Author | SHA1 | Message | Date |
---|---|---|---|
|
3c87768cc6 |
!590 去除根图节点上的dump origin name属性
Merge pull request !590 from 薛鹏/r1.9.0 |
2 years ago |
|
33c9ce1ebd |
!588 fix opensource problem
Merge pull request !588 from 蒋荣强/cherry-pick-1656571743 |
3 years ago |
|
a1b1ffab38 |
!587 超过2G的onnx模型导入
Merge pull request !587 from 黄桂军/c82_0630 |
3 years ago |
|
774e74a3f9 |
!581 修复dump子图节点时 origin name未拼接根图节点的问题
Merge pull request !581 from 薛鹏/r1.9.0 |
3 years ago |
|
8a3973ce99 |
!574 update owners
Merge pull request !574 from 王涛/r1.9.0 |
3 years ago |
|
c813469db2 |
!571 update .gitmodules.
Merge pull request !571 from 王涛/r1.9.0 |
3 years ago |
@@ -1,4 +1,4 @@ | |||||
[submodule "metadef"] | [submodule "metadef"] | ||||
path = metadef | path = metadef | ||||
url = https://gitee.com/ascend/metadef.git | url = https://gitee.com/ascend/metadef.git | ||||
branch = master | |||||
branch = r1.9.0 |
@@ -1,13 +1,12 @@ | |||||
approvers: | approvers: | ||||
- ji_chen | |||||
- wqtshg | |||||
- ljl0711 | |||||
- liu-jisheng | |||||
- startzgf168 | |||||
- andylhy | |||||
- liyihan123 | |||||
- zhangfan_hq | - zhangfan_hq | ||||
- lipeiyang3699 | |||||
reviewers: | reviewers: | ||||
- xchu42 | - xchu42 | ||||
- sheng-nan | - sheng-nan | ||||
- tangqunzhang | - tangqunzhang | ||||
- wangxiaotian22 | - wangxiaotian22 | ||||
- stevenaw | |||||
- stevenaw | |||||
- xuepenginnanjing |
@@ -1 +1 @@ | |||||
Subproject commit 4f61fa7a7181e0e7dbdd4acbfaf99088a58d920d | |||||
Subproject commit 35de9facd31448995922246c5d2ffaa5a726bbb1 |
@@ -74,6 +74,7 @@ using std::ifstream; | |||||
namespace { | namespace { | ||||
const size_t kMaxErrStrLen = 128U; | const size_t kMaxErrStrLen = 128U; | ||||
std::map<std::vector<std::string>, std::vector<std::string>> params_share_map; | |||||
} // namespace | } // namespace | ||||
namespace ge { | namespace ge { | ||||
@@ -282,7 +283,7 @@ Status CheckPathValid(const char *model_path, const string &custom_proto, string | |||||
const set<string> CaffeWeightsParser::skiped_layer_type_ = {"Split", "SoftmaxWithLoss", "Accuracy", "Data", | const set<string> CaffeWeightsParser::skiped_layer_type_ = {"Split", "SoftmaxWithLoss", "Accuracy", "Data", | ||||
"Dropout", "MultiLabelLMDB", "Python", "AnnotatedData"}; | "Dropout", "MultiLabelLMDB", "Python", "AnnotatedData"}; | ||||
Status CaffeModelParser::ParseInput(domi::caffe::NetParameter &proto_message, bool &input_data_flag) { | |||||
Status CaffeModelParser::ParseInput(domi::caffe::NetParameter &proto_message, bool &input_data_flag) const { | |||||
if (proto_message.input_size() > 0) { | if (proto_message.input_size() > 0) { | ||||
GELOGI("This net exsit input."); | GELOGI("This net exsit input."); | ||||
@@ -456,7 +457,7 @@ Status CaffeModelParser::CustomProtoParse(const char *model_path, const string & | |||||
return ret; | return ret; | ||||
} | } | ||||
Status CaffeModelParser::ReadModelWithoutWarning(const char *model_path, google::protobuf::Message *message) { | |||||
Status CaffeModelParser::ReadModelWithoutWarning(const char *model_path, google::protobuf::Message *message) const { | |||||
int32_t copy_fd = mmDup(STDERR_FILENO); | int32_t copy_fd = mmDup(STDERR_FILENO); | ||||
if (copy_fd < 0) { | if (copy_fd < 0) { | ||||
char_t err_buf[kMaxErrStrLen + 1U] = {}; | char_t err_buf[kMaxErrStrLen + 1U] = {}; | ||||
@@ -536,7 +537,7 @@ Status CaffeModelParser::ReadCaffeModelFromText(const char *model_path, google:: | |||||
Status CaffeModelParser::ParseLayerParameter(const google::protobuf::Descriptor *layer_descriptor, | Status CaffeModelParser::ParseLayerParameter(const google::protobuf::Descriptor *layer_descriptor, | ||||
const google::protobuf::Message *message, | const google::protobuf::Message *message, | ||||
vector<ge::Operator> &operators) { | |||||
vector<ge::Operator> &operators) const { | |||||
auto field_name = layer_descriptor->FindFieldByName(kFieldName); | auto field_name = layer_descriptor->FindFieldByName(kFieldName); | ||||
CAFFE_CHECK_NULL_AND_REPROT_ERRORMSG(field_name, "Does not find name in google::protobuf::Descriptor"); | CAFFE_CHECK_NULL_AND_REPROT_ERRORMSG(field_name, "Does not find name in google::protobuf::Descriptor"); | ||||
auto field_type = layer_descriptor->FindFieldByName(kFieldType); | auto field_type = layer_descriptor->FindFieldByName(kFieldType); | ||||
@@ -624,7 +625,7 @@ void CaffeModelParser::AddOutputInfoToContext(string layer_name, int32_t top_ind | |||||
ge::GetParserContext().user_out_nodes.push_back(std::make_pair(layer_name, top_index)); | ge::GetParserContext().user_out_nodes.push_back(std::make_pair(layer_name, top_index)); | ||||
} | } | ||||
Status CaffeModelParser::ParseOutputNodeTopInfo(const domi::caffe::NetParameter &proto_message) { | |||||
Status CaffeModelParser::ParseOutputNodeTopInfo(const domi::caffe::NetParameter &proto_message) const { | |||||
if (ge::GetParserContext().user_out_tensors.empty()) { | if (ge::GetParserContext().user_out_tensors.empty()) { | ||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
@@ -932,7 +933,7 @@ Status CaffeModelParser::AddTensorDescToOpDesc(ge::OpDescPtr &op_desc, const dom | |||||
} | } | ||||
Status CaffeModelParser::AddTensorDescToOpDescByIr(ge::OpDescPtr &op_desc, const domi::caffe::LayerParameter &layer, | Status CaffeModelParser::AddTensorDescToOpDescByIr(ge::OpDescPtr &op_desc, const domi::caffe::LayerParameter &layer, | ||||
const string &op_type) { | |||||
const string &op_type) const { | |||||
if (std::find(kAddTensorIrSkipNodes.begin(), kAddTensorIrSkipNodes.end(), op_type) != kAddTensorIrSkipNodes.end()) { | if (std::find(kAddTensorIrSkipNodes.begin(), kAddTensorIrSkipNodes.end(), op_type) != kAddTensorIrSkipNodes.end()) { | ||||
op_desc = ge::parser::MakeShared<ge::OpDesc>(layer.name(), op_type); | op_desc = ge::parser::MakeShared<ge::OpDesc>(layer.name(), op_type); | ||||
GE_CHECK_NOTNULL(op_desc); | GE_CHECK_NOTNULL(op_desc); | ||||
@@ -1202,7 +1203,7 @@ std::string CaffeModelParser::RemapTopNameByLayer(const domi::caffe::LayerParame | |||||
return (top_name + "_" + layer.name() + "_" + std::to_string(index)); | return (top_name + "_" + layer.name() + "_" + std::to_string(index)); | ||||
} | } | ||||
Status CaffeModelParser::PreCheck(const domi::caffe::NetParameter &net) { | |||||
Status CaffeModelParser::PreCheck(const domi::caffe::NetParameter &net) const { | |||||
// Add layer in the model to PreChecker and check the general parameters | // Add layer in the model to PreChecker and check the general parameters | ||||
PreChecker::Instance().SetModelName(net.name()); | PreChecker::Instance().SetModelName(net.name()); | ||||
for (int i = 0; i < net.layer_size(); i++) { | for (int i = 0; i < net.layer_size(); i++) { | ||||
@@ -1977,7 +1978,7 @@ Status CaffeWeightsParser::ParseLayerField(const google::protobuf::Reflection *r | |||||
} | } | ||||
Status CaffeWeightsParser::ConvertBlobsProto(const google::protobuf::Message *message, | Status CaffeWeightsParser::ConvertBlobsProto(const google::protobuf::Message *message, | ||||
google::protobuf::Message *blobs) { | |||||
google::protobuf::Message *blobs) const { | |||||
const google::protobuf::Reflection *blobs_reflection = message->GetReflection(); | const google::protobuf::Reflection *blobs_reflection = message->GetReflection(); | ||||
CAFFE_CHECK_NULL_AND_REPROT_ERRORMSG(blobs_reflection, "Get Reflection failed in google::protobuf::Message"); | CAFFE_CHECK_NULL_AND_REPROT_ERRORMSG(blobs_reflection, "Get Reflection failed in google::protobuf::Message"); | ||||
vector<const google::protobuf::FieldDescriptor *> field_desc; | vector<const google::protobuf::FieldDescriptor *> field_desc; | ||||
@@ -52,12 +52,11 @@ using std::string; | |||||
using std::unordered_map; | using std::unordered_map; | ||||
using std::vector; | using std::vector; | ||||
using domi::Status; | using domi::Status; | ||||
static std::map<std::vector<std::string>, std::vector<std::string>> params_share_map; | |||||
class PARSER_FUNC_VISIBILITY CaffeModelParser : public domi::ModelParser { | class PARSER_FUNC_VISIBILITY CaffeModelParser : public domi::ModelParser { | ||||
public: | public: | ||||
CaffeModelParser() {} | CaffeModelParser() {} | ||||
virtual ~CaffeModelParser() override {} | |||||
~CaffeModelParser() override {} | |||||
/** | /** | ||||
* @ingroup domi_omg | * @ingroup domi_omg | ||||
@@ -145,7 +144,7 @@ class PARSER_FUNC_VISIBILITY CaffeModelParser : public domi::ModelParser { | |||||
* @return SUCCESS build successfully | * @return SUCCESS build successfully | ||||
* @return FAILED build failed | * @return FAILED build failed | ||||
*/ | */ | ||||
Status PreCheck(const domi::caffe::NetParameter &net); | |||||
Status PreCheck(const domi::caffe::NetParameter &net) const; | |||||
/** | /** | ||||
* @ingroup domi_omg | * @ingroup domi_omg | ||||
@@ -156,7 +155,7 @@ class PARSER_FUNC_VISIBILITY CaffeModelParser : public domi::ModelParser { | |||||
* @return SUCCESS build successfully | * @return SUCCESS build successfully | ||||
* @return FAILED build failed | * @return FAILED build failed | ||||
*/ | */ | ||||
Status ParseInput(domi::caffe::NetParameter &proto_message, bool &input_data_flag); | |||||
Status ParseInput(domi::caffe::NetParameter &proto_message, bool &input_data_flag) const; | |||||
/* | /* | ||||
* @ingroup domi_omg | * @ingroup domi_omg | ||||
@@ -192,7 +191,7 @@ class PARSER_FUNC_VISIBILITY CaffeModelParser : public domi::ModelParser { | |||||
* @return SUCCESS read file successfully | * @return SUCCESS read file successfully | ||||
* @return FAILED read file failed | * @return FAILED read file failed | ||||
*/ | */ | ||||
Status ReadModelWithoutWarning(const char *model_path, google::protobuf::Message *message); | |||||
Status ReadModelWithoutWarning(const char *model_path, google::protobuf::Message *message) const; | |||||
/* | /* | ||||
* @ingroup domi_omg | * @ingroup domi_omg | ||||
@@ -214,7 +213,7 @@ class PARSER_FUNC_VISIBILITY CaffeModelParser : public domi::ModelParser { | |||||
* @return FAILED parse layer failed | * @return FAILED parse layer failed | ||||
*/ | */ | ||||
Status ParseLayerParameter(const google::protobuf::Descriptor *layer_descriptor, | Status ParseLayerParameter(const google::protobuf::Descriptor *layer_descriptor, | ||||
const google::protobuf::Message *message, std::vector<ge::Operator> &operators); | |||||
const google::protobuf::Message *message, std::vector<ge::Operator> &operators) const; | |||||
/* | /* | ||||
* @ingroup domi_omg | * @ingroup domi_omg | ||||
@@ -301,7 +300,7 @@ class PARSER_FUNC_VISIBILITY CaffeModelParser : public domi::ModelParser { | |||||
Status AddTensorDescToOpDesc(ge::OpDescPtr &op_desc, const domi::caffe::LayerParameter &layer) const; | Status AddTensorDescToOpDesc(ge::OpDescPtr &op_desc, const domi::caffe::LayerParameter &layer) const; | ||||
Status AddTensorDescToOpDescByIr(ge::OpDescPtr &op_desc, const domi::caffe::LayerParameter &layer, | Status AddTensorDescToOpDescByIr(ge::OpDescPtr &op_desc, const domi::caffe::LayerParameter &layer, | ||||
const string &op_type); | |||||
const string &op_type) const; | |||||
Status AddUserOutNodesTop(); | Status AddUserOutNodesTop(); | ||||
@@ -321,7 +320,7 @@ class PARSER_FUNC_VISIBILITY CaffeModelParser : public domi::ModelParser { | |||||
void AddOutputInfoToContext(string layer_name, int32_t top_index) const; | void AddOutputInfoToContext(string layer_name, int32_t top_index) const; | ||||
Status ParseOutputNodeTopInfo(const domi::caffe::NetParameter &proto_message); | |||||
Status ParseOutputNodeTopInfo(const domi::caffe::NetParameter &proto_message) const; | |||||
Status SaveDataLayerTops(const domi::caffe::LayerParameter &layer); | Status SaveDataLayerTops(const domi::caffe::LayerParameter &layer); | ||||
@@ -405,7 +404,7 @@ class PARSER_FUNC_VISIBILITY CaffeWeightsParser : public domi::WeightsParser { | |||||
google::protobuf::Message *layer); | google::protobuf::Message *layer); | ||||
Status ConvertBlobsProto(const google::protobuf::Message *message, | Status ConvertBlobsProto(const google::protobuf::Message *message, | ||||
google::protobuf::Message *blobs); | |||||
google::protobuf::Message *blobs) const; | |||||
Status ConvertBlobShapeProto(const google::protobuf::Message *message, | Status ConvertBlobShapeProto(const google::protobuf::Message *message, | ||||
google::protobuf::Message *dest_message) const; | google::protobuf::Message *dest_message) const; | ||||
@@ -266,7 +266,7 @@ void AclGrphParseUtil::SetDefaultFormat() { | |||||
} | } | ||||
} | } | ||||
domi::Status AclGrphParseUtil::ParseAclOutputNodes(const string &out_nodes) { | |||||
domi::Status AclGrphParseUtil::ParseAclOutputNodes(const string &out_nodes) const { | |||||
try { | try { | ||||
ge::GetParserContext().out_nodes_map.clear(); | ge::GetParserContext().out_nodes_map.clear(); | ||||
ge::GetParserContext().user_out_nodes.clear(); | ge::GetParserContext().user_out_nodes.clear(); | ||||
@@ -492,7 +492,7 @@ domi::Status AclGrphParseUtil::GetOutputLeaf(NodePtr node, | |||||
} | } | ||||
domi::Status AclGrphParseUtil::GetDefaultOutInfo(ge::ComputeGraphPtr &compute_graph, | domi::Status AclGrphParseUtil::GetDefaultOutInfo(ge::ComputeGraphPtr &compute_graph, | ||||
std::vector<std::pair<ge::NodePtr, int32_t>> &output_nodes_info) { | |||||
std::vector<std::pair<ge::NodePtr, int32_t>> &output_nodes_info) const { | |||||
std::vector<std::pair<std::string, int32_t>> default_out_nodes = ge::GetParserContext().default_out_nodes; | std::vector<std::pair<std::string, int32_t>> default_out_nodes = ge::GetParserContext().default_out_nodes; | ||||
if (!default_out_nodes.empty()) { | if (!default_out_nodes.empty()) { | ||||
for (size_t i = 0; i < default_out_nodes.size(); ++i) { | for (size_t i = 0; i < default_out_nodes.size(); ++i) { | ||||
@@ -587,7 +587,7 @@ domi::Status AclGrphParseUtil::CheckOptions(const std::map<AscendString, AscendS | |||||
} | } | ||||
string key_str = key_ascend; | string key_str = key_ascend; | ||||
auto it = ge::ir_option::ir_parser_suppported_options.find(key_str); | |||||
std::set<std::string>::const_iterator it = ge::ir_option::ir_parser_suppported_options.find(key_str); | |||||
if (it == ge::ir_option::ir_parser_suppported_options.end()) { | if (it == ge::ir_option::ir_parser_suppported_options.end()) { | ||||
ErrorManager::GetInstance().ATCReportErrMessage("E10016", {"parameter", "opname"}, {"parser_params", key_str}); | ErrorManager::GetInstance().ATCReportErrMessage("E10016", {"parameter", "opname"}, {"parser_params", key_str}); | ||||
GELOGE(PARAM_INVALID, "[Check][Param] Input options include unsupported option(%s).Please check!", key_ascend); | GELOGE(PARAM_INVALID, "[Check][Param] Input options include unsupported option(%s).Please check!", key_ascend); | ||||
@@ -651,7 +651,7 @@ domi::Status AclGrphParseUtil::ParseParamsBeforeGraph(const std::map<AscendStrin | |||||
} | } | ||||
domi::Status AclGrphParseUtil::ParseParamsAfterGraph(ge::Graph &graph, | domi::Status AclGrphParseUtil::ParseParamsAfterGraph(ge::Graph &graph, | ||||
const std::map<AscendString, AscendString> &parser_params) { | |||||
const std::map<AscendString, AscendString> &parser_params) const { | |||||
// support paragrams: input_fp16_nodes, is_input_adjust_hw_layout, | // support paragrams: input_fp16_nodes, is_input_adjust_hw_layout, | ||||
ComputeGraphPtr compute_graph = GraphUtils::GetComputeGraph(graph); | ComputeGraphPtr compute_graph = GraphUtils::GetComputeGraph(graph); | ||||
GE_CHECK_NOTNULL(compute_graph); | GE_CHECK_NOTNULL(compute_graph); | ||||
@@ -943,7 +943,7 @@ FMK_FUNC_HOST_VISIBILITY bool ValidateStr(const std::string &filePath, const std | |||||
regex_t reg; | regex_t reg; | ||||
int cflags = REG_EXTENDED | REG_NOSUB; | int cflags = REG_EXTENDED | REG_NOSUB; | ||||
int ret = regcomp(®, mode.c_str(), cflags); | int ret = regcomp(®, mode.c_str(), cflags); | ||||
if (ret) { | |||||
if (ret != 0) { | |||||
regerror(ret, ®, ebuff, kMaxBuffSize); | regerror(ret, ®, ebuff, kMaxBuffSize); | ||||
GELOGW("regcomp failed, reason: %s", ebuff); | GELOGW("regcomp failed, reason: %s", ebuff); | ||||
regfree(®); | regfree(®); | ||||
@@ -951,7 +951,7 @@ FMK_FUNC_HOST_VISIBILITY bool ValidateStr(const std::string &filePath, const std | |||||
} | } | ||||
ret = regexec(®, filePath.c_str(), 0, nullptr, 0); | ret = regexec(®, filePath.c_str(), 0, nullptr, 0); | ||||
if (ret) { | |||||
if (ret != 0) { | |||||
regerror(ret, ®, ebuff, kMaxBuffSize); | regerror(ret, ®, ebuff, kMaxBuffSize); | ||||
GELOGE(ge::PARAM_INVALID, "[Invoke][RegExec] failed, reason: %s", ebuff); | GELOGE(ge::PARAM_INVALID, "[Invoke][RegExec] failed, reason: %s", ebuff); | ||||
regfree(®); | regfree(®); | ||||
@@ -44,7 +44,8 @@ class AclGrphParseUtil { | |||||
domi::Status SetOutputNodeInfo(ge::Graph &graph, const std::map<AscendString, AscendString> &parser_params); | domi::Status SetOutputNodeInfo(ge::Graph &graph, const std::map<AscendString, AscendString> &parser_params); | ||||
domi::Status ParseParamsBeforeGraph(const std::map<AscendString, AscendString> &parser_params, | domi::Status ParseParamsBeforeGraph(const std::map<AscendString, AscendString> &parser_params, | ||||
std::string &graph_name); | std::string &graph_name); | ||||
domi::Status ParseParamsAfterGraph(ge::Graph &graph, const std::map<AscendString, AscendString> &parser_params); | |||||
domi::Status ParseParamsAfterGraph(ge::Graph &graph, const std::map<AscendString, | |||||
AscendString> &parser_params) const; | |||||
private: | private: | ||||
bool parser_initialized = false; | bool parser_initialized = false; | ||||
@@ -53,7 +54,7 @@ class AclGrphParseUtil { | |||||
void CreateOutputNodesInfo(std::vector<std::pair<ge::NodePtr, int32_t>> &output_nodes_info, | void CreateOutputNodesInfo(std::vector<std::pair<ge::NodePtr, int32_t>> &output_nodes_info, | ||||
std::vector<std::string> &output_nodes_name) const; | std::vector<std::string> &output_nodes_name) const; | ||||
static void SetDefaultFormat(); | static void SetDefaultFormat(); | ||||
domi::Status ParseAclOutputNodes(const std::string &out_nodes); | |||||
domi::Status ParseAclOutputNodes(const std::string &out_nodes) const; | |||||
domi::Status ParseAclOutputFp16NodesFormat(const std::string &is_output_fp16) const; | domi::Status ParseAclOutputFp16NodesFormat(const std::string &is_output_fp16) const; | ||||
domi::Status ParseAclEnableScope(const std::string &enable_scope_fusion_passes) const; | domi::Status ParseAclEnableScope(const std::string &enable_scope_fusion_passes) const; | ||||
static void AddAttrsForInputNodes(const vector<string> &adjust_fp16_format_vec, const string &fp16_nodes_name, | static void AddAttrsForInputNodes(const vector<string> &adjust_fp16_format_vec, const string &fp16_nodes_name, | ||||
@@ -61,7 +62,7 @@ class AclGrphParseUtil { | |||||
domi::Status ParseAclInputFp16Nodes(const ComputeGraphPtr &graph, const string &input_fp16_nodes, | domi::Status ParseAclInputFp16Nodes(const ComputeGraphPtr &graph, const string &input_fp16_nodes, | ||||
const string &is_input_adjust_hw_layout) const; | const string &is_input_adjust_hw_layout) const; | ||||
domi::Status GetDefaultOutInfo(ge::ComputeGraphPtr &compute_graph, | domi::Status GetDefaultOutInfo(ge::ComputeGraphPtr &compute_graph, | ||||
std::vector<std::pair<ge::NodePtr, int32_t>> &output_nodes_info); | |||||
std::vector<std::pair<ge::NodePtr, int32_t>> &output_nodes_info) const; | |||||
}; | }; | ||||
namespace parser { | namespace parser { | ||||
@@ -77,7 +77,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ModelSaver::SaveJsonToFi | |||||
const char *model_char = model_str.c_str(); | const char *model_char = model_str.c_str(); | ||||
uint32_t len = static_cast<uint32_t>(model_str.length()); | uint32_t len = static_cast<uint32_t>(model_str.length()); | ||||
// Write data to file | // Write data to file | ||||
mmSsize_t mmpa_ret = mmWrite(fd, const_cast<void *>((const void *)model_char), len); | |||||
mmSsize_t mmpa_ret = mmWrite(fd, const_cast<void *>(static_cast<const void *>(model_char)), len); | |||||
if (mmpa_ret == EN_ERROR || mmpa_ret == EN_INVALID_PARAM) { | if (mmpa_ret == EN_ERROR || mmpa_ret == EN_INVALID_PARAM) { | ||||
char_t err_buf[kMaxErrStrLen + 1U] = {}; | char_t err_buf[kMaxErrStrLen + 1U] = {}; | ||||
const auto err_msg = mmGetErrorFormatMessage(mmGetErrorCode(), &err_buf[0], kMaxErrStrLen); | const auto err_msg = mmGetErrorFormatMessage(mmGetErrorCode(), &err_buf[0], kMaxErrStrLen); | ||||
@@ -48,7 +48,7 @@ static bool IsRoundOne(uint64_t man, uint16_t trunc_len) { | |||||
uint64_t mask0 = 0x4; | uint64_t mask0 = 0x4; | ||||
uint64_t mask1 = 0x2; | uint64_t mask1 = 0x2; | ||||
uint64_t mask2; | uint64_t mask2; | ||||
uint16_t shift_out = static_cast<uint16_t>(trunc_len - kDim2); | |||||
uint16_t shift_out = static_cast<uint16_t>(trunc_len - static_cast<uint16_t>(kDim2)); | |||||
mask0 = mask0 << shift_out; | mask0 = mask0 << shift_out; | ||||
mask1 = mask1 << shift_out; | mask1 = mask1 << shift_out; | ||||
mask2 = mask1 - 1; | mask2 = mask1 - 1; | ||||
@@ -89,7 +89,7 @@ static float Fp16ToFloat(const uint16_t &fp_val) { | |||||
int16_t hf_exp; | int16_t hf_exp; | ||||
ExtractFp16(fp_val, hf_sign, hf_exp, hf_man); | ExtractFp16(fp_val, hf_sign, hf_exp, hf_man); | ||||
while (hf_man && !(hf_man & kFp16ManHideBit)) { | |||||
while ((hf_man != 0U) && ((hf_man & kFp16ManHideBit) == 0U)) { | |||||
hf_man <<= 1; | hf_man <<= 1; | ||||
hf_exp--; | hf_exp--; | ||||
} | } | ||||
@@ -120,7 +120,7 @@ static double Fp16ToDouble(const uint16_t &fp_val) { | |||||
int16_t hf_exp; | int16_t hf_exp; | ||||
ExtractFp16(fp_val, hf_sign, hf_exp, hf_man); | ExtractFp16(fp_val, hf_sign, hf_exp, hf_man); | ||||
while (hf_man && !(hf_man & kFp16ManHideBit)) { | |||||
while ((hf_man != 0U) && ((hf_man & kFp16ManHideBit) == 0U)) { | |||||
hf_man <<= 1; | hf_man <<= 1; | ||||
hf_exp--; | hf_exp--; | ||||
} | } | ||||
@@ -128,7 +128,7 @@ static double Fp16ToDouble(const uint16_t &fp_val) { | |||||
uint64_t e_ret; | uint64_t e_ret; | ||||
uint64_t m_ret; | uint64_t m_ret; | ||||
uint64_t s_ret = hf_sign; | uint64_t s_ret = hf_sign; | ||||
if (!hf_man) { | |||||
if (hf_man == 0U) { | |||||
e_ret = 0; | e_ret = 0; | ||||
m_ret = 0; | m_ret = 0; | ||||
} else { | } else { | ||||
@@ -256,7 +256,7 @@ static uint8_t Fp16ToUInt8(const uint16_t &fp_val) { | |||||
shift_out++; | shift_out++; | ||||
} | } | ||||
} | } | ||||
if (!overflow_flag) { | |||||
if (overflow_flag == 0U) { | |||||
bool need_round = IsRoundOne(long_int_m, shift_out + kFp16ManLen); | bool need_round = IsRoundOne(long_int_m, shift_out + kFp16ManLen); | ||||
m_ret = static_cast<uint8_t>((long_int_m >> (kFp16ManLen + shift_out)) & kBitLen8Max); | m_ret = static_cast<uint8_t>((long_int_m >> (kFp16ManLen + shift_out)) & kBitLen8Max); | ||||
if (need_round && m_ret != kBitLen8Max) { | if (need_round && m_ret != kBitLen8Max) { | ||||
@@ -290,7 +290,7 @@ static uint16_t GetUint16ValByMan(uint16_t s_ret, const uint64_t &long_int_m, co | |||||
if (m_ret == 0) { | if (m_ret == 0) { | ||||
s_ret = 0; | s_ret = 0; | ||||
} | } | ||||
return static_cast<uint16_t>((s_ret << kBitShift15) | (m_ret)); | |||||
return static_cast<uint16_t>((s_ret << static_cast<uint16_t>(kBitShift15)) | (m_ret)); | |||||
} | } | ||||
/// @ingroup fp16_t math conversion static method | /// @ingroup fp16_t math conversion static method | ||||
@@ -431,7 +431,7 @@ static int32_t Fp16ToInt32(const uint16_t &fp_val) { | |||||
s_ret = 0; | s_ret = 0; | ||||
} | } | ||||
// Generate final result | // Generate final result | ||||
ret_v = (s_ret << kBitShift31) | (m_ret); | |||||
ret_v = (s_ret << static_cast<uint16_t>(kBitShift31)) | (m_ret); | |||||
} | } | ||||
return *(ge::PtrToPtr<uint32_t, uint32_t>(&ret_v)); | return *(ge::PtrToPtr<uint32_t, uint32_t>(&ret_v)); | ||||
@@ -565,7 +565,7 @@ static uint16_t Fp16Add(uint16_t v_1, uint16_t v_2) { | |||||
m_trunc = (m_b << (static_cast<uint16_t>(kBitShift32) - static_cast<uint16_t>(e_tmp))); | m_trunc = (m_b << (static_cast<uint16_t>(kBitShift32) - static_cast<uint16_t>(e_tmp))); | ||||
m_b = RightShift(m_b, e_tmp); | m_b = RightShift(m_b, e_tmp); | ||||
} else if (e_a < e_b) { | } else if (e_a < e_b) { | ||||
m_trunc = (m_a << (kBitShift32 - static_cast<uint16_t>(e_tmp))); | |||||
m_trunc = (m_a << (static_cast<uint16_t>(kBitShift32) - static_cast<uint16_t>(e_tmp))); | |||||
m_a = RightShift(m_a, e_tmp); | m_a = RightShift(m_a, e_tmp); | ||||
} | } | ||||
// calculate mantissav | // calculate mantissav | ||||
@@ -603,7 +603,7 @@ static uint16_t Fp16Mul(uint16_t v_1, uint16_t v_2) { | |||||
m_a = m_a_tmp; | m_a = m_a_tmp; | ||||
m_b = m_b_tmp; | m_b = m_b_tmp; | ||||
e_ret = ((e_a + e_b) - kFp16ExpBias) - kDim10; | |||||
e_ret = ((e_a + e_b) - kFp16ExpBias) - static_cast<int16_t>(kDim10); | |||||
mul_m = m_a * m_b; | mul_m = m_a * m_b; | ||||
s_ret = s_a ^ s_b; | s_ret = s_a ^ s_b; | ||||
@@ -905,7 +905,7 @@ fp16_t &fp16_t::operator=(const float &f_val) { | |||||
fp16_t &fp16_t::operator=(const int8_t &i_val) { | fp16_t &fp16_t::operator=(const int8_t &i_val) { | ||||
uint16_t s_ret, e_ret, m_ret; | uint16_t s_ret, e_ret, m_ret; | ||||
s_ret = static_cast<uint16_t>(((static_cast<uint8_t>(i_val)) & 0x80) >> kDim7); | |||||
s_ret = static_cast<uint16_t>(((static_cast<uint8_t>(i_val)) & 0x80) >> static_cast<uint8_t>(kDim7)); | |||||
m_ret = static_cast<uint16_t>(((static_cast<uint8_t>(i_val)) & kInt8Max)); | m_ret = static_cast<uint16_t>(((static_cast<uint8_t>(i_val)) & kInt8Max)); | ||||
if (m_ret == 0) { | if (m_ret == 0) { | ||||
@@ -952,14 +952,14 @@ static void SetValByUint16Val(const uint16_t &input_val, const uint16_t &sign, u | |||||
uint16_t len = static_cast<uint16_t>(GetManBitLength(m_tmp)); | uint16_t len = static_cast<uint16_t>(GetManBitLength(m_tmp)); | ||||
if (static_cast<bool>(m_tmp)) { | if (static_cast<bool>(m_tmp)) { | ||||
int16_t e_ret; | int16_t e_ret; | ||||
if (len > kDim11) { | |||||
if (len > static_cast<uint16_t>(kDim11)) { | |||||
e_ret = kFp16ExpBias + kFp16ManLen; | e_ret = kFp16ExpBias + kFp16ManLen; | ||||
uint16_t e_tmp = len - static_cast<uint16_t>(kDim11); | uint16_t e_tmp = len - static_cast<uint16_t>(kDim11); | ||||
uint32_t trunc_mask = 1; | uint32_t trunc_mask = 1; | ||||
for (int i = 1; i < e_tmp; i++) { | for (int i = 1; i < e_tmp; i++) { | ||||
trunc_mask = (trunc_mask << 1) + 1; | trunc_mask = (trunc_mask << 1) + 1; | ||||
} | } | ||||
uint32_t m_trunc = (m_tmp & trunc_mask) << (kBitShift32 - e_tmp); | |||||
uint32_t m_trunc = (m_tmp & trunc_mask) << (static_cast<uint16_t>(kBitShift32) - e_tmp); | |||||
for (int i = 0; i < e_tmp; i++) { | for (int i = 0; i < e_tmp; i++) { | ||||
m_tmp = (m_tmp >> 1); | m_tmp = (m_tmp >> 1); | ||||
e_ret = e_ret + 1; | e_ret = e_ret + 1; | ||||
@@ -991,7 +991,7 @@ fp16_t &fp16_t::operator=(const int16_t &i_val) { | |||||
val = 0; | val = 0; | ||||
} else { | } else { | ||||
uint16_t ui_val = *(ge::PtrToPtr<const int16_t, const int16_t>(&i_val)); | uint16_t ui_val = *(ge::PtrToPtr<const int16_t, const int16_t>(&i_val)); | ||||
auto s_ret = static_cast<uint16_t>(ui_val >> kBitShift15); | |||||
auto s_ret = static_cast<uint16_t>(ui_val >> static_cast<uint16_t>(kBitShift15)); | |||||
if (static_cast<bool>(s_ret)) { | if (static_cast<bool>(s_ret)) { | ||||
int16_t iValM = -i_val; | int16_t iValM = -i_val; | ||||
ui_val = *(ge::PtrToPtr<int16_t, uint16_t>(&iValM)); | ui_val = *(ge::PtrToPtr<int16_t, uint16_t>(&iValM)); | ||||
@@ -1018,7 +1018,7 @@ fp16_t &fp16_t::operator=(const uint16_t &ui_val) { | |||||
for (int i = 1; i < e_tmp; i++) { | for (int i = 1; i < e_tmp; i++) { | ||||
trunc_mask = (trunc_mask << 1) + 1; | trunc_mask = (trunc_mask << 1) + 1; | ||||
} | } | ||||
m_trunc = (m_ret & trunc_mask) << (kBitShift32 - e_tmp); | |||||
m_trunc = (m_ret & trunc_mask) << (static_cast<uint16_t>(kBitShift32) - e_tmp); | |||||
for (int i = 0; i < e_tmp; i++) { | for (int i = 0; i < e_tmp; i++) { | ||||
m_ret = (m_ret >> 1); | m_ret = (m_ret >> 1); | ||||
e_ret = e_ret + 1; | e_ret = e_ret + 1; | ||||
@@ -1040,7 +1040,7 @@ fp16_t &fp16_t::operator=(const uint16_t &ui_val) { | |||||
} | } | ||||
} else { | } else { | ||||
e_ret = static_cast<int16_t>(kFp16ExpBias); | e_ret = static_cast<int16_t>(kFp16ExpBias); | ||||
m_ret = m_ret << (kDim11 - len); | |||||
m_ret = m_ret << (static_cast<uint16_t>(kDim11) - len); | |||||
e_ret = e_ret + (len - 1); | e_ret = e_ret + (len - 1); | ||||
} | } | ||||
val = FP16_CONSTRUCTOR(0u, static_cast<uint16_t>(e_ret), m_ret); | val = FP16_CONSTRUCTOR(0u, static_cast<uint16_t>(e_ret), m_ret); | ||||
@@ -1062,7 +1062,7 @@ static void SetValByUint32Val(const uint32_t &input_val, const uint16_t &sign, u | |||||
for (int i = 1; i < e_tmp; i++) { | for (int i = 1; i < e_tmp; i++) { | ||||
trunc_mask = (trunc_mask << 1) + 1; | trunc_mask = (trunc_mask << 1) + 1; | ||||
} | } | ||||
m_trunc = (m_tmp & trunc_mask) << (kBitShift32 - e_tmp); | |||||
m_trunc = (m_tmp & trunc_mask) << (static_cast<uint16_t>(kBitShift32) - e_tmp); | |||||
for (int i = 0; i < e_tmp; i++) { | for (int i = 0; i < e_tmp; i++) { | ||||
m_tmp = (m_tmp >> 1); | m_tmp = (m_tmp >> 1); | ||||
e_ret = e_ret + 1; | e_ret = e_ret + 1; | ||||
@@ -1085,7 +1085,7 @@ static void SetValByUint32Val(const uint32_t &input_val, const uint16_t &sign, u | |||||
} | } | ||||
} else { | } else { | ||||
e_ret = static_cast<int16_t>(kFp16ExpBias); | e_ret = static_cast<int16_t>(kFp16ExpBias); | ||||
m_tmp = m_tmp << (kDim11 - len); | |||||
m_tmp = m_tmp << (static_cast<uint16_t>(kDim11) - len); | |||||
e_ret = e_ret + (len - 1); | e_ret = e_ret + (len - 1); | ||||
} | } | ||||
auto m_ret = static_cast<uint16_t>(m_tmp); | auto m_ret = static_cast<uint16_t>(m_tmp); | ||||
@@ -1097,7 +1097,7 @@ fp16_t &fp16_t::operator=(const int32_t &i_val) { | |||||
val = 0; | val = 0; | ||||
} else { | } else { | ||||
uint32_t ui_val = *(ge::PtrToPtr<const int32_t, const uint32_t>(&i_val)); | uint32_t ui_val = *(ge::PtrToPtr<const int32_t, const uint32_t>(&i_val)); | ||||
auto s_ret = static_cast<uint16_t>(ui_val >> kBitShift31); | |||||
auto s_ret = static_cast<uint16_t>(ui_val >> static_cast<uint16_t>(kBitShift31)); | |||||
if (static_cast<bool>(s_ret)) { | if (static_cast<bool>(s_ret)) { | ||||
int32_t iValM = -i_val; | int32_t iValM = -i_val; | ||||
ui_val = *(ge::PtrToPtr<int32_t, uint32_t>(&iValM)); | ui_val = *(ge::PtrToPtr<int32_t, uint32_t>(&iValM)); | ||||
@@ -1124,7 +1124,7 @@ fp16_t &fp16_t::operator=(const uint32_t &ui_val) { | |||||
for (int i = 1; i < e_tmp; i++) { | for (int i = 1; i < e_tmp; i++) { | ||||
trunc_mask = (trunc_mask << 1) + 1; | trunc_mask = (trunc_mask << 1) + 1; | ||||
} | } | ||||
m_trunc = (m_tmp & trunc_mask) << static_cast<uint32_t>(kBitShift32 - e_tmp); | |||||
m_trunc = (m_tmp & trunc_mask) << static_cast<uint32_t>(static_cast<uint16_t>(kBitShift32) - e_tmp); | |||||
for (uint16_t i = 0; i < e_tmp; i++) { | for (uint16_t i = 0; i < e_tmp; i++) { | ||||
m_tmp = (m_tmp >> 1); | m_tmp = (m_tmp >> 1); | ||||
e_ret = e_ret + 1; | e_ret = e_ret + 1; | ||||
@@ -1147,7 +1147,7 @@ fp16_t &fp16_t::operator=(const uint32_t &ui_val) { | |||||
} | } | ||||
} else { | } else { | ||||
e_ret = static_cast<int16_t>(kFp16ExpBias); | e_ret = static_cast<int16_t>(kFp16ExpBias); | ||||
m_tmp = m_tmp << (kDim11 - len); | |||||
m_tmp = m_tmp << (static_cast<uint16_t>(kDim11) - len); | |||||
e_ret = e_ret + (len - 1); | e_ret = e_ret + (len - 1); | ||||
} | } | ||||
auto m_ret = static_cast<uint16_t>(m_tmp); | auto m_ret = static_cast<uint16_t>(m_tmp); | ||||
@@ -131,6 +131,7 @@ const char *YOLO2REORG = "Yolo2Reorg"; | |||||
const char *REDUCESUM = "ReduceSum"; | const char *REDUCESUM = "ReduceSum"; | ||||
const char *SUM = "Sum"; | const char *SUM = "Sum"; | ||||
const char *CONSTANT = "Const"; | const char *CONSTANT = "Const"; | ||||
const char *FILECONSTANT = "FileConstant"; | |||||
const char *RESIZEBILINEAR = "ResizeBilinear"; | const char *RESIZEBILINEAR = "ResizeBilinear"; | ||||
const char *RESIZEBILINEARGRAD = "ResizeBilinearGrad"; | const char *RESIZEBILINEARGRAD = "ResizeBilinearGrad"; | ||||
const char *MAXIMUM = "Maximum"; | const char *MAXIMUM = "Maximum"; | ||||
@@ -19,8 +19,6 @@ | |||||
#include <memory> | #include <memory> | ||||
#include "common/fmk_error_codes.h" | |||||
namespace ge { | namespace ge { | ||||
/// | /// | ||||
/// @ingroup domi_omg | /// @ingroup domi_omg | ||||
@@ -218,9 +218,9 @@ Status PreChecker::Save(const string &file) { | |||||
// Constructing JSON information of operators in order of network | // Constructing JSON information of operators in order of network | ||||
for (auto id : ops_) { | for (auto id : ops_) { | ||||
auto iter = op_map_.find(id); | |||||
GE_CHK_BOOL_RET_STATUS(iter != op_map_.end(), FAILED, "[Check][Param] don't find this op."); | |||||
Info &info = iter->second; | |||||
std::map<OpId, Info>::const_iterator iter = op_map_.find(id); | |||||
GE_CHK_BOOL_RET_STATUS(iter != op_map_.cend(), FAILED, "[Check][Param] don't find this op."); | |||||
const Info &info = iter->second; | |||||
// Initialization operator general information | // Initialization operator general information | ||||
nlohmann::json op = {{kKeyOpName, info.name}, {kKeyOpType, info.type}}; | nlohmann::json op = {{kKeyOpName, info.name}, {kKeyOpType, info.type}}; | ||||
@@ -67,7 +67,7 @@ bool GetIdentifier(const std::string &line, int &identifier) { | |||||
break; | break; | ||||
} | } | ||||
if (line[i] >= kMinNum && line[i] <= kMaxNum) { | if (line[i] >= kMinNum && line[i] <= kMaxNum) { | ||||
identifier = identifier * kDecimalMulti + line[i] - kMinNum; | |||||
identifier = identifier * kDecimalMulti + static_cast<int>(line[i]) - static_cast<int>(kMinNum); | |||||
} | } | ||||
if (identifier > kMaxIdentifier || identifier < 0) { | if (identifier > kMaxIdentifier || identifier < 0) { | ||||
return false; | return false; | ||||
@@ -75,7 +75,7 @@ bool OpRegistrationTbe::Finalize(const OpRegistrationData ®_data, bool is_tra | |||||
return ret; | return ret; | ||||
} | } | ||||
bool OpRegistrationTbe::RegisterParser(const OpRegistrationData ®_data) { | |||||
bool OpRegistrationTbe::RegisterParser(const OpRegistrationData ®_data) const { | |||||
if (reg_data.GetFrameworkType() == domi::TENSORFLOW) { | if (reg_data.GetFrameworkType() == domi::TENSORFLOW) { | ||||
std::shared_ptr<OpParserFactory> factory = OpParserFactory::Instance(domi::TENSORFLOW); | std::shared_ptr<OpParserFactory> factory = OpParserFactory::Instance(domi::TENSORFLOW); | ||||
if (factory == nullptr) { | if (factory == nullptr) { | ||||
@@ -27,7 +27,7 @@ class OpRegistrationTbe { | |||||
bool Finalize(const OpRegistrationData ®_data, bool is_train = false); | bool Finalize(const OpRegistrationData ®_data, bool is_train = false); | ||||
private: | private: | ||||
bool RegisterParser(const OpRegistrationData ®_data); | |||||
bool RegisterParser(const OpRegistrationData ®_data) const; | |||||
}; | }; | ||||
} // namespace ge | } // namespace ge | ||||
@@ -4,6 +4,7 @@ set(SRC_LIST | |||||
"onnx_data_parser.cc" | "onnx_data_parser.cc" | ||||
"onnx_util.cc" | "onnx_util.cc" | ||||
"onnx_constant_parser.cc" | "onnx_constant_parser.cc" | ||||
"onnx_file_constant_parser.cc" | |||||
"subgraph_adapter/if_subgraph_adapter.cc" | "subgraph_adapter/if_subgraph_adapter.cc" | ||||
"subgraph_adapter/subgraph_adapter_factory.cc" | "subgraph_adapter/subgraph_adapter_factory.cc" | ||||
) | ) | ||||
@@ -17,6 +17,7 @@ PARSER_ONNX_SRC_FILES := \ | |||||
onnx_data_parser.cc \ | onnx_data_parser.cc \ | ||||
onnx_util.cc \ | onnx_util.cc \ | ||||
onnx_constant_parser.cc \ | onnx_constant_parser.cc \ | ||||
onnx_file_constant_parser.cc \ | |||||
proto/onnx/ge_onnx.proto \ | proto/onnx/ge_onnx.proto \ | ||||
proto/om.proto \ | proto/om.proto \ | ||||
@@ -0,0 +1,150 @@ | |||||
/** | |||||
* Copyright 2022 Huawei Technologies Co., Ltd | |||||
* | |||||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||||
* you may not use this file except in compliance with the License. | |||||
* You may obtain a copy of the License at | |||||
* | |||||
* http://www.apache.org/licenses/LICENSE-2.0 | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, | |||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* See the License for the specific language governing permissions and | |||||
* limitations under the License. | |||||
*/ | |||||
#include "onnx_file_constant_parser.h" | |||||
#include <vector> | |||||
#include "graph/ge_tensor.h" | |||||
#include "parser/common/op_parser_factory.h" | |||||
#include "parser/onnx/onnx_util.h" | |||||
#include "framework/common/util.h" | |||||
#include "framework/common/types.h" | |||||
using ge::onnx::NodeProto; | |||||
using ge::onnx::TensorProto; | |||||
using domi::ONNX; | |||||
using GeShape = ge::GeShape; | |||||
using GeTensorDesc = ge::GeTensorDesc; | |||||
using namespace ge::parser; | |||||
namespace { | |||||
const std::string kAttrShape = "shape"; | |||||
const std::string kAttrDataType = "dtype"; | |||||
const std::string kFileConstantPath = "file_constant_path"; | |||||
const std::string kLocation = "location"; | |||||
const std::string kOffset = "offset"; | |||||
const int64_t kOffsetCoefficient = 4096; | |||||
const char *const kFileConstant = "FileConstant"; | |||||
} | |||||
namespace ge { | |||||
Status OnnxFileConstantParser::ParseParams(const Message *op_src, ge::Operator &op_def) { | |||||
GE_CHECK_NOTNULL(op_src); | |||||
const ge::onnx::NodeProto *node = reinterpret_cast<const ge::onnx::NodeProto *>(op_src); | |||||
GELOGD("Onnx op node name = %s, op type= %s, parse params", node->name().c_str(), node->op_type().c_str()); | |||||
ge::onnx::TensorProto tensor_proto; | |||||
if (GetTensorProto(node, tensor_proto) != SUCCESS) { | |||||
REPORT_INNER_ERROR("E19999", "node[%s] get tensor failed", node->name().c_str()); | |||||
GELOGE(domi::PARAM_INVALID, "[Get][TensorProto] node[%s] get tensor failed", node->name().c_str()); | |||||
return FAILED; | |||||
} | |||||
if (ParseDataType(tensor_proto, op_def) != SUCCESS) { | |||||
REPORT_INNER_ERROR("E19999", "node[%s] parse data type failed", node->name().c_str()); | |||||
GELOGE(domi::PARAM_INVALID, "[Parse][Shape] node[%s] parse data type failed", node->name().c_str()); | |||||
return FAILED; | |||||
} | |||||
if (ParsePath(tensor_proto, op_def) != SUCCESS) { | |||||
REPORT_INNER_ERROR("E19999", "node[%s] parse file path failed", node->name().c_str()); | |||||
GELOGE(domi::PARAM_INVALID, "[Parse][Shape] node[%s] parse file path failed", node->name().c_str()); | |||||
return FAILED; | |||||
} | |||||
ParseShape(tensor_proto, op_def); | |||||
return SUCCESS; | |||||
} | |||||
Status OnnxFileConstantParser::GetTensorProto(const ge::onnx::NodeProto *node_proto, | |||||
ge::onnx::TensorProto &tensor_proto) { | |||||
for (const auto &it : node_proto->attribute()) { | |||||
if (it.name() != ge::kAttrNameValue) { | |||||
continue; | |||||
} | |||||
tensor_proto = it.t(); | |||||
return SUCCESS; | |||||
} | |||||
REPORT_INNER_ERROR("E19999", "node_proto[%s] get value failed", node_proto->name().c_str()); | |||||
GELOGE(ge::PARAM_INVALID, "[Get][TensorProto] node_proto[%s] get value failed", node_proto->name().c_str()); | |||||
return FAILED; | |||||
} | |||||
void OnnxFileConstantParser::ParseShape(const ge::onnx::TensorProto &tensor_proto, ge::Operator &op_def) { | |||||
std::vector<int64_t> tmp_shape; | |||||
for (int i = 0; i < tensor_proto.dims_size(); i++) { | |||||
tmp_shape.push_back(tensor_proto.dims(i)); | |||||
} | |||||
op_def.SetAttr(kAttrShape.c_str(), tmp_shape); | |||||
} | |||||
Status OnnxFileConstantParser::ParseDataType(const ge::onnx::TensorProto &tensor_proto, ge::Operator &op_def) { | |||||
int64_t data_type = tensor_proto.data_type(); | |||||
ge::DataType type = ge::OnnxUtil::ConvertOnnxDataType(data_type); | |||||
if (type >= ge::DataType::DT_UNDEFINED) { | |||||
REPORT_INNER_ERROR("E19999", "tensor_proto date type %ld is undefined.", data_type); | |||||
GELOGE(domi::PARAM_INVALID, "[Check][Param] tensor_proto date type %ld is undefined.", data_type); | |||||
return FAILED; | |||||
} | |||||
op_def.SetAttr(kAttrDataType.c_str(), type); | |||||
return SUCCESS; | |||||
} | |||||
Status OnnxFileConstantParser::ParsePath(const ge::onnx::TensorProto &tensor_proto, ge::Operator &op_def) { | |||||
ge::NamedAttrs attrs; | |||||
for (int32_t i = 0; i < tensor_proto.external_data_size(); ++i) { | |||||
const ge::onnx::StringStringEntryProto &string_proto = tensor_proto.external_data(i); | |||||
if (SetPathAttr(string_proto, attrs) != SUCCESS) { | |||||
REPORT_INNER_ERROR("E19999", "external tensor proto[%s] parse attrs failed.", tensor_proto.name().c_str()); | |||||
GELOGE(domi::PARAM_INVALID, "external tensor proto[%s] parse attrs failed.", tensor_proto.name().c_str()); | |||||
return FAILED; | |||||
} | |||||
} | |||||
if (!attrs.HasAttr(kLocation)) { | |||||
REPORT_INNER_ERROR("E19999", "external tensor proto[%s] must contain location.", tensor_proto.name().c_str()); | |||||
GELOGE(domi::PARAM_INVALID, "external tensor proto[%s] must contain location.", tensor_proto.name().c_str()); | |||||
return FAILED; | |||||
} | |||||
op_def.SetAttr(kFileConstantPath.c_str(), attrs); | |||||
return SUCCESS; | |||||
} | |||||
Status OnnxFileConstantParser::SetPathAttr(const ge::onnx::StringStringEntryProto &string_proto, | |||||
ge::NamedAttrs &attrs) { | |||||
if (string_proto.key() == kLocation) { | |||||
AttrUtils::SetStr(attrs, kLocation, string_proto.value()); | |||||
} else { | |||||
int64_t value; | |||||
try { | |||||
value = stol(string_proto.value()); | |||||
} catch (const std::exception &e) { | |||||
REPORT_INNER_ERROR("E19999", "Convert %s to int64_t value failed:%s", string_proto.value().c_str(), e.what()); | |||||
GELOGE(domi::PARAM_INVALID, "Convert %s to int64_t value failed:%s", string_proto.value().c_str(), e.what()); | |||||
return FAILED; | |||||
} | |||||
if (string_proto.key() == kOffset) { | |||||
if (std::numeric_limits<int64_t>::max() / kOffsetCoefficient < value) { | |||||
REPORT_INNER_ERROR("E19999", "overflow, kOffsetCoefficient[%ld], value[%ld]", kOffsetCoefficient, value); | |||||
GELOGE(domi::PARAM_INVALID, "overflow, kOffsetCoefficient[%ld], value[%ld]", kOffsetCoefficient, value); | |||||
return FAILED; | |||||
} | |||||
value *= kOffsetCoefficient; | |||||
} | |||||
AttrUtils::SetInt(attrs, string_proto.key(), value); | |||||
} | |||||
return SUCCESS; | |||||
} | |||||
REGISTER_OP_PARSER_CREATOR(ONNX, kFileConstant, OnnxFileConstantParser); | |||||
} // namespace ge |
@@ -0,0 +1,37 @@ | |||||
/** | |||||
* Copyright 2022 Huawei Technologies Co., Ltd | |||||
* | |||||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||||
* you may not use this file except in compliance with the License. | |||||
* You may obtain a copy of the License at | |||||
* | |||||
* http://www.apache.org/licenses/LICENSE-2.0 | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, | |||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* See the License for the specific language governing permissions and | |||||
* limitations under the License. | |||||
*/ | |||||
#ifndef GE_PARSER_ONNX_ONNX_FILE_CONSTANT_PARSER_H_ | |||||
#define GE_PARSER_ONNX_ONNX_FILE_CONSTANT_PARSER_H_ | |||||
#include "parser/onnx/onnx_op_parser.h" | |||||
#include "proto/onnx/ge_onnx.pb.h" | |||||
namespace ge { | |||||
class PARSER_FUNC_VISIBILITY OnnxFileConstantParser : public OnnxOpParser { | |||||
public: | |||||
Status ParseParams(const Message *op_src, ge::Operator &op_def) override; | |||||
private: | |||||
Status ParsePath(const ge::onnx::TensorProto &tensor_proto, ge::Operator &op_def); | |||||
Status ParseDataType(const ge::onnx::TensorProto &tensor_proto, ge::Operator &op_def); | |||||
void ParseShape(const ge::onnx::TensorProto &tensor_proto, ge::Operator &op_def); | |||||
Status GetTensorProto(const ge::onnx::NodeProto *node_proto, ge::onnx::TensorProto &tensor_proto); | |||||
Status SetPathAttr(const ge::onnx::StringStringEntryProto &string_proto, ge::NamedAttrs &attrs); | |||||
}; | |||||
} // namespace ge | |||||
#endif // GE_PARSER_ONNX_ONNX_FILE_CONSTANT_PARSER_H_ |
@@ -44,6 +44,12 @@ | |||||
#include "graph/utils/node_utils.h" | #include "graph/utils/node_utils.h" | ||||
#include "graph/utils/type_utils.h" | #include "graph/utils/type_utils.h" | ||||
#include "subgraph_adapter/subgraph_adapter_factory.h" | #include "subgraph_adapter/subgraph_adapter_factory.h" | ||||
#include "framework/common/types.h" | |||||
#include "mmpa/mmpa_api.h" | |||||
namespace { | |||||
const std::string kLocation = "location"; | |||||
} | |||||
namespace ge { | namespace ge { | ||||
graphStatus PrepareBeforeParse(AclGrphParseUtil &acl_graph_parse_util, | graphStatus PrepareBeforeParse(AclGrphParseUtil &acl_graph_parse_util, | ||||
@@ -160,7 +166,8 @@ namespace ge { | |||||
namespace { | namespace { | ||||
const std::map<std::string, std::string> kOnnxOpMap = { | const std::map<std::string, std::string> kOnnxOpMap = { | ||||
{ge::kOpTypeInput, ge::parser::DATA}, | {ge::kOpTypeInput, ge::parser::DATA}, | ||||
{ge::kOpTypeConstant, ge::parser::CONSTANT} | |||||
{ge::kOpTypeConstant, ge::parser::CONSTANT}, | |||||
{ge::kFileConstant, ge::parser::FILECONSTANT} | |||||
}; | }; | ||||
const int64_t kDimValue = 1; | const int64_t kDimValue = 1; | ||||
@@ -350,12 +357,16 @@ Status OnnxModelParser::ParseInitializer(ge::onnx::GraphProto &onnx_graph, | |||||
ge::onnx::NodeProto *const_node = onnx_graph.add_node(); | ge::onnx::NodeProto *const_node = onnx_graph.add_node(); | ||||
std::string output_name = it.first + "_" + to_string(index++); | std::string output_name = it.first + "_" + to_string(index++); | ||||
const_node->set_name(output_name); | const_node->set_name(output_name); | ||||
const_node->set_op_type(ge::kOpTypeConstant); | |||||
const_node->add_output(it.first); | const_node->add_output(it.first); | ||||
ge::onnx::AttributeProto *attribute = const_node->add_attribute(); | ge::onnx::AttributeProto *attribute = const_node->add_attribute(); | ||||
attribute->set_name(ge::kAttrNameValue); | attribute->set_name(ge::kAttrNameValue); | ||||
ge::onnx::TensorProto *attribute_t = attribute->mutable_t(); | ge::onnx::TensorProto *attribute_t = attribute->mutable_t(); | ||||
*attribute_t = it.second; | *attribute_t = it.second; | ||||
if (it.second.data_location() == ge::onnx::TensorProto_DataLocation_EXTERNAL) { | |||||
const_node->set_op_type(kFileConstant); | |||||
} else { | |||||
const_node->set_op_type(ge::kOpTypeConstant); | |||||
} | |||||
} | } | ||||
return SUCCESS; | return SUCCESS; | ||||
@@ -723,6 +734,51 @@ Status OnnxModelParser::GetModelFromFile(const char *file, ge::onnx::ModelProto | |||||
GELOGE(PARAM_INVALID, "[Read][ModeFile] failed."); | GELOGE(PARAM_INVALID, "[Read][ModeFile] failed."); | ||||
return FAILED; | return FAILED; | ||||
} | } | ||||
if (SetExternalPath(file, onnx_model) != SUCCESS) { | |||||
REPORT_CALL_ERROR("E19999", "Set external path failed, file[%s]", file); | |||||
GELOGE(PARAM_INVALID, "[Set][ExternalPath] failed."); | |||||
return PARAM_INVALID; | |||||
} | |||||
return SUCCESS; | |||||
} | |||||
Status OnnxModelParser::SetExternalPath(const char *file, ge::onnx::ModelProto &onnx_model) const { | |||||
std::string real_path = ge::parser::RealPath(file); | |||||
const size_t file_len = real_path.length(); | |||||
std::unique_ptr<char[]> tmp_file(new (std::nothrow) char[file_len + 1U]); | |||||
GE_CHECK_NOTNULL(tmp_file); | |||||
const auto ret = strncpy_s(tmp_file.get(), file_len + 1U, real_path.c_str(), file_len); | |||||
if (ret != EN_OK) { | |||||
REPORT_CALL_ERROR("E19999", "strncpy_s failed, src=%p, dst=%p, src_len=%zu, dst_len=%zu, ret=%d.", | |||||
real_path.c_str(), tmp_file.get(), file_len, file_len + 1U, ret); | |||||
GELOGE(FAILED, "strncpy_s failed, src=%p, dst=%p, src_len=%zu, dst_len=%zu.", | |||||
real_path.c_str(), tmp_file.get(), file_len, file_len + 1U); | |||||
return FAILED; | |||||
} | |||||
const char *const dir = mmDirName(tmp_file.get()); | |||||
GE_CHECK_NOTNULL(dir); | |||||
const ge::onnx::GraphProto &onnx_graph = onnx_model.graph(); | |||||
for (int32_t i = 0; i < onnx_graph.initializer_size(); ++i) { | |||||
const ge::onnx::TensorProto &initializer_tensor = onnx_graph.initializer(i); | |||||
if (initializer_tensor.data_location() != ge::onnx::TensorProto_DataLocation_EXTERNAL) { | |||||
continue; | |||||
} | |||||
for (int32_t j = 0; j < initializer_tensor.external_data_size(); ++j) { | |||||
ge::onnx::StringStringEntryProto &string_proto = | |||||
const_cast<ge::onnx::StringStringEntryProto &>(initializer_tensor.external_data(j)); | |||||
if (string_proto.key() != kLocation) { | |||||
continue; | |||||
} | |||||
const std::string &file_name = string_proto.value(); | |||||
const std::string new_file = std::string(dir) + MMPA_PATH_SEPARATOR_STR + file_name; | |||||
GELOGD("[%s] is external data. concat dir[%s] and file_name[%s], new_file[%s]", | |||||
initializer_tensor.name().c_str(), dir, file_name.c_str(), new_file.c_str()); | |||||
string_proto.set_value(new_file); | |||||
} | |||||
} | |||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
@@ -126,6 +126,8 @@ class PARSER_FUNC_VISIBILITY OnnxModelParser : public domi::ModelParser { | |||||
Status GetModelFromFile(const char *file, ge::onnx::ModelProto &onnx_model) const; | Status GetModelFromFile(const char *file, ge::onnx::ModelProto &onnx_model) const; | ||||
Status SetExternalPath(const char *file, ge::onnx::ModelProto &onnx_model) const; | |||||
Status GetModelFromMemory(const char *data, uint32_t size, ge::onnx::ModelProto &onnx_model) const; | Status GetModelFromMemory(const char *data, uint32_t size, ge::onnx::ModelProto &onnx_model) const; | ||||
Status ModelParseToGraph(const ge::onnx::ModelProto &onnx_model, ge::Graph &root_graph); | Status ModelParseToGraph(const ge::onnx::ModelProto &onnx_model, ge::Graph &root_graph); | ||||
@@ -48,6 +48,7 @@ const char *const kAttrNameIndex = "index"; | |||||
const char *const kAttrNameIsSubgraphOp = "is_subgraph_op"; | const char *const kAttrNameIsSubgraphOp = "is_subgraph_op"; | ||||
const char *const kOpTypeConstant = "Constant"; | const char *const kOpTypeConstant = "Constant"; | ||||
const char *const kOpTypeInput = "Input"; | const char *const kOpTypeInput = "Input"; | ||||
const char *const kFileConstant = "FileConstant"; | |||||
class OnnxUtil { | class OnnxUtil { | ||||
public: | public: | ||||
@@ -45,7 +45,7 @@ domi::Status IfSubgraphAdapter::AdaptAndFindAllSubgraphs( | |||||
domi::Status IfSubgraphAdapter::ParseIfNodeSubgraphs( | domi::Status IfSubgraphAdapter::ParseIfNodeSubgraphs( | ||||
ge::onnx::NodeProto *parent_node, std::vector<ge::onnx::GraphProto *> &onnx_graphs, | ge::onnx::NodeProto *parent_node, std::vector<ge::onnx::GraphProto *> &onnx_graphs, | ||||
std::map<std::string, ge::onnx::GraphProto *> &name_to_onnx_graph, const std::string &parent_graph_name) { | |||||
std::map<std::string, ge::onnx::GraphProto *> &name_to_onnx_graph, const std::string &parent_graph_name) const { | |||||
if (parent_node->attribute_size() != kIfNodeAttrSize) { | if (parent_node->attribute_size() != kIfNodeAttrSize) { | ||||
GELOGE(FAILED, "[Parse][Node] Invalid graph, if node attribute size:%d must be 2.", parent_node->attribute_size()); | GELOGE(FAILED, "[Parse][Node] Invalid graph, if node attribute size:%d must be 2.", parent_node->attribute_size()); | ||||
REPORT_INNER_ERROR("E19999", "Invalid graph, if node attribute size:%d must be 2.", parent_node->attribute_size()); | REPORT_INNER_ERROR("E19999", "Invalid graph, if node attribute size:%d must be 2.", parent_node->attribute_size()); | ||||
@@ -32,7 +32,7 @@ class PARSER_FUNC_VISIBILITY IfSubgraphAdapter : public SubgraphAdapter { | |||||
private: | private: | ||||
domi::Status ParseIfNodeSubgraphs(ge::onnx::NodeProto *parent_node, std::vector<ge::onnx::GraphProto *> &onnx_graphs, | domi::Status ParseIfNodeSubgraphs(ge::onnx::NodeProto *parent_node, std::vector<ge::onnx::GraphProto *> &onnx_graphs, | ||||
std::map<std::string, ge::onnx::GraphProto *> &name_to_onnx_graph, | std::map<std::string, ge::onnx::GraphProto *> &name_to_onnx_graph, | ||||
const std::string &parent_graph_name); | |||||
const std::string &parent_graph_name) const; | |||||
domi::Status GetSubgraphsAllInputs(ge::onnx::GraphProto &onnx_graph, std::set<std::string> &all_inputs) const; | domi::Status GetSubgraphsAllInputs(ge::onnx::GraphProto &onnx_graph, std::set<std::string> &all_inputs) const; | ||||
void AddInputNodeForGraph(const std::set<std::string> &all_inputs, ge::onnx::GraphProto &onnx_graph) const; | void AddInputNodeForGraph(const std::set<std::string> &all_inputs, ge::onnx::GraphProto &onnx_graph) const; | ||||
void AddInputForParentNode(const std::set<std::string> &all_inputs, ge::onnx::NodeProto &parent_node) const; | void AddInputForParentNode(const std::set<std::string> &all_inputs, ge::onnx::NodeProto &parent_node) const; | ||||
@@ -59,7 +59,7 @@ Status TensorFlowFusionCustomParserAdapter::ParseParams(const vector<const NodeD | |||||
} | } | ||||
Status TensorFlowFusionCustomParserAdapter::ParseParams(const std::vector<ge::Operator> &v_input_const, | Status TensorFlowFusionCustomParserAdapter::ParseParams(const std::vector<ge::Operator> &v_input_const, | ||||
ge::NodePtr &node) { | |||||
ge::NodePtr &node) const { | |||||
GE_CHECK_NOTNULL(node); | GE_CHECK_NOTNULL(node); | ||||
auto op_dest = node->GetOpDesc(); | auto op_dest = node->GetOpDesc(); | ||||
GE_CHECK_NOTNULL(op_dest); | GE_CHECK_NOTNULL(op_dest); | ||||
@@ -42,7 +42,7 @@ class PARSER_FUNC_VISIBILITY TensorFlowFusionCustomParserAdapter : public Tensor | |||||
* @return FAILED parse failed | * @return FAILED parse failed | ||||
* @author | * @author | ||||
*/ | */ | ||||
Status ParseParams(const std::vector<ge::Operator> &v_input_const, ge::NodePtr &node); | |||||
Status ParseParams(const std::vector<ge::Operator> &v_input_const, ge::NodePtr &node) const; | |||||
}; | }; | ||||
} // namespace ge | } // namespace ge | ||||
@@ -195,11 +195,10 @@ void AddDumpOriginName(const std::string& subgraph_name, const ge::NodePtr paren | |||||
auto parend_desc = parent_node->GetOpDesc(); | auto parend_desc = parent_node->GetOpDesc(); | ||||
(void)ge::AttrUtils::GetListStr(parend_desc, ge::ATTR_NAME_DATA_DUMP_ORIGIN_OP_NAMES, original_names); | (void)ge::AttrUtils::GetListStr(parend_desc, ge::ATTR_NAME_DATA_DUMP_ORIGIN_OP_NAMES, original_names); | ||||
if (original_names.empty()) { | if (original_names.empty()) { | ||||
original_names.emplace_back(string(subgraph_name).append("/").append(node->GetName())); | |||||
} else { | |||||
// for fusion node also used original_names[0] | |||||
(void)original_names[0].append("/").append(subgraph_name).append("/").append(node->GetName()); | |||||
original_names.emplace_back(parent_node->GetName()); | |||||
} | } | ||||
// for fusion node also used original_names[0] | |||||
(void)original_names[0].append("/").append(subgraph_name).append("/").append(node->GetName()); | |||||
if (!ge::AttrUtils::SetListStr(node->GetOpDesc(), ge::ATTR_NAME_DATA_DUMP_ORIGIN_OP_NAMES, original_names)) { | if (!ge::AttrUtils::SetListStr(node->GetOpDesc(), ge::ATTR_NAME_DATA_DUMP_ORIGIN_OP_NAMES, original_names)) { | ||||
GELOGW("Set %s to %s fail.", ge::ATTR_NAME_DATA_DUMP_ORIGIN_OP_NAMES.c_str(), node->GetOpDesc()->GetName().c_str()); | GELOGW("Set %s to %s fail.", ge::ATTR_NAME_DATA_DUMP_ORIGIN_OP_NAMES.c_str(), node->GetOpDesc()->GetName().c_str()); | ||||
@@ -3050,7 +3049,7 @@ Status TensorFlowModelParser::TrimGraphByInput(const domi::tensorflow::GraphDef | |||||
GE_CHECK_NOTNULL(current_node); | GE_CHECK_NOTNULL(current_node); | ||||
for (const string &input_name : current_node->input()) { | for (const string &input_name : current_node->input()) { | ||||
string input_node_name = NodeNameFromInput(input_name); | string input_node_name = NodeNameFromInput(input_name); | ||||
if (!delete_nodes.count(input_node_name)) { | |||||
if (delete_nodes.count(input_node_name) == 0U) { | |||||
next_inputs.insert(input_node_name); | next_inputs.insert(input_node_name); | ||||
} | } | ||||
} | } | ||||
@@ -3063,7 +3062,7 @@ Status TensorFlowModelParser::TrimGraphByInput(const domi::tensorflow::GraphDef | |||||
if (static_cast<bool>(input_nodes.count(node.name()))) { | if (static_cast<bool>(input_nodes.count(node.name()))) { | ||||
*(filtered_graph_def.mutable_node()->Add()) = node; | *(filtered_graph_def.mutable_node()->Add()) = node; | ||||
} | } | ||||
if (!delete_nodes.count(node.name())) { | |||||
if (delete_nodes.count(node.name()) == 0U) { | |||||
*(filtered_graph_def.mutable_node()->Add()) = node; | *(filtered_graph_def.mutable_node()->Add()) = node; | ||||
} | } | ||||
} | } | ||||
@@ -3126,7 +3125,7 @@ Status TensorFlowModelParser::TrimGraphByOutput(const domi::tensorflow::GraphDef | |||||
GE_CHECK_NOTNULL(current_node); | GE_CHECK_NOTNULL(current_node); | ||||
for (const string &input_name : current_node->input()) { | for (const string &input_name : current_node->input()) { | ||||
string input_node_name = NodeNameFromInput(input_name); | string input_node_name = NodeNameFromInput(input_name); | ||||
if (!required_nodes.count(input_node_name)) { | |||||
if (required_nodes.count(input_node_name) == 0U) { | |||||
next_inputs.insert(input_node_name); | next_inputs.insert(input_node_name); | ||||
} | } | ||||
} | } | ||||
@@ -15,6 +15,7 @@ | |||||
*/ | */ | ||||
#include "mmpa/mmpa_api.h" | #include "mmpa/mmpa_api.h" | ||||
#include <string> | |||||
typedef int mmErrorMSg; | typedef int mmErrorMSg; | ||||
@@ -301,3 +302,22 @@ CHAR *mmGetErrorFormatMessage(mmErrorMSg errnum, CHAR *buf, mmSize size) | |||||
} | } | ||||
return strerror_r(errnum, buf, size); | return strerror_r(errnum, buf, size); | ||||
} | } | ||||
CHAR *mmDirName(CHAR *path) { | |||||
if (path == NULL) { | |||||
return NULL; | |||||
} | |||||
#if (defined(_WIN32) || defined(_WIN64) || defined(_MSC_VER)) | |||||
char separator = '\\'; | |||||
#else | |||||
char separator = '/'; | |||||
#endif | |||||
std::string path_str(path); | |||||
const size_t last_sep_pos = path_str.rfind(separator); | |||||
if (last_sep_pos == std::string::npos) { | |||||
return NULL; | |||||
} | |||||
path[last_sep_pos] = '\0'; | |||||
return path; | |||||
} |
@@ -277,6 +277,7 @@ set(PARSER_SRC_FILES | |||||
"${PARSER_DIR}/parser/common/thread_pool.cc" | "${PARSER_DIR}/parser/common/thread_pool.cc" | ||||
"${PARSER_DIR}/parser/common/auto_mapping_subgraph_io_index_func.cc" | "${PARSER_DIR}/parser/common/auto_mapping_subgraph_io_index_func.cc" | ||||
"${PARSER_DIR}/parser/onnx/onnx_constant_parser.cc" | "${PARSER_DIR}/parser/onnx/onnx_constant_parser.cc" | ||||
"${PARSER_DIR}/parser/onnx/onnx_file_constant_parser.cc" | |||||
"${PARSER_DIR}/parser/onnx/onnx_custom_parser_adapter.cc" | "${PARSER_DIR}/parser/onnx/onnx_custom_parser_adapter.cc" | ||||
"${PARSER_DIR}/parser/onnx/onnx_data_parser.cc" | "${PARSER_DIR}/parser/onnx/onnx_data_parser.cc" | ||||
"${PARSER_DIR}/parser/onnx/onnx_parser.cc" | "${PARSER_DIR}/parser/onnx/onnx_parser.cc" | ||||
@@ -4245,7 +4245,7 @@ TEST_F(STestTensorflowParser, AddDumpOriginName_test) | |||||
std::vector<std::string> original_names; | std::vector<std::string> original_names; | ||||
(void)ge::AttrUtils::GetListStr(desc, ge::ATTR_NAME_DATA_DUMP_ORIGIN_OP_NAMES, original_names); | (void)ge::AttrUtils::GetListStr(desc, ge::ATTR_NAME_DATA_DUMP_ORIGIN_OP_NAMES, original_names); | ||||
EXPECT_EQ(original_names.empty(), false); | EXPECT_EQ(original_names.empty(), false); | ||||
EXPECT_EQ(original_names[0], "while/COND0/cond/Data1"); | |||||
EXPECT_EQ(original_names[0], "WHILE0/while/COND0/cond/Data1"); | |||||
} | } | ||||
} // namespace ge | } // namespace ge |
@@ -278,6 +278,7 @@ set(PARSER_SRC_FILES | |||||
"${PARSER_DIR}/parser/common/thread_pool.cc" | "${PARSER_DIR}/parser/common/thread_pool.cc" | ||||
"${PARSER_DIR}/parser/common/auto_mapping_subgraph_io_index_func.cc" | "${PARSER_DIR}/parser/common/auto_mapping_subgraph_io_index_func.cc" | ||||
"${PARSER_DIR}/parser/onnx/onnx_constant_parser.cc" | "${PARSER_DIR}/parser/onnx/onnx_constant_parser.cc" | ||||
"${PARSER_DIR}/parser/onnx/onnx_file_constant_parser.cc" | |||||
"${PARSER_DIR}/parser/onnx/onnx_custom_parser_adapter.cc" | "${PARSER_DIR}/parser/onnx/onnx_custom_parser_adapter.cc" | ||||
"${PARSER_DIR}/parser/onnx/onnx_data_parser.cc" | "${PARSER_DIR}/parser/onnx/onnx_data_parser.cc" | ||||
"${PARSER_DIR}/parser/onnx/onnx_parser.cc" | "${PARSER_DIR}/parser/onnx/onnx_parser.cc" | ||||
@@ -30,6 +30,7 @@ | |||||
#define protected public | #define protected public | ||||
#define private public | #define private public | ||||
#include "parser/onnx/onnx_constant_parser.h" | #include "parser/onnx/onnx_constant_parser.h" | ||||
#include "parser/onnx/onnx_file_constant_parser.h" | |||||
#include "parser/onnx/onnx_util.h" | #include "parser/onnx/onnx_util.h" | ||||
#include "parser/onnx/onnx_parser.h" | #include "parser/onnx/onnx_parser.h" | ||||
#undef protected | #undef protected | ||||
@@ -316,6 +317,190 @@ TEST_F(UtestOnnxParser, OnnxConstantParser_ParseConvertDataType_test) | |||||
EXPECT_EQ(ret, FAILED); | EXPECT_EQ(ret, FAILED); | ||||
} | } | ||||
TEST_F(UtestOnnxParser, FileConstantGetTensorProto) | |||||
{ | |||||
OnnxFileConstantParser parser; | |||||
ge::onnx::NodeProto input_node; | |||||
ge::onnx::TensorProto tensor_proto; | |||||
Status ret = parser.GetTensorProto(&input_node, tensor_proto); | |||||
EXPECT_EQ(ret, FAILED); | |||||
ge::onnx::AttributeProto *attribute = input_node.add_attribute(); | |||||
attribute->set_name("attribute"); | |||||
attribute = input_node.add_attribute(); | |||||
attribute->set_name("value"); | |||||
ge::onnx::TensorProto *attribute_tensor = attribute->mutable_t(); | |||||
*attribute_tensor = tensor_proto; | |||||
ret = parser.GetTensorProto(&input_node, tensor_proto); | |||||
EXPECT_EQ(ret, SUCCESS); | |||||
} | |||||
TEST_F(UtestOnnxParser, FileConstantParseShape) | |||||
{ | |||||
OnnxFileConstantParser parser; | |||||
ge::onnx::TensorProto tensor_proto; | |||||
tensor_proto.add_dims(4); | |||||
tensor_proto.add_dims(2); | |||||
ge::OpDescPtr op_desc_src = std::make_shared<ge::OpDesc>("file_constant", "FileConstant"); | |||||
ge::Operator op = ge::OpDescUtils::CreateOperatorFromOpDesc(op_desc_src); | |||||
parser.ParseShape(tensor_proto, op); | |||||
std::vector<int64_t> attr_value; | |||||
op.GetAttr("shape", attr_value); | |||||
EXPECT_EQ(attr_value.size(), 2U); | |||||
if (attr_value.size() == 2U) { | |||||
EXPECT_EQ(attr_value[0], 4); | |||||
EXPECT_EQ(attr_value[1], 2); | |||||
} | |||||
} | |||||
TEST_F(UtestOnnxParser, FileConstantParseDataType) | |||||
{ | |||||
OnnxFileConstantParser parser; | |||||
ge::onnx::TensorProto tensor_proto; | |||||
tensor_proto.set_data_type(OnnxDataType::UNDEFINED); | |||||
ge::OpDescPtr op_desc_src = std::make_shared<ge::OpDesc>("file_constant", "FileConstant"); | |||||
ge::Operator op = ge::OpDescUtils::CreateOperatorFromOpDesc(op_desc_src); | |||||
Status ret = parser.ParseDataType(tensor_proto, op); | |||||
EXPECT_EQ(ret, FAILED); | |||||
tensor_proto.set_data_type(OnnxDataType::UINT8); | |||||
ret = parser.ParseDataType(tensor_proto, op); | |||||
EXPECT_EQ(ret, SUCCESS); | |||||
ge::DataType attr_value; | |||||
op.GetAttr("dtype", attr_value); | |||||
EXPECT_EQ(attr_value, ge::DataType::DT_UINT8); | |||||
} | |||||
TEST_F(UtestOnnxParser, FileConstantParseAttr) | |||||
{ | |||||
OnnxFileConstantParser parser; | |||||
ge::onnx::StringStringEntryProto string_proto; | |||||
ge::NamedAttrs attrs; | |||||
// test location | |||||
string_proto.set_key("location"); | |||||
string_proto.set_value("/usr/local"); | |||||
Status ret = parser.SetPathAttr(string_proto, attrs); | |||||
EXPECT_EQ(ret, SUCCESS); | |||||
std::string attr_value; | |||||
AttrUtils::GetStr(attrs, "location", attr_value); | |||||
EXPECT_EQ(attr_value, "/usr/local"); | |||||
// test offset | |||||
string_proto.set_key("offset"); | |||||
string_proto.set_value("123"); | |||||
ret = parser.SetPathAttr(string_proto, attrs); | |||||
EXPECT_EQ(ret, SUCCESS); | |||||
int64_t offset_value; | |||||
AttrUtils::GetInt(attrs, "offset", offset_value); | |||||
EXPECT_EQ(offset_value, 123 * 4096); | |||||
// offset overflow | |||||
string_proto.set_key("offset"); | |||||
string_proto.set_value("9223372036854775800"); | |||||
ret = parser.SetPathAttr(string_proto, attrs); | |||||
EXPECT_EQ(ret, FAILED); | |||||
// itol exception | |||||
string_proto.set_key("offset"); | |||||
string_proto.set_value("999999999999999999999999999999999999"); | |||||
ret = parser.SetPathAttr(string_proto, attrs); | |||||
EXPECT_EQ(ret, FAILED); | |||||
} | |||||
TEST_F(UtestOnnxParser, FileConstantParsePath) | |||||
{ | |||||
OnnxFileConstantParser parser; | |||||
ge::onnx::TensorProto tensor_proto; | |||||
ge::OpDescPtr op_desc_src = std::make_shared<ge::OpDesc>("file_constant", "FileConstant"); | |||||
ge::Operator op = ge::OpDescUtils::CreateOperatorFromOpDesc(op_desc_src); | |||||
// without location, error | |||||
auto ret = parser.ParsePath(tensor_proto, op); | |||||
EXPECT_EQ(ret, FAILED); | |||||
// SetPathAttr error | |||||
ge::onnx::StringStringEntryProto *offset_proto = tensor_proto.add_external_data(); | |||||
offset_proto->set_key("offset"); | |||||
offset_proto->set_value("999999999999999999999999999999"); | |||||
ret = parser.ParsePath(tensor_proto, op); | |||||
EXPECT_EQ(ret, FAILED); | |||||
// has location, success | |||||
ge::onnx::StringStringEntryProto *string_proto = tensor_proto.add_external_data(); | |||||
string_proto->set_key("location"); | |||||
string_proto->set_value("/usr/local"); | |||||
offset_proto->set_key("offset"); | |||||
offset_proto->set_value("0"); | |||||
ret = parser.ParsePath(tensor_proto, op); | |||||
EXPECT_EQ(ret, SUCCESS); | |||||
// check location | |||||
std::string attr_value; | |||||
ge::NamedAttrs attrs; | |||||
AttrUtils::GetNamedAttrs(op_desc_src, "file_constant_path", attrs); | |||||
AttrUtils::GetStr(attrs, "location", attr_value); | |||||
EXPECT_EQ(attr_value, "/usr/local"); | |||||
} | |||||
TEST_F(UtestOnnxParser, FileConstantParseParam) | |||||
{ | |||||
OnnxFileConstantParser parser; | |||||
ge::onnx::NodeProto input_node; | |||||
ge::OpDescPtr op_desc_src = std::make_shared<ge::OpDesc>("file_constant", "FileConstant"); | |||||
ge::Operator op = ge::OpDescUtils::CreateOperatorFromOpDesc(op_desc_src); | |||||
// get tensor proto failed | |||||
auto ret = parser.ParseParams(reinterpret_cast<Message *>(&input_node), op); | |||||
EXPECT_EQ(ret, FAILED); | |||||
ge::onnx::TensorProto tensor_proto; | |||||
ge::onnx::AttributeProto *attribute = input_node.add_attribute(); | |||||
attribute->set_name("value"); | |||||
ge::onnx::TensorProto *attribute_tensor = attribute->mutable_t(); | |||||
*attribute_tensor = tensor_proto; | |||||
// parse data type failed | |||||
attribute_tensor->set_data_type(OnnxDataType::UNDEFINED); | |||||
ret = parser.ParseParams(reinterpret_cast<Message *>(&input_node), op); | |||||
EXPECT_EQ(ret, FAILED); | |||||
// parse path failed | |||||
attribute_tensor->set_data_type(OnnxDataType::UINT16); | |||||
ret = parser.ParseParams(reinterpret_cast<Message *>(&input_node), op); | |||||
EXPECT_EQ(ret, FAILED); | |||||
// success | |||||
ge::onnx::StringStringEntryProto *string_proto = attribute_tensor->add_external_data(); | |||||
string_proto->set_key("location"); | |||||
string_proto->set_value("/usr/local"); | |||||
attribute_tensor->add_dims(4); | |||||
ret = parser.ParseParams(reinterpret_cast<Message *>(&input_node), op); | |||||
EXPECT_EQ(ret, SUCCESS); | |||||
// check location, shape, dtype | |||||
NamedAttrs attrs; | |||||
AttrUtils::GetNamedAttrs(*op_desc_src, "file_constant_path", attrs); | |||||
std::string file_path; | |||||
AttrUtils::GetStr(attrs, "location", file_path); | |||||
EXPECT_EQ(file_path, "/usr/local"); | |||||
std::vector<int64_t> dims; | |||||
op.GetAttr("shape", dims); | |||||
EXPECT_EQ(dims.size(), 1); | |||||
if (!dims.empty()) { | |||||
EXPECT_EQ(dims[0], 4); | |||||
} | |||||
DataType dtype; | |||||
op.GetAttr("dtype", dtype); | |||||
EXPECT_EQ(dtype, ge::DataType::DT_UINT16); | |||||
} | |||||
TEST_F(UtestOnnxParser, OnnxModelParser_ParseInput_test) | TEST_F(UtestOnnxParser, OnnxModelParser_ParseInput_test) | ||||
{ | { | ||||
OnnxModelParser model_parser; | OnnxModelParser model_parser; | ||||
@@ -388,6 +573,25 @@ TEST_F(UtestOnnxParser, onnx_test_ModelParseToGraph) | |||||
EXPECT_EQ(ret, FAILED); | EXPECT_EQ(ret, FAILED); | ||||
} | } | ||||
TEST_F(UtestOnnxParser, onnx_test_SetExternalPath) | |||||
{ | |||||
OnnxModelParser modelParser; | |||||
ge::onnx::ModelProto model_proto; | |||||
auto ret = modelParser.SetExternalPath("", model_proto); | |||||
EXPECT_NE(ret, SUCCESS); | |||||
ge::onnx::GraphProto &graph_proto = const_cast<ge::onnx::GraphProto &>(model_proto.graph()); | |||||
graph_proto.add_initializer(); | |||||
ge::onnx::TensorProto* tensor_proto = graph_proto.add_initializer(); | |||||
tensor_proto->set_data_location(ge::onnx::TensorProto_DataLocation::TensorProto_DataLocation_EXTERNAL); | |||||
tensor_proto->add_external_data(); | |||||
ge::onnx::StringStringEntryProto *string_proto = tensor_proto->add_external_data(); | |||||
string_proto->set_key("location"); | |||||
string_proto->set_value("if.onnx"); | |||||
ret = modelParser.SetExternalPath("/usr/local", model_proto); | |||||
EXPECT_EQ(ret, SUCCESS); | |||||
} | |||||
TEST_F(UtestOnnxParser, onnx_test_ParseFromMemory) | TEST_F(UtestOnnxParser, onnx_test_ParseFromMemory) | ||||
{ | { | ||||
OnnxModelParser modelParser; | OnnxModelParser modelParser; | ||||
@@ -4712,7 +4712,7 @@ TEST_F(UtestTensorflowParser, AddDumpOriginName_test) | |||||
std::vector<std::string> original_names; | std::vector<std::string> original_names; | ||||
(void)ge::AttrUtils::GetListStr(desc, ge::ATTR_NAME_DATA_DUMP_ORIGIN_OP_NAMES, original_names); | (void)ge::AttrUtils::GetListStr(desc, ge::ATTR_NAME_DATA_DUMP_ORIGIN_OP_NAMES, original_names); | ||||
EXPECT_EQ(original_names.empty(), false); | EXPECT_EQ(original_names.empty(), false); | ||||
EXPECT_EQ(original_names[0], "while/COND0/cond/Data1"); | |||||
EXPECT_EQ(original_names[0], "WHILE0/while/COND0/cond/Data1"); | |||||
} | } | ||||
} // namespace ge | } // namespace ge |