Compare commits

...

6 Commits

Author SHA1 Message Date
  薛鹏 3c87768cc6 !590 去除根图节点上的dump origin name属性 2 years ago
  蒋荣强 33c9ce1ebd !588 fix opensource problem 3 years ago
  黄桂军 a1b1ffab38 !587 超过2G的onnx模型导入 3 years ago
  薛鹏 774e74a3f9 !581 修复dump子图节点时 origin name未拼接根图节点的问题 3 years ago
  王涛 8a3973ce99 !574 update owners 3 years ago
  王涛 c813469db2 !571 update .gitmodules. 3 years ago
33 changed files with 549 additions and 77 deletions
Split View
  1. +1
    -1
      .gitmodules
  2. +5
    -6
      OWNERS
  3. +1
    -1
      metadef
  4. +8
    -7
      parser/caffe/caffe_parser.cc
  5. +8
    -9
      parser/caffe/caffe_parser.h
  6. +6
    -6
      parser/common/acl_graph_parser_util.cc
  7. +4
    -3
      parser/common/acl_graph_parser_util.h
  8. +1
    -1
      parser/common/model_saver.cc
  9. +20
    -20
      parser/common/parser_fp16_t.cc
  10. +1
    -0
      parser/common/parser_types.cc
  11. +0
    -2
      parser/common/pass.h
  12. +3
    -3
      parser/common/pre_checker.cc
  13. +1
    -1
      parser/common/proto_file_parser.cc
  14. +1
    -1
      parser/common/register_tbe.cc
  15. +1
    -1
      parser/common/register_tbe.h
  16. +1
    -0
      parser/onnx/CMakeLists.txt
  17. +1
    -0
      parser/onnx/module.mk
  18. +150
    -0
      parser/onnx/onnx_file_constant_parser.cc
  19. +37
    -0
      parser/onnx/onnx_file_constant_parser.h
  20. +58
    -2
      parser/onnx/onnx_parser.cc
  21. +2
    -0
      parser/onnx/onnx_parser.h
  22. +1
    -0
      parser/onnx/onnx_util.h
  23. +1
    -1
      parser/onnx/subgraph_adapter/if_subgraph_adapter.cc
  24. +1
    -1
      parser/onnx/subgraph_adapter/if_subgraph_adapter.h
  25. +1
    -1
      parser/tensorflow/tensorflow_fusion_custom_parser_adapter.cc
  26. +1
    -1
      parser/tensorflow/tensorflow_fusion_custom_parser_adapter.h
  27. +6
    -7
      parser/tensorflow/tensorflow_parser.cc
  28. +20
    -0
      tests/depends/mmpa/src/mmpa_stub.cc
  29. +1
    -0
      tests/st/CMakeLists.txt
  30. +1
    -1
      tests/st/testcase/test_tensorflow_parser.cc
  31. +1
    -0
      tests/ut/parser/CMakeLists.txt
  32. +204
    -0
      tests/ut/parser/testcase/onnx_parser_testcase/onnx_parser_unittest.cc
  33. +1
    -1
      tests/ut/parser/testcase/tensorflow_parser_testcase/tensorflow_parser_unittest.cc

+ 1
- 1
.gitmodules View File

@@ -1,4 +1,4 @@
[submodule "metadef"]
path = metadef
url = https://gitee.com/ascend/metadef.git
branch = master
branch = r1.9.0

+ 5
- 6
OWNERS View File

@@ -1,13 +1,12 @@
approvers:
- ji_chen
- wqtshg
- ljl0711
- liu-jisheng
- startzgf168
- andylhy
- liyihan123
- zhangfan_hq
- lipeiyang3699
reviewers:
- xchu42
- sheng-nan
- tangqunzhang
- wangxiaotian22
- stevenaw
- stevenaw
- xuepenginnanjing

+ 1
- 1
metadef

@@ -1 +1 @@
Subproject commit 4f61fa7a7181e0e7dbdd4acbfaf99088a58d920d
Subproject commit 35de9facd31448995922246c5d2ffaa5a726bbb1

+ 8
- 7
parser/caffe/caffe_parser.cc View File

@@ -74,6 +74,7 @@ using std::ifstream;

namespace {
const size_t kMaxErrStrLen = 128U;
std::map<std::vector<std::string>, std::vector<std::string>> params_share_map;
} // namespace

namespace ge {
@@ -282,7 +283,7 @@ Status CheckPathValid(const char *model_path, const string &custom_proto, string
const set<string> CaffeWeightsParser::skiped_layer_type_ = {"Split", "SoftmaxWithLoss", "Accuracy", "Data",
"Dropout", "MultiLabelLMDB", "Python", "AnnotatedData"};

Status CaffeModelParser::ParseInput(domi::caffe::NetParameter &proto_message, bool &input_data_flag) {
Status CaffeModelParser::ParseInput(domi::caffe::NetParameter &proto_message, bool &input_data_flag) const {
if (proto_message.input_size() > 0) {
GELOGI("This net exsit input.");

@@ -456,7 +457,7 @@ Status CaffeModelParser::CustomProtoParse(const char *model_path, const string &
return ret;
}

Status CaffeModelParser::ReadModelWithoutWarning(const char *model_path, google::protobuf::Message *message) {
Status CaffeModelParser::ReadModelWithoutWarning(const char *model_path, google::protobuf::Message *message) const {
int32_t copy_fd = mmDup(STDERR_FILENO);
if (copy_fd < 0) {
char_t err_buf[kMaxErrStrLen + 1U] = {};
@@ -536,7 +537,7 @@ Status CaffeModelParser::ReadCaffeModelFromText(const char *model_path, google::

Status CaffeModelParser::ParseLayerParameter(const google::protobuf::Descriptor *layer_descriptor,
const google::protobuf::Message *message,
vector<ge::Operator> &operators) {
vector<ge::Operator> &operators) const {
auto field_name = layer_descriptor->FindFieldByName(kFieldName);
CAFFE_CHECK_NULL_AND_REPROT_ERRORMSG(field_name, "Does not find name in google::protobuf::Descriptor");
auto field_type = layer_descriptor->FindFieldByName(kFieldType);
@@ -624,7 +625,7 @@ void CaffeModelParser::AddOutputInfoToContext(string layer_name, int32_t top_ind
ge::GetParserContext().user_out_nodes.push_back(std::make_pair(layer_name, top_index));
}

Status CaffeModelParser::ParseOutputNodeTopInfo(const domi::caffe::NetParameter &proto_message) {
Status CaffeModelParser::ParseOutputNodeTopInfo(const domi::caffe::NetParameter &proto_message) const {
if (ge::GetParserContext().user_out_tensors.empty()) {
return SUCCESS;
}
@@ -932,7 +933,7 @@ Status CaffeModelParser::AddTensorDescToOpDesc(ge::OpDescPtr &op_desc, const dom
}

Status CaffeModelParser::AddTensorDescToOpDescByIr(ge::OpDescPtr &op_desc, const domi::caffe::LayerParameter &layer,
const string &op_type) {
const string &op_type) const {
if (std::find(kAddTensorIrSkipNodes.begin(), kAddTensorIrSkipNodes.end(), op_type) != kAddTensorIrSkipNodes.end()) {
op_desc = ge::parser::MakeShared<ge::OpDesc>(layer.name(), op_type);
GE_CHECK_NOTNULL(op_desc);
@@ -1202,7 +1203,7 @@ std::string CaffeModelParser::RemapTopNameByLayer(const domi::caffe::LayerParame
return (top_name + "_" + layer.name() + "_" + std::to_string(index));
}

Status CaffeModelParser::PreCheck(const domi::caffe::NetParameter &net) {
Status CaffeModelParser::PreCheck(const domi::caffe::NetParameter &net) const {
// Add layer in the model to PreChecker and check the general parameters
PreChecker::Instance().SetModelName(net.name());
for (int i = 0; i < net.layer_size(); i++) {
@@ -1977,7 +1978,7 @@ Status CaffeWeightsParser::ParseLayerField(const google::protobuf::Reflection *r
}

Status CaffeWeightsParser::ConvertBlobsProto(const google::protobuf::Message *message,
google::protobuf::Message *blobs) {
google::protobuf::Message *blobs) const {
const google::protobuf::Reflection *blobs_reflection = message->GetReflection();
CAFFE_CHECK_NULL_AND_REPROT_ERRORMSG(blobs_reflection, "Get Reflection failed in google::protobuf::Message");
vector<const google::protobuf::FieldDescriptor *> field_desc;


+ 8
- 9
parser/caffe/caffe_parser.h View File

@@ -52,12 +52,11 @@ using std::string;
using std::unordered_map;
using std::vector;
using domi::Status;
static std::map<std::vector<std::string>, std::vector<std::string>> params_share_map;

class PARSER_FUNC_VISIBILITY CaffeModelParser : public domi::ModelParser {
public:
CaffeModelParser() {}
virtual ~CaffeModelParser() override {}
~CaffeModelParser() override {}

/**
* @ingroup domi_omg
@@ -145,7 +144,7 @@ class PARSER_FUNC_VISIBILITY CaffeModelParser : public domi::ModelParser {
* @return SUCCESS build successfully
* @return FAILED build failed
*/
Status PreCheck(const domi::caffe::NetParameter &net);
Status PreCheck(const domi::caffe::NetParameter &net) const;

/**
* @ingroup domi_omg
@@ -156,7 +155,7 @@ class PARSER_FUNC_VISIBILITY CaffeModelParser : public domi::ModelParser {
* @return SUCCESS build successfully
* @return FAILED build failed
*/
Status ParseInput(domi::caffe::NetParameter &proto_message, bool &input_data_flag);
Status ParseInput(domi::caffe::NetParameter &proto_message, bool &input_data_flag) const;

/*
* @ingroup domi_omg
@@ -192,7 +191,7 @@ class PARSER_FUNC_VISIBILITY CaffeModelParser : public domi::ModelParser {
* @return SUCCESS read file successfully
* @return FAILED read file failed
*/
Status ReadModelWithoutWarning(const char *model_path, google::protobuf::Message *message);
Status ReadModelWithoutWarning(const char *model_path, google::protobuf::Message *message) const;

/*
* @ingroup domi_omg
@@ -214,7 +213,7 @@ class PARSER_FUNC_VISIBILITY CaffeModelParser : public domi::ModelParser {
* @return FAILED parse layer failed
*/
Status ParseLayerParameter(const google::protobuf::Descriptor *layer_descriptor,
const google::protobuf::Message *message, std::vector<ge::Operator> &operators);
const google::protobuf::Message *message, std::vector<ge::Operator> &operators) const;

/*
* @ingroup domi_omg
@@ -301,7 +300,7 @@ class PARSER_FUNC_VISIBILITY CaffeModelParser : public domi::ModelParser {
Status AddTensorDescToOpDesc(ge::OpDescPtr &op_desc, const domi::caffe::LayerParameter &layer) const;

Status AddTensorDescToOpDescByIr(ge::OpDescPtr &op_desc, const domi::caffe::LayerParameter &layer,
const string &op_type);
const string &op_type) const;

Status AddUserOutNodesTop();

@@ -321,7 +320,7 @@ class PARSER_FUNC_VISIBILITY CaffeModelParser : public domi::ModelParser {

void AddOutputInfoToContext(string layer_name, int32_t top_index) const;

Status ParseOutputNodeTopInfo(const domi::caffe::NetParameter &proto_message);
Status ParseOutputNodeTopInfo(const domi::caffe::NetParameter &proto_message) const;

Status SaveDataLayerTops(const domi::caffe::LayerParameter &layer);

@@ -405,7 +404,7 @@ class PARSER_FUNC_VISIBILITY CaffeWeightsParser : public domi::WeightsParser {
google::protobuf::Message *layer);

Status ConvertBlobsProto(const google::protobuf::Message *message,
google::protobuf::Message *blobs);
google::protobuf::Message *blobs) const;

Status ConvertBlobShapeProto(const google::protobuf::Message *message,
google::protobuf::Message *dest_message) const;


+ 6
- 6
parser/common/acl_graph_parser_util.cc View File

@@ -266,7 +266,7 @@ void AclGrphParseUtil::SetDefaultFormat() {
}
}

domi::Status AclGrphParseUtil::ParseAclOutputNodes(const string &out_nodes) {
domi::Status AclGrphParseUtil::ParseAclOutputNodes(const string &out_nodes) const {
try {
ge::GetParserContext().out_nodes_map.clear();
ge::GetParserContext().user_out_nodes.clear();
@@ -492,7 +492,7 @@ domi::Status AclGrphParseUtil::GetOutputLeaf(NodePtr node,
}

domi::Status AclGrphParseUtil::GetDefaultOutInfo(ge::ComputeGraphPtr &compute_graph,
std::vector<std::pair<ge::NodePtr, int32_t>> &output_nodes_info) {
std::vector<std::pair<ge::NodePtr, int32_t>> &output_nodes_info) const {
std::vector<std::pair<std::string, int32_t>> default_out_nodes = ge::GetParserContext().default_out_nodes;
if (!default_out_nodes.empty()) {
for (size_t i = 0; i < default_out_nodes.size(); ++i) {
@@ -587,7 +587,7 @@ domi::Status AclGrphParseUtil::CheckOptions(const std::map<AscendString, AscendS
}

string key_str = key_ascend;
auto it = ge::ir_option::ir_parser_suppported_options.find(key_str);
std::set<std::string>::const_iterator it = ge::ir_option::ir_parser_suppported_options.find(key_str);
if (it == ge::ir_option::ir_parser_suppported_options.end()) {
ErrorManager::GetInstance().ATCReportErrMessage("E10016", {"parameter", "opname"}, {"parser_params", key_str});
GELOGE(PARAM_INVALID, "[Check][Param] Input options include unsupported option(%s).Please check!", key_ascend);
@@ -651,7 +651,7 @@ domi::Status AclGrphParseUtil::ParseParamsBeforeGraph(const std::map<AscendStrin
}

domi::Status AclGrphParseUtil::ParseParamsAfterGraph(ge::Graph &graph,
const std::map<AscendString, AscendString> &parser_params) {
const std::map<AscendString, AscendString> &parser_params) const {
// support paragrams: input_fp16_nodes, is_input_adjust_hw_layout,
ComputeGraphPtr compute_graph = GraphUtils::GetComputeGraph(graph);
GE_CHECK_NOTNULL(compute_graph);
@@ -943,7 +943,7 @@ FMK_FUNC_HOST_VISIBILITY bool ValidateStr(const std::string &filePath, const std
regex_t reg;
int cflags = REG_EXTENDED | REG_NOSUB;
int ret = regcomp(&reg, mode.c_str(), cflags);
if (ret) {
if (ret != 0) {
regerror(ret, &reg, ebuff, kMaxBuffSize);
GELOGW("regcomp failed, reason: %s", ebuff);
regfree(&reg);
@@ -951,7 +951,7 @@ FMK_FUNC_HOST_VISIBILITY bool ValidateStr(const std::string &filePath, const std
}

ret = regexec(&reg, filePath.c_str(), 0, nullptr, 0);
if (ret) {
if (ret != 0) {
regerror(ret, &reg, ebuff, kMaxBuffSize);
GELOGE(ge::PARAM_INVALID, "[Invoke][RegExec] failed, reason: %s", ebuff);
regfree(&reg);


+ 4
- 3
parser/common/acl_graph_parser_util.h View File

@@ -44,7 +44,8 @@ class AclGrphParseUtil {
domi::Status SetOutputNodeInfo(ge::Graph &graph, const std::map<AscendString, AscendString> &parser_params);
domi::Status ParseParamsBeforeGraph(const std::map<AscendString, AscendString> &parser_params,
std::string &graph_name);
domi::Status ParseParamsAfterGraph(ge::Graph &graph, const std::map<AscendString, AscendString> &parser_params);
domi::Status ParseParamsAfterGraph(ge::Graph &graph, const std::map<AscendString,
AscendString> &parser_params) const;

private:
bool parser_initialized = false;
@@ -53,7 +54,7 @@ class AclGrphParseUtil {
void CreateOutputNodesInfo(std::vector<std::pair<ge::NodePtr, int32_t>> &output_nodes_info,
std::vector<std::string> &output_nodes_name) const;
static void SetDefaultFormat();
domi::Status ParseAclOutputNodes(const std::string &out_nodes);
domi::Status ParseAclOutputNodes(const std::string &out_nodes) const;
domi::Status ParseAclOutputFp16NodesFormat(const std::string &is_output_fp16) const;
domi::Status ParseAclEnableScope(const std::string &enable_scope_fusion_passes) const;
static void AddAttrsForInputNodes(const vector<string> &adjust_fp16_format_vec, const string &fp16_nodes_name,
@@ -61,7 +62,7 @@ class AclGrphParseUtil {
domi::Status ParseAclInputFp16Nodes(const ComputeGraphPtr &graph, const string &input_fp16_nodes,
const string &is_input_adjust_hw_layout) const;
domi::Status GetDefaultOutInfo(ge::ComputeGraphPtr &compute_graph,
std::vector<std::pair<ge::NodePtr, int32_t>> &output_nodes_info);
std::vector<std::pair<ge::NodePtr, int32_t>> &output_nodes_info) const;
};

namespace parser {


+ 1
- 1
parser/common/model_saver.cc View File

@@ -77,7 +77,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ModelSaver::SaveJsonToFi
const char *model_char = model_str.c_str();
uint32_t len = static_cast<uint32_t>(model_str.length());
// Write data to file
mmSsize_t mmpa_ret = mmWrite(fd, const_cast<void *>((const void *)model_char), len);
mmSsize_t mmpa_ret = mmWrite(fd, const_cast<void *>(static_cast<const void *>(model_char)), len);
if (mmpa_ret == EN_ERROR || mmpa_ret == EN_INVALID_PARAM) {
char_t err_buf[kMaxErrStrLen + 1U] = {};
const auto err_msg = mmGetErrorFormatMessage(mmGetErrorCode(), &err_buf[0], kMaxErrStrLen);


+ 20
- 20
parser/common/parser_fp16_t.cc View File

@@ -48,7 +48,7 @@ static bool IsRoundOne(uint64_t man, uint16_t trunc_len) {
uint64_t mask0 = 0x4;
uint64_t mask1 = 0x2;
uint64_t mask2;
uint16_t shift_out = static_cast<uint16_t>(trunc_len - kDim2);
uint16_t shift_out = static_cast<uint16_t>(trunc_len - static_cast<uint16_t>(kDim2));
mask0 = mask0 << shift_out;
mask1 = mask1 << shift_out;
mask2 = mask1 - 1;
@@ -89,7 +89,7 @@ static float Fp16ToFloat(const uint16_t &fp_val) {
int16_t hf_exp;
ExtractFp16(fp_val, hf_sign, hf_exp, hf_man);

while (hf_man && !(hf_man & kFp16ManHideBit)) {
while ((hf_man != 0U) && ((hf_man & kFp16ManHideBit) == 0U)) {
hf_man <<= 1;
hf_exp--;
}
@@ -120,7 +120,7 @@ static double Fp16ToDouble(const uint16_t &fp_val) {
int16_t hf_exp;
ExtractFp16(fp_val, hf_sign, hf_exp, hf_man);

while (hf_man && !(hf_man & kFp16ManHideBit)) {
while ((hf_man != 0U) && ((hf_man & kFp16ManHideBit) == 0U)) {
hf_man <<= 1;
hf_exp--;
}
@@ -128,7 +128,7 @@ static double Fp16ToDouble(const uint16_t &fp_val) {
uint64_t e_ret;
uint64_t m_ret;
uint64_t s_ret = hf_sign;
if (!hf_man) {
if (hf_man == 0U) {
e_ret = 0;
m_ret = 0;
} else {
@@ -256,7 +256,7 @@ static uint8_t Fp16ToUInt8(const uint16_t &fp_val) {
shift_out++;
}
}
if (!overflow_flag) {
if (overflow_flag == 0U) {
bool need_round = IsRoundOne(long_int_m, shift_out + kFp16ManLen);
m_ret = static_cast<uint8_t>((long_int_m >> (kFp16ManLen + shift_out)) & kBitLen8Max);
if (need_round && m_ret != kBitLen8Max) {
@@ -290,7 +290,7 @@ static uint16_t GetUint16ValByMan(uint16_t s_ret, const uint64_t &long_int_m, co
if (m_ret == 0) {
s_ret = 0;
}
return static_cast<uint16_t>((s_ret << kBitShift15) | (m_ret));
return static_cast<uint16_t>((s_ret << static_cast<uint16_t>(kBitShift15)) | (m_ret));
}

/// @ingroup fp16_t math conversion static method
@@ -431,7 +431,7 @@ static int32_t Fp16ToInt32(const uint16_t &fp_val) {
s_ret = 0;
}
// Generate final result
ret_v = (s_ret << kBitShift31) | (m_ret);
ret_v = (s_ret << static_cast<uint16_t>(kBitShift31)) | (m_ret);
}

return *(ge::PtrToPtr<uint32_t, uint32_t>(&ret_v));
@@ -565,7 +565,7 @@ static uint16_t Fp16Add(uint16_t v_1, uint16_t v_2) {
m_trunc = (m_b << (static_cast<uint16_t>(kBitShift32) - static_cast<uint16_t>(e_tmp)));
m_b = RightShift(m_b, e_tmp);
} else if (e_a < e_b) {
m_trunc = (m_a << (kBitShift32 - static_cast<uint16_t>(e_tmp)));
m_trunc = (m_a << (static_cast<uint16_t>(kBitShift32) - static_cast<uint16_t>(e_tmp)));
m_a = RightShift(m_a, e_tmp);
}
// calculate mantissav
@@ -603,7 +603,7 @@ static uint16_t Fp16Mul(uint16_t v_1, uint16_t v_2) {
m_a = m_a_tmp;
m_b = m_b_tmp;

e_ret = ((e_a + e_b) - kFp16ExpBias) - kDim10;
e_ret = ((e_a + e_b) - kFp16ExpBias) - static_cast<int16_t>(kDim10);
mul_m = m_a * m_b;
s_ret = s_a ^ s_b;

@@ -905,7 +905,7 @@ fp16_t &fp16_t::operator=(const float &f_val) {
fp16_t &fp16_t::operator=(const int8_t &i_val) {
uint16_t s_ret, e_ret, m_ret;

s_ret = static_cast<uint16_t>(((static_cast<uint8_t>(i_val)) & 0x80) >> kDim7);
s_ret = static_cast<uint16_t>(((static_cast<uint8_t>(i_val)) & 0x80) >> static_cast<uint8_t>(kDim7));
m_ret = static_cast<uint16_t>(((static_cast<uint8_t>(i_val)) & kInt8Max));

if (m_ret == 0) {
@@ -952,14 +952,14 @@ static void SetValByUint16Val(const uint16_t &input_val, const uint16_t &sign, u
uint16_t len = static_cast<uint16_t>(GetManBitLength(m_tmp));
if (static_cast<bool>(m_tmp)) {
int16_t e_ret;
if (len > kDim11) {
if (len > static_cast<uint16_t>(kDim11)) {
e_ret = kFp16ExpBias + kFp16ManLen;
uint16_t e_tmp = len - static_cast<uint16_t>(kDim11);
uint32_t trunc_mask = 1;
for (int i = 1; i < e_tmp; i++) {
trunc_mask = (trunc_mask << 1) + 1;
}
uint32_t m_trunc = (m_tmp & trunc_mask) << (kBitShift32 - e_tmp);
uint32_t m_trunc = (m_tmp & trunc_mask) << (static_cast<uint16_t>(kBitShift32) - e_tmp);
for (int i = 0; i < e_tmp; i++) {
m_tmp = (m_tmp >> 1);
e_ret = e_ret + 1;
@@ -991,7 +991,7 @@ fp16_t &fp16_t::operator=(const int16_t &i_val) {
val = 0;
} else {
uint16_t ui_val = *(ge::PtrToPtr<const int16_t, const int16_t>(&i_val));
auto s_ret = static_cast<uint16_t>(ui_val >> kBitShift15);
auto s_ret = static_cast<uint16_t>(ui_val >> static_cast<uint16_t>(kBitShift15));
if (static_cast<bool>(s_ret)) {
int16_t iValM = -i_val;
ui_val = *(ge::PtrToPtr<int16_t, uint16_t>(&iValM));
@@ -1018,7 +1018,7 @@ fp16_t &fp16_t::operator=(const uint16_t &ui_val) {
for (int i = 1; i < e_tmp; i++) {
trunc_mask = (trunc_mask << 1) + 1;
}
m_trunc = (m_ret & trunc_mask) << (kBitShift32 - e_tmp);
m_trunc = (m_ret & trunc_mask) << (static_cast<uint16_t>(kBitShift32) - e_tmp);
for (int i = 0; i < e_tmp; i++) {
m_ret = (m_ret >> 1);
e_ret = e_ret + 1;
@@ -1040,7 +1040,7 @@ fp16_t &fp16_t::operator=(const uint16_t &ui_val) {
}
} else {
e_ret = static_cast<int16_t>(kFp16ExpBias);
m_ret = m_ret << (kDim11 - len);
m_ret = m_ret << (static_cast<uint16_t>(kDim11) - len);
e_ret = e_ret + (len - 1);
}
val = FP16_CONSTRUCTOR(0u, static_cast<uint16_t>(e_ret), m_ret);
@@ -1062,7 +1062,7 @@ static void SetValByUint32Val(const uint32_t &input_val, const uint16_t &sign, u
for (int i = 1; i < e_tmp; i++) {
trunc_mask = (trunc_mask << 1) + 1;
}
m_trunc = (m_tmp & trunc_mask) << (kBitShift32 - e_tmp);
m_trunc = (m_tmp & trunc_mask) << (static_cast<uint16_t>(kBitShift32) - e_tmp);
for (int i = 0; i < e_tmp; i++) {
m_tmp = (m_tmp >> 1);
e_ret = e_ret + 1;
@@ -1085,7 +1085,7 @@ static void SetValByUint32Val(const uint32_t &input_val, const uint16_t &sign, u
}
} else {
e_ret = static_cast<int16_t>(kFp16ExpBias);
m_tmp = m_tmp << (kDim11 - len);
m_tmp = m_tmp << (static_cast<uint16_t>(kDim11) - len);
e_ret = e_ret + (len - 1);
}
auto m_ret = static_cast<uint16_t>(m_tmp);
@@ -1097,7 +1097,7 @@ fp16_t &fp16_t::operator=(const int32_t &i_val) {
val = 0;
} else {
uint32_t ui_val = *(ge::PtrToPtr<const int32_t, const uint32_t>(&i_val));
auto s_ret = static_cast<uint16_t>(ui_val >> kBitShift31);
auto s_ret = static_cast<uint16_t>(ui_val >> static_cast<uint16_t>(kBitShift31));
if (static_cast<bool>(s_ret)) {
int32_t iValM = -i_val;
ui_val = *(ge::PtrToPtr<int32_t, uint32_t>(&iValM));
@@ -1124,7 +1124,7 @@ fp16_t &fp16_t::operator=(const uint32_t &ui_val) {
for (int i = 1; i < e_tmp; i++) {
trunc_mask = (trunc_mask << 1) + 1;
}
m_trunc = (m_tmp & trunc_mask) << static_cast<uint32_t>(kBitShift32 - e_tmp);
m_trunc = (m_tmp & trunc_mask) << static_cast<uint32_t>(static_cast<uint16_t>(kBitShift32) - e_tmp);
for (uint16_t i = 0; i < e_tmp; i++) {
m_tmp = (m_tmp >> 1);
e_ret = e_ret + 1;
@@ -1147,7 +1147,7 @@ fp16_t &fp16_t::operator=(const uint32_t &ui_val) {
}
} else {
e_ret = static_cast<int16_t>(kFp16ExpBias);
m_tmp = m_tmp << (kDim11 - len);
m_tmp = m_tmp << (static_cast<uint16_t>(kDim11) - len);
e_ret = e_ret + (len - 1);
}
auto m_ret = static_cast<uint16_t>(m_tmp);


+ 1
- 0
parser/common/parser_types.cc View File

@@ -131,6 +131,7 @@ const char *YOLO2REORG = "Yolo2Reorg";
const char *REDUCESUM = "ReduceSum";
const char *SUM = "Sum";
const char *CONSTANT = "Const";
const char *FILECONSTANT = "FileConstant";
const char *RESIZEBILINEAR = "ResizeBilinear";
const char *RESIZEBILINEARGRAD = "ResizeBilinearGrad";
const char *MAXIMUM = "Maximum";


+ 0
- 2
parser/common/pass.h View File

@@ -19,8 +19,6 @@

#include <memory>

#include "common/fmk_error_codes.h"

namespace ge {
///
/// @ingroup domi_omg


+ 3
- 3
parser/common/pre_checker.cc View File

@@ -218,9 +218,9 @@ Status PreChecker::Save(const string &file) {

// Constructing JSON information of operators in order of network
for (auto id : ops_) {
auto iter = op_map_.find(id);
GE_CHK_BOOL_RET_STATUS(iter != op_map_.end(), FAILED, "[Check][Param] don't find this op.");
Info &info = iter->second;
std::map<OpId, Info>::const_iterator iter = op_map_.find(id);
GE_CHK_BOOL_RET_STATUS(iter != op_map_.cend(), FAILED, "[Check][Param] don't find this op.");
const Info &info = iter->second;

// Initialization operator general information
nlohmann::json op = {{kKeyOpName, info.name}, {kKeyOpType, info.type}};


+ 1
- 1
parser/common/proto_file_parser.cc View File

@@ -67,7 +67,7 @@ bool GetIdentifier(const std::string &line, int &identifier) {
break;
}
if (line[i] >= kMinNum && line[i] <= kMaxNum) {
identifier = identifier * kDecimalMulti + line[i] - kMinNum;
identifier = identifier * kDecimalMulti + static_cast<int>(line[i]) - static_cast<int>(kMinNum);
}
if (identifier > kMaxIdentifier || identifier < 0) {
return false;


+ 1
- 1
parser/common/register_tbe.cc View File

@@ -75,7 +75,7 @@ bool OpRegistrationTbe::Finalize(const OpRegistrationData &reg_data, bool is_tra
return ret;
}

bool OpRegistrationTbe::RegisterParser(const OpRegistrationData &reg_data) {
bool OpRegistrationTbe::RegisterParser(const OpRegistrationData &reg_data) const {
if (reg_data.GetFrameworkType() == domi::TENSORFLOW) {
std::shared_ptr<OpParserFactory> factory = OpParserFactory::Instance(domi::TENSORFLOW);
if (factory == nullptr) {


+ 1
- 1
parser/common/register_tbe.h View File

@@ -27,7 +27,7 @@ class OpRegistrationTbe {
bool Finalize(const OpRegistrationData &reg_data, bool is_train = false);

private:
bool RegisterParser(const OpRegistrationData &reg_data);
bool RegisterParser(const OpRegistrationData &reg_data) const;
};
} // namespace ge


+ 1
- 0
parser/onnx/CMakeLists.txt View File

@@ -4,6 +4,7 @@ set(SRC_LIST
"onnx_data_parser.cc"
"onnx_util.cc"
"onnx_constant_parser.cc"
"onnx_file_constant_parser.cc"
"subgraph_adapter/if_subgraph_adapter.cc"
"subgraph_adapter/subgraph_adapter_factory.cc"
)


+ 1
- 0
parser/onnx/module.mk View File

@@ -17,6 +17,7 @@ PARSER_ONNX_SRC_FILES := \
onnx_data_parser.cc \
onnx_util.cc \
onnx_constant_parser.cc \
onnx_file_constant_parser.cc \
proto/onnx/ge_onnx.proto \
proto/om.proto \



+ 150
- 0
parser/onnx/onnx_file_constant_parser.cc View File

@@ -0,0 +1,150 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "onnx_file_constant_parser.h"
#include <vector>

#include "graph/ge_tensor.h"
#include "parser/common/op_parser_factory.h"
#include "parser/onnx/onnx_util.h"
#include "framework/common/util.h"
#include "framework/common/types.h"

using ge::onnx::NodeProto;
using ge::onnx::TensorProto;
using domi::ONNX;
using GeShape = ge::GeShape;
using GeTensorDesc = ge::GeTensorDesc;
using namespace ge::parser;

namespace {
const std::string kAttrShape = "shape";
const std::string kAttrDataType = "dtype";
const std::string kFileConstantPath = "file_constant_path";
const std::string kLocation = "location";
const std::string kOffset = "offset";
const int64_t kOffsetCoefficient = 4096;
const char *const kFileConstant = "FileConstant";
}
namespace ge {
Status OnnxFileConstantParser::ParseParams(const Message *op_src, ge::Operator &op_def) {
GE_CHECK_NOTNULL(op_src);
const ge::onnx::NodeProto *node = reinterpret_cast<const ge::onnx::NodeProto *>(op_src);
GELOGD("Onnx op node name = %s, op type= %s, parse params", node->name().c_str(), node->op_type().c_str());

ge::onnx::TensorProto tensor_proto;
if (GetTensorProto(node, tensor_proto) != SUCCESS) {
REPORT_INNER_ERROR("E19999", "node[%s] get tensor failed", node->name().c_str());
GELOGE(domi::PARAM_INVALID, "[Get][TensorProto] node[%s] get tensor failed", node->name().c_str());
return FAILED;
}
if (ParseDataType(tensor_proto, op_def) != SUCCESS) {
REPORT_INNER_ERROR("E19999", "node[%s] parse data type failed", node->name().c_str());
GELOGE(domi::PARAM_INVALID, "[Parse][Shape] node[%s] parse data type failed", node->name().c_str());
return FAILED;
}
if (ParsePath(tensor_proto, op_def) != SUCCESS) {
REPORT_INNER_ERROR("E19999", "node[%s] parse file path failed", node->name().c_str());
GELOGE(domi::PARAM_INVALID, "[Parse][Shape] node[%s] parse file path failed", node->name().c_str());
return FAILED;
}
ParseShape(tensor_proto, op_def);
return SUCCESS;
}

Status OnnxFileConstantParser::GetTensorProto(const ge::onnx::NodeProto *node_proto,
ge::onnx::TensorProto &tensor_proto) {
for (const auto &it : node_proto->attribute()) {
if (it.name() != ge::kAttrNameValue) {
continue;
}
tensor_proto = it.t();
return SUCCESS;
}
REPORT_INNER_ERROR("E19999", "node_proto[%s] get value failed", node_proto->name().c_str());
GELOGE(ge::PARAM_INVALID, "[Get][TensorProto] node_proto[%s] get value failed", node_proto->name().c_str());
return FAILED;
}

void OnnxFileConstantParser::ParseShape(const ge::onnx::TensorProto &tensor_proto, ge::Operator &op_def) {
std::vector<int64_t> tmp_shape;
for (int i = 0; i < tensor_proto.dims_size(); i++) {
tmp_shape.push_back(tensor_proto.dims(i));
}
op_def.SetAttr(kAttrShape.c_str(), tmp_shape);
}

Status OnnxFileConstantParser::ParseDataType(const ge::onnx::TensorProto &tensor_proto, ge::Operator &op_def) {
int64_t data_type = tensor_proto.data_type();
ge::DataType type = ge::OnnxUtil::ConvertOnnxDataType(data_type);
if (type >= ge::DataType::DT_UNDEFINED) {
REPORT_INNER_ERROR("E19999", "tensor_proto date type %ld is undefined.", data_type);
GELOGE(domi::PARAM_INVALID, "[Check][Param] tensor_proto date type %ld is undefined.", data_type);
return FAILED;
}

op_def.SetAttr(kAttrDataType.c_str(), type);
return SUCCESS;
}

Status OnnxFileConstantParser::ParsePath(const ge::onnx::TensorProto &tensor_proto, ge::Operator &op_def) {
ge::NamedAttrs attrs;
for (int32_t i = 0; i < tensor_proto.external_data_size(); ++i) {
const ge::onnx::StringStringEntryProto &string_proto = tensor_proto.external_data(i);
if (SetPathAttr(string_proto, attrs) != SUCCESS) {
REPORT_INNER_ERROR("E19999", "external tensor proto[%s] parse attrs failed.", tensor_proto.name().c_str());
GELOGE(domi::PARAM_INVALID, "external tensor proto[%s] parse attrs failed.", tensor_proto.name().c_str());
return FAILED;
}
}

if (!attrs.HasAttr(kLocation)) {
REPORT_INNER_ERROR("E19999", "external tensor proto[%s] must contain location.", tensor_proto.name().c_str());
GELOGE(domi::PARAM_INVALID, "external tensor proto[%s] must contain location.", tensor_proto.name().c_str());
return FAILED;
}
op_def.SetAttr(kFileConstantPath.c_str(), attrs);
return SUCCESS;
}

Status OnnxFileConstantParser::SetPathAttr(const ge::onnx::StringStringEntryProto &string_proto,
ge::NamedAttrs &attrs) {
if (string_proto.key() == kLocation) {
AttrUtils::SetStr(attrs, kLocation, string_proto.value());
} else {
int64_t value;
try {
value = stol(string_proto.value());
} catch (const std::exception &e) {
REPORT_INNER_ERROR("E19999", "Convert %s to int64_t value failed:%s", string_proto.value().c_str(), e.what());
GELOGE(domi::PARAM_INVALID, "Convert %s to int64_t value failed:%s", string_proto.value().c_str(), e.what());
return FAILED;
}
if (string_proto.key() == kOffset) {
if (std::numeric_limits<int64_t>::max() / kOffsetCoefficient < value) {
REPORT_INNER_ERROR("E19999", "overflow, kOffsetCoefficient[%ld], value[%ld]", kOffsetCoefficient, value);
GELOGE(domi::PARAM_INVALID, "overflow, kOffsetCoefficient[%ld], value[%ld]", kOffsetCoefficient, value);
return FAILED;
}
value *= kOffsetCoefficient;
}
AttrUtils::SetInt(attrs, string_proto.key(), value);
}
return SUCCESS;
}

REGISTER_OP_PARSER_CREATOR(ONNX, kFileConstant, OnnxFileConstantParser);
} // namespace ge

+ 37
- 0
parser/onnx/onnx_file_constant_parser.h View File

@@ -0,0 +1,37 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef GE_PARSER_ONNX_ONNX_FILE_CONSTANT_PARSER_H_
#define GE_PARSER_ONNX_ONNX_FILE_CONSTANT_PARSER_H_

#include "parser/onnx/onnx_op_parser.h"
#include "proto/onnx/ge_onnx.pb.h"

namespace ge {
class PARSER_FUNC_VISIBILITY OnnxFileConstantParser : public OnnxOpParser {
public:
Status ParseParams(const Message *op_src, ge::Operator &op_def) override;

private:
Status ParsePath(const ge::onnx::TensorProto &tensor_proto, ge::Operator &op_def);
Status ParseDataType(const ge::onnx::TensorProto &tensor_proto, ge::Operator &op_def);
void ParseShape(const ge::onnx::TensorProto &tensor_proto, ge::Operator &op_def);
Status GetTensorProto(const ge::onnx::NodeProto *node_proto, ge::onnx::TensorProto &tensor_proto);
Status SetPathAttr(const ge::onnx::StringStringEntryProto &string_proto, ge::NamedAttrs &attrs);
};
} // namespace ge

#endif // GE_PARSER_ONNX_ONNX_FILE_CONSTANT_PARSER_H_

+ 58
- 2
parser/onnx/onnx_parser.cc View File

@@ -44,6 +44,12 @@
#include "graph/utils/node_utils.h"
#include "graph/utils/type_utils.h"
#include "subgraph_adapter/subgraph_adapter_factory.h"
#include "framework/common/types.h"
#include "mmpa/mmpa_api.h"

namespace {
const std::string kLocation = "location";
}

namespace ge {
graphStatus PrepareBeforeParse(AclGrphParseUtil &acl_graph_parse_util,
@@ -160,7 +166,8 @@ namespace ge {
namespace {
const std::map<std::string, std::string> kOnnxOpMap = {
{ge::kOpTypeInput, ge::parser::DATA},
{ge::kOpTypeConstant, ge::parser::CONSTANT}
{ge::kOpTypeConstant, ge::parser::CONSTANT},
{ge::kFileConstant, ge::parser::FILECONSTANT}
};
const int64_t kDimValue = 1;

@@ -350,12 +357,16 @@ Status OnnxModelParser::ParseInitializer(ge::onnx::GraphProto &onnx_graph,
ge::onnx::NodeProto *const_node = onnx_graph.add_node();
std::string output_name = it.first + "_" + to_string(index++);
const_node->set_name(output_name);
const_node->set_op_type(ge::kOpTypeConstant);
const_node->add_output(it.first);
ge::onnx::AttributeProto *attribute = const_node->add_attribute();
attribute->set_name(ge::kAttrNameValue);
ge::onnx::TensorProto *attribute_t = attribute->mutable_t();
*attribute_t = it.second;
if (it.second.data_location() == ge::onnx::TensorProto_DataLocation_EXTERNAL) {
const_node->set_op_type(kFileConstant);
} else {
const_node->set_op_type(ge::kOpTypeConstant);
}
}

return SUCCESS;
@@ -723,6 +734,51 @@ Status OnnxModelParser::GetModelFromFile(const char *file, ge::onnx::ModelProto
GELOGE(PARAM_INVALID, "[Read][ModeFile] failed.");
return FAILED;
}

if (SetExternalPath(file, onnx_model) != SUCCESS) {
REPORT_CALL_ERROR("E19999", "Set external path failed, file[%s]", file);
GELOGE(PARAM_INVALID, "[Set][ExternalPath] failed.");
return PARAM_INVALID;
}
return SUCCESS;
}

Status OnnxModelParser::SetExternalPath(const char *file, ge::onnx::ModelProto &onnx_model) const {
std::string real_path = ge::parser::RealPath(file);
const size_t file_len = real_path.length();
std::unique_ptr<char[]> tmp_file(new (std::nothrow) char[file_len + 1U]);
GE_CHECK_NOTNULL(tmp_file);

const auto ret = strncpy_s(tmp_file.get(), file_len + 1U, real_path.c_str(), file_len);
if (ret != EN_OK) {
REPORT_CALL_ERROR("E19999", "strncpy_s failed, src=%p, dst=%p, src_len=%zu, dst_len=%zu, ret=%d.",
real_path.c_str(), tmp_file.get(), file_len, file_len + 1U, ret);
GELOGE(FAILED, "strncpy_s failed, src=%p, dst=%p, src_len=%zu, dst_len=%zu.",
real_path.c_str(), tmp_file.get(), file_len, file_len + 1U);
return FAILED;
}
const char *const dir = mmDirName(tmp_file.get());
GE_CHECK_NOTNULL(dir);

const ge::onnx::GraphProto &onnx_graph = onnx_model.graph();
for (int32_t i = 0; i < onnx_graph.initializer_size(); ++i) {
const ge::onnx::TensorProto &initializer_tensor = onnx_graph.initializer(i);
if (initializer_tensor.data_location() != ge::onnx::TensorProto_DataLocation_EXTERNAL) {
continue;
}
for (int32_t j = 0; j < initializer_tensor.external_data_size(); ++j) {
ge::onnx::StringStringEntryProto &string_proto =
const_cast<ge::onnx::StringStringEntryProto &>(initializer_tensor.external_data(j));
if (string_proto.key() != kLocation) {
continue;
}
const std::string &file_name = string_proto.value();
const std::string new_file = std::string(dir) + MMPA_PATH_SEPARATOR_STR + file_name;
GELOGD("[%s] is external data. concat dir[%s] and file_name[%s], new_file[%s]",
initializer_tensor.name().c_str(), dir, file_name.c_str(), new_file.c_str());
string_proto.set_value(new_file);
}
}
return SUCCESS;
}



+ 2
- 0
parser/onnx/onnx_parser.h View File

@@ -126,6 +126,8 @@ class PARSER_FUNC_VISIBILITY OnnxModelParser : public domi::ModelParser {
Status GetModelFromFile(const char *file, ge::onnx::ModelProto &onnx_model) const;

Status SetExternalPath(const char *file, ge::onnx::ModelProto &onnx_model) const;

Status GetModelFromMemory(const char *data, uint32_t size, ge::onnx::ModelProto &onnx_model) const;

Status ModelParseToGraph(const ge::onnx::ModelProto &onnx_model, ge::Graph &root_graph);


+ 1
- 0
parser/onnx/onnx_util.h View File

@@ -48,6 +48,7 @@ const char *const kAttrNameIndex = "index";
const char *const kAttrNameIsSubgraphOp = "is_subgraph_op";
const char *const kOpTypeConstant = "Constant";
const char *const kOpTypeInput = "Input";
const char *const kFileConstant = "FileConstant";

class OnnxUtil {
public:


+ 1
- 1
parser/onnx/subgraph_adapter/if_subgraph_adapter.cc View File

@@ -45,7 +45,7 @@ domi::Status IfSubgraphAdapter::AdaptAndFindAllSubgraphs(

domi::Status IfSubgraphAdapter::ParseIfNodeSubgraphs(
ge::onnx::NodeProto *parent_node, std::vector<ge::onnx::GraphProto *> &onnx_graphs,
std::map<std::string, ge::onnx::GraphProto *> &name_to_onnx_graph, const std::string &parent_graph_name) {
std::map<std::string, ge::onnx::GraphProto *> &name_to_onnx_graph, const std::string &parent_graph_name) const {
if (parent_node->attribute_size() != kIfNodeAttrSize) {
GELOGE(FAILED, "[Parse][Node] Invalid graph, if node attribute size:%d must be 2.", parent_node->attribute_size());
REPORT_INNER_ERROR("E19999", "Invalid graph, if node attribute size:%d must be 2.", parent_node->attribute_size());


+ 1
- 1
parser/onnx/subgraph_adapter/if_subgraph_adapter.h View File

@@ -32,7 +32,7 @@ class PARSER_FUNC_VISIBILITY IfSubgraphAdapter : public SubgraphAdapter {
private:
domi::Status ParseIfNodeSubgraphs(ge::onnx::NodeProto *parent_node, std::vector<ge::onnx::GraphProto *> &onnx_graphs,
std::map<std::string, ge::onnx::GraphProto *> &name_to_onnx_graph,
const std::string &parent_graph_name);
const std::string &parent_graph_name) const;
domi::Status GetSubgraphsAllInputs(ge::onnx::GraphProto &onnx_graph, std::set<std::string> &all_inputs) const;
void AddInputNodeForGraph(const std::set<std::string> &all_inputs, ge::onnx::GraphProto &onnx_graph) const;
void AddInputForParentNode(const std::set<std::string> &all_inputs, ge::onnx::NodeProto &parent_node) const;


+ 1
- 1
parser/tensorflow/tensorflow_fusion_custom_parser_adapter.cc View File

@@ -59,7 +59,7 @@ Status TensorFlowFusionCustomParserAdapter::ParseParams(const vector<const NodeD
}

Status TensorFlowFusionCustomParserAdapter::ParseParams(const std::vector<ge::Operator> &v_input_const,
ge::NodePtr &node) {
ge::NodePtr &node) const {
GE_CHECK_NOTNULL(node);
auto op_dest = node->GetOpDesc();
GE_CHECK_NOTNULL(op_dest);


+ 1
- 1
parser/tensorflow/tensorflow_fusion_custom_parser_adapter.h View File

@@ -42,7 +42,7 @@ class PARSER_FUNC_VISIBILITY TensorFlowFusionCustomParserAdapter : public Tensor
* @return FAILED parse failed
* @author
*/
Status ParseParams(const std::vector<ge::Operator> &v_input_const, ge::NodePtr &node);
Status ParseParams(const std::vector<ge::Operator> &v_input_const, ge::NodePtr &node) const;
};
} // namespace ge



+ 6
- 7
parser/tensorflow/tensorflow_parser.cc View File

@@ -195,11 +195,10 @@ void AddDumpOriginName(const std::string& subgraph_name, const ge::NodePtr paren
auto parend_desc = parent_node->GetOpDesc();
(void)ge::AttrUtils::GetListStr(parend_desc, ge::ATTR_NAME_DATA_DUMP_ORIGIN_OP_NAMES, original_names);
if (original_names.empty()) {
original_names.emplace_back(string(subgraph_name).append("/").append(node->GetName()));
} else {
// for fusion node also used original_names[0]
(void)original_names[0].append("/").append(subgraph_name).append("/").append(node->GetName());
original_names.emplace_back(parent_node->GetName());
}
// for fusion node also used original_names[0]
(void)original_names[0].append("/").append(subgraph_name).append("/").append(node->GetName());

if (!ge::AttrUtils::SetListStr(node->GetOpDesc(), ge::ATTR_NAME_DATA_DUMP_ORIGIN_OP_NAMES, original_names)) {
GELOGW("Set %s to %s fail.", ge::ATTR_NAME_DATA_DUMP_ORIGIN_OP_NAMES.c_str(), node->GetOpDesc()->GetName().c_str());
@@ -3050,7 +3049,7 @@ Status TensorFlowModelParser::TrimGraphByInput(const domi::tensorflow::GraphDef
GE_CHECK_NOTNULL(current_node);
for (const string &input_name : current_node->input()) {
string input_node_name = NodeNameFromInput(input_name);
if (!delete_nodes.count(input_node_name)) {
if (delete_nodes.count(input_node_name) == 0U) {
next_inputs.insert(input_node_name);
}
}
@@ -3063,7 +3062,7 @@ Status TensorFlowModelParser::TrimGraphByInput(const domi::tensorflow::GraphDef
if (static_cast<bool>(input_nodes.count(node.name()))) {
*(filtered_graph_def.mutable_node()->Add()) = node;
}
if (!delete_nodes.count(node.name())) {
if (delete_nodes.count(node.name()) == 0U) {
*(filtered_graph_def.mutable_node()->Add()) = node;
}
}
@@ -3126,7 +3125,7 @@ Status TensorFlowModelParser::TrimGraphByOutput(const domi::tensorflow::GraphDef
GE_CHECK_NOTNULL(current_node);
for (const string &input_name : current_node->input()) {
string input_node_name = NodeNameFromInput(input_name);
if (!required_nodes.count(input_node_name)) {
if (required_nodes.count(input_node_name) == 0U) {
next_inputs.insert(input_node_name);
}
}


+ 20
- 0
tests/depends/mmpa/src/mmpa_stub.cc View File

@@ -15,6 +15,7 @@
*/

#include "mmpa/mmpa_api.h"
#include <string>

typedef int mmErrorMSg;

@@ -301,3 +302,22 @@ CHAR *mmGetErrorFormatMessage(mmErrorMSg errnum, CHAR *buf, mmSize size)
}
return strerror_r(errnum, buf, size);
}

CHAR *mmDirName(CHAR *path) {
if (path == NULL) {
return NULL;
}
#if (defined(_WIN32) || defined(_WIN64) || defined(_MSC_VER))
char separator = '\\';
#else
char separator = '/';
#endif
std::string path_str(path);
const size_t last_sep_pos = path_str.rfind(separator);
if (last_sep_pos == std::string::npos) {
return NULL;
}

path[last_sep_pos] = '\0';
return path;
}

+ 1
- 0
tests/st/CMakeLists.txt View File

@@ -277,6 +277,7 @@ set(PARSER_SRC_FILES
"${PARSER_DIR}/parser/common/thread_pool.cc"
"${PARSER_DIR}/parser/common/auto_mapping_subgraph_io_index_func.cc"
"${PARSER_DIR}/parser/onnx/onnx_constant_parser.cc"
"${PARSER_DIR}/parser/onnx/onnx_file_constant_parser.cc"
"${PARSER_DIR}/parser/onnx/onnx_custom_parser_adapter.cc"
"${PARSER_DIR}/parser/onnx/onnx_data_parser.cc"
"${PARSER_DIR}/parser/onnx/onnx_parser.cc"


+ 1
- 1
tests/st/testcase/test_tensorflow_parser.cc View File

@@ -4245,7 +4245,7 @@ TEST_F(STestTensorflowParser, AddDumpOriginName_test)
std::vector<std::string> original_names;
(void)ge::AttrUtils::GetListStr(desc, ge::ATTR_NAME_DATA_DUMP_ORIGIN_OP_NAMES, original_names);
EXPECT_EQ(original_names.empty(), false);
EXPECT_EQ(original_names[0], "while/COND0/cond/Data1");
EXPECT_EQ(original_names[0], "WHILE0/while/COND0/cond/Data1");
}

} // namespace ge

+ 1
- 0
tests/ut/parser/CMakeLists.txt View File

@@ -278,6 +278,7 @@ set(PARSER_SRC_FILES
"${PARSER_DIR}/parser/common/thread_pool.cc"
"${PARSER_DIR}/parser/common/auto_mapping_subgraph_io_index_func.cc"
"${PARSER_DIR}/parser/onnx/onnx_constant_parser.cc"
"${PARSER_DIR}/parser/onnx/onnx_file_constant_parser.cc"
"${PARSER_DIR}/parser/onnx/onnx_custom_parser_adapter.cc"
"${PARSER_DIR}/parser/onnx/onnx_data_parser.cc"
"${PARSER_DIR}/parser/onnx/onnx_parser.cc"


+ 204
- 0
tests/ut/parser/testcase/onnx_parser_testcase/onnx_parser_unittest.cc View File

@@ -30,6 +30,7 @@
#define protected public
#define private public
#include "parser/onnx/onnx_constant_parser.h"
#include "parser/onnx/onnx_file_constant_parser.h"
#include "parser/onnx/onnx_util.h"
#include "parser/onnx/onnx_parser.h"
#undef protected
@@ -316,6 +317,190 @@ TEST_F(UtestOnnxParser, OnnxConstantParser_ParseConvertDataType_test)
EXPECT_EQ(ret, FAILED);
}

TEST_F(UtestOnnxParser, FileConstantGetTensorProto)
{
OnnxFileConstantParser parser;
ge::onnx::NodeProto input_node;
ge::onnx::TensorProto tensor_proto;
Status ret = parser.GetTensorProto(&input_node, tensor_proto);
EXPECT_EQ(ret, FAILED);

ge::onnx::AttributeProto *attribute = input_node.add_attribute();
attribute->set_name("attribute");
attribute = input_node.add_attribute();
attribute->set_name("value");

ge::onnx::TensorProto *attribute_tensor = attribute->mutable_t();
*attribute_tensor = tensor_proto;
ret = parser.GetTensorProto(&input_node, tensor_proto);
EXPECT_EQ(ret, SUCCESS);
}

TEST_F(UtestOnnxParser, FileConstantParseShape)
{
OnnxFileConstantParser parser;
ge::onnx::TensorProto tensor_proto;
tensor_proto.add_dims(4);
tensor_proto.add_dims(2);
ge::OpDescPtr op_desc_src = std::make_shared<ge::OpDesc>("file_constant", "FileConstant");
ge::Operator op = ge::OpDescUtils::CreateOperatorFromOpDesc(op_desc_src);

parser.ParseShape(tensor_proto, op);

std::vector<int64_t> attr_value;
op.GetAttr("shape", attr_value);
EXPECT_EQ(attr_value.size(), 2U);
if (attr_value.size() == 2U) {
EXPECT_EQ(attr_value[0], 4);
EXPECT_EQ(attr_value[1], 2);
}
}

TEST_F(UtestOnnxParser, FileConstantParseDataType)
{
OnnxFileConstantParser parser;
ge::onnx::TensorProto tensor_proto;
tensor_proto.set_data_type(OnnxDataType::UNDEFINED);
ge::OpDescPtr op_desc_src = std::make_shared<ge::OpDesc>("file_constant", "FileConstant");
ge::Operator op = ge::OpDescUtils::CreateOperatorFromOpDesc(op_desc_src);

Status ret = parser.ParseDataType(tensor_proto, op);
EXPECT_EQ(ret, FAILED);

tensor_proto.set_data_type(OnnxDataType::UINT8);
ret = parser.ParseDataType(tensor_proto, op);
EXPECT_EQ(ret, SUCCESS);
ge::DataType attr_value;
op.GetAttr("dtype", attr_value);
EXPECT_EQ(attr_value, ge::DataType::DT_UINT8);
}

TEST_F(UtestOnnxParser, FileConstantParseAttr)
{
OnnxFileConstantParser parser;
ge::onnx::StringStringEntryProto string_proto;
ge::NamedAttrs attrs;

// test location
string_proto.set_key("location");
string_proto.set_value("/usr/local");
Status ret = parser.SetPathAttr(string_proto, attrs);
EXPECT_EQ(ret, SUCCESS);
std::string attr_value;
AttrUtils::GetStr(attrs, "location", attr_value);
EXPECT_EQ(attr_value, "/usr/local");

// test offset
string_proto.set_key("offset");
string_proto.set_value("123");
ret = parser.SetPathAttr(string_proto, attrs);
EXPECT_EQ(ret, SUCCESS);
int64_t offset_value;
AttrUtils::GetInt(attrs, "offset", offset_value);
EXPECT_EQ(offset_value, 123 * 4096);

// offset overflow
string_proto.set_key("offset");
string_proto.set_value("9223372036854775800");
ret = parser.SetPathAttr(string_proto, attrs);
EXPECT_EQ(ret, FAILED);

// itol exception
string_proto.set_key("offset");
string_proto.set_value("999999999999999999999999999999999999");
ret = parser.SetPathAttr(string_proto, attrs);
EXPECT_EQ(ret, FAILED);
}

TEST_F(UtestOnnxParser, FileConstantParsePath)
{
OnnxFileConstantParser parser;
ge::onnx::TensorProto tensor_proto;
ge::OpDescPtr op_desc_src = std::make_shared<ge::OpDesc>("file_constant", "FileConstant");
ge::Operator op = ge::OpDescUtils::CreateOperatorFromOpDesc(op_desc_src);


// without location, error
auto ret = parser.ParsePath(tensor_proto, op);
EXPECT_EQ(ret, FAILED);

// SetPathAttr error
ge::onnx::StringStringEntryProto *offset_proto = tensor_proto.add_external_data();
offset_proto->set_key("offset");
offset_proto->set_value("999999999999999999999999999999");
ret = parser.ParsePath(tensor_proto, op);
EXPECT_EQ(ret, FAILED);

// has location, success
ge::onnx::StringStringEntryProto *string_proto = tensor_proto.add_external_data();
string_proto->set_key("location");
string_proto->set_value("/usr/local");
offset_proto->set_key("offset");
offset_proto->set_value("0");
ret = parser.ParsePath(tensor_proto, op);
EXPECT_EQ(ret, SUCCESS);

// check location
std::string attr_value;
ge::NamedAttrs attrs;
AttrUtils::GetNamedAttrs(op_desc_src, "file_constant_path", attrs);
AttrUtils::GetStr(attrs, "location", attr_value);
EXPECT_EQ(attr_value, "/usr/local");
}

TEST_F(UtestOnnxParser, FileConstantParseParam)
{
OnnxFileConstantParser parser;
ge::onnx::NodeProto input_node;
ge::OpDescPtr op_desc_src = std::make_shared<ge::OpDesc>("file_constant", "FileConstant");
ge::Operator op = ge::OpDescUtils::CreateOperatorFromOpDesc(op_desc_src);

// get tensor proto failed
auto ret = parser.ParseParams(reinterpret_cast<Message *>(&input_node), op);
EXPECT_EQ(ret, FAILED);

ge::onnx::TensorProto tensor_proto;
ge::onnx::AttributeProto *attribute = input_node.add_attribute();
attribute->set_name("value");
ge::onnx::TensorProto *attribute_tensor = attribute->mutable_t();
*attribute_tensor = tensor_proto;

// parse data type failed
attribute_tensor->set_data_type(OnnxDataType::UNDEFINED);
ret = parser.ParseParams(reinterpret_cast<Message *>(&input_node), op);
EXPECT_EQ(ret, FAILED);

// parse path failed
attribute_tensor->set_data_type(OnnxDataType::UINT16);
ret = parser.ParseParams(reinterpret_cast<Message *>(&input_node), op);
EXPECT_EQ(ret, FAILED);

// success
ge::onnx::StringStringEntryProto *string_proto = attribute_tensor->add_external_data();
string_proto->set_key("location");
string_proto->set_value("/usr/local");
attribute_tensor->add_dims(4);
ret = parser.ParseParams(reinterpret_cast<Message *>(&input_node), op);
EXPECT_EQ(ret, SUCCESS);

// check location, shape, dtype
NamedAttrs attrs;
AttrUtils::GetNamedAttrs(*op_desc_src, "file_constant_path", attrs);
std::string file_path;
AttrUtils::GetStr(attrs, "location", file_path);
EXPECT_EQ(file_path, "/usr/local");

std::vector<int64_t> dims;
op.GetAttr("shape", dims);
EXPECT_EQ(dims.size(), 1);
if (!dims.empty()) {
EXPECT_EQ(dims[0], 4);
}
DataType dtype;
op.GetAttr("dtype", dtype);
EXPECT_EQ(dtype, ge::DataType::DT_UINT16);
}

TEST_F(UtestOnnxParser, OnnxModelParser_ParseInput_test)
{
OnnxModelParser model_parser;
@@ -388,6 +573,25 @@ TEST_F(UtestOnnxParser, onnx_test_ModelParseToGraph)
EXPECT_EQ(ret, FAILED);
}

TEST_F(UtestOnnxParser, onnx_test_SetExternalPath)
{
OnnxModelParser modelParser;
ge::onnx::ModelProto model_proto;
auto ret = modelParser.SetExternalPath("", model_proto);
EXPECT_NE(ret, SUCCESS);

ge::onnx::GraphProto &graph_proto = const_cast<ge::onnx::GraphProto &>(model_proto.graph());
graph_proto.add_initializer();
ge::onnx::TensorProto* tensor_proto = graph_proto.add_initializer();
tensor_proto->set_data_location(ge::onnx::TensorProto_DataLocation::TensorProto_DataLocation_EXTERNAL);
tensor_proto->add_external_data();
ge::onnx::StringStringEntryProto *string_proto = tensor_proto->add_external_data();
string_proto->set_key("location");
string_proto->set_value("if.onnx");
ret = modelParser.SetExternalPath("/usr/local", model_proto);
EXPECT_EQ(ret, SUCCESS);
}

TEST_F(UtestOnnxParser, onnx_test_ParseFromMemory)
{
OnnxModelParser modelParser;


+ 1
- 1
tests/ut/parser/testcase/tensorflow_parser_testcase/tensorflow_parser_unittest.cc View File

@@ -4712,7 +4712,7 @@ TEST_F(UtestTensorflowParser, AddDumpOriginName_test)
std::vector<std::string> original_names;
(void)ge::AttrUtils::GetListStr(desc, ge::ATTR_NAME_DATA_DUMP_ORIGIN_OP_NAMES, original_names);
EXPECT_EQ(original_names.empty(), false);
EXPECT_EQ(original_names[0], "while/COND0/cond/Data1");
EXPECT_EQ(original_names[0], "WHILE0/while/COND0/cond/Data1");
}

} // namespace ge

Loading…
Cancel
Save