Browse Source

onnx parser st

pull/414/head
y00500818 3 years ago
parent
commit
e4093339ca
9 changed files with 182 additions and 21 deletions
  1. +0
    -20
      parser/onnx/onnx_util.cc
  2. +0
    -1
      parser/onnx/onnx_util.h
  3. +39
    -0
      tests/depends/ops_stub/ops_stub.h
  4. +64
    -0
      tests/st/parser_st_utils.cc
  5. +6
    -0
      tests/st/parser_st_utils.h
  6. BIN
      tests/st/testcase/origin_models/onnx_clip_v9.onnx
  7. +28
    -0
      tests/st/testcase/origin_models/onnx_clip_v9.py
  8. BIN
      tests/st/testcase/origin_models/onnx_const_type.onnx
  9. +45
    -0
      tests/st/testcase/test_onnx_parser.cc

+ 0
- 20
parser/onnx/onnx_util.cc View File

@@ -29,17 +29,6 @@ const std::map<uint32_t, ge::DataType> onnx_data_type_map = {
{OnnxDataType::COMPLEX64, ge::DataType::DT_COMPLEX64}, {OnnxDataType::COMPLEX128, ge::DataType::DT_COMPLEX128}, {OnnxDataType::COMPLEX64, ge::DataType::DT_COMPLEX64}, {OnnxDataType::COMPLEX128, ge::DataType::DT_COMPLEX128},
{OnnxDataType::BFLOAT16, ge::DataType::DT_UNDEFINED}, {OnnxDataType::BFLOAT16, ge::DataType::DT_UNDEFINED},
}; };

const std::map<uint32_t, int64_t> onnx_data_type_size_map = {
{OnnxDataType::FLOAT, sizeof(float)}, {OnnxDataType::UINT8, sizeof(uint8_t)},
{OnnxDataType::INT8, sizeof(int8_t)}, {OnnxDataType::UINT16, sizeof(uint16_t)},
{OnnxDataType::INT16, sizeof(int16_t)}, {OnnxDataType::INT32, sizeof(int32_t)},
{OnnxDataType::INT64, sizeof(int64_t)}, {OnnxDataType::STRING, sizeof(std::string)},
{OnnxDataType::BOOL, sizeof(bool)}, {OnnxDataType::FLOAT16, 2},
{OnnxDataType::DOUBLE, sizeof(double)}, {OnnxDataType::UINT32, sizeof(uint32_t)},
{OnnxDataType::UINT64, sizeof(uint64_t)}, {OnnxDataType::COMPLEX64, 8},
{OnnxDataType::COMPLEX128, 16}, {OnnxDataType::BFLOAT16, 2},
};
} }


namespace ge { namespace ge {
@@ -52,15 +41,6 @@ ge::DataType OnnxUtil::ConvertOnnxDataType(int64_t onnx_data_type) {
} }
} }


int64_t OnnxUtil::CaculateDataSize(int64_t onnx_data_type) {
auto search = onnx_data_type_size_map.find(onnx_data_type);
if (search != onnx_data_type_size_map.end()) {
return search->second;
} else {
return ge::DataType::DT_UNDEFINED;
}
}

void OnnxUtil::GenUniqueSubgraphName(int subgraph_index, const std::string &original_subgraph_name, void OnnxUtil::GenUniqueSubgraphName(int subgraph_index, const std::string &original_subgraph_name,
const std::string &parent_node_name, std::string &unique_subgraph_name) { const std::string &parent_node_name, std::string &unique_subgraph_name) {
unique_subgraph_name = parent_node_name + "_" + std::to_string(subgraph_index) + "_" + original_subgraph_name; unique_subgraph_name = parent_node_name + "_" + std::to_string(subgraph_index) + "_" + original_subgraph_name;


+ 0
- 1
parser/onnx/onnx_util.h View File

@@ -52,7 +52,6 @@ const char *const kOpTypeInput = "Input";
class OnnxUtil { class OnnxUtil {
public: public:
static ge::DataType ConvertOnnxDataType(int64_t onnx_data_type); static ge::DataType ConvertOnnxDataType(int64_t onnx_data_type);
static int64_t CaculateDataSize(int64_t onnx_data_type);
static void GenUniqueSubgraphName(int subgraph_index, const std::string &original_subgraph_name, static void GenUniqueSubgraphName(int subgraph_index, const std::string &original_subgraph_name,
const std::string &parent_node_name, std::string &unique_subgraph_name); const std::string &parent_node_name, std::string &unique_subgraph_name);
}; };


+ 39
- 0
tests/depends/ops_stub/ops_stub.h View File

@@ -19,6 +19,7 @@


#include "external/graph/operator_reg.h" #include "external/graph/operator_reg.h"
#include "register/op_registry.h" #include "register/op_registry.h"
#include "graph/utils/op_desc_utils.h"


namespace ge { namespace ge {
// for ir // for ir
@@ -99,6 +100,14 @@ REG_OP(Abs)
.OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE, DT_INT32, DT_INT64})) .OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE, DT_INT32, DT_INT64}))
.OP_END_FACTORY_REG(Abs) .OP_END_FACTORY_REG(Abs)


REG_OP(PartitionedCall)
.DYNAMIC_INPUT(args, TensorType::ALL())
.DYNAMIC_OUTPUT(output, TensorType::ALL())
.GRAPH(f)
.ATTR(config, String, "")
.ATTR(config_proto, String, "")
.ATTR(executor_type, String, "")
.OP_END_FACTORY_REG(PartitionedCall)


// for plugin // for plugin
static Status ParseParamsStub(const google::protobuf::Message* op_src, ge::Operator& op_dest) { static Status ParseParamsStub(const google::protobuf::Message* op_src, ge::Operator& op_dest) {
@@ -127,6 +136,29 @@ static Status ParseSubgraphPostFnIfStub(const std::string& subgraph_name, const
}); });
} }


static Status ParseParamsClipV9Stub(const Message* op_src, ge::Operator& op_dest) {
auto opDesc = ge::OpDescUtils::GetOpDescFromOperator(op_dest);
// 1.add dynamic input and out
opDesc->AddDynamicInputDesc("x", 1);
opDesc->AddDynamicOutputDesc("output", 1);

// 2.set original_type
ge::AttrUtils::SetStr(opDesc, "original_type", "ai.onnx::9::Clip");
return SUCCESS;
}

static Status ParseOpToGraphClipV9Stub(const Operator& op, Graph& graph) {
auto data0 = op::Data("data0").set_attr_index(0);
auto abs0 = op::Abs("abs0").set_input_x(data0);

std::vector<Operator> inputs{data0};
std::vector<std::pair<Operator, std::vector<size_t> > > output_indexs;
output_indexs.emplace_back(abs0, vector<std::size_t>{0});
graph.SetInputs(inputs).SetOutputs(output_indexs);
return SUCCESS;
}


// caffe plugin // caffe plugin
REGISTER_CUSTOM_OP("Data") REGISTER_CUSTOM_OP("Data")
.FrameworkType(domi::CAFFE) .FrameworkType(domi::CAFFE)
@@ -170,5 +202,12 @@ REGISTER_CUSTOM_OP("Add")
.FrameworkType(domi::TENSORFLOW) .FrameworkType(domi::TENSORFLOW)
.OriginOpType("Add") .OriginOpType("Add")
.ParseParamsFn(ParseParamsStub); .ParseParamsFn(ParseParamsStub);


REGISTER_CUSTOM_OP("PartitionedCall")
.FrameworkType(domi::ONNX)
.OriginOpType({"ai.onnx::9::Clip"})
.ParseParamsFn(ParseParamsClipV9Stub)
.ParseOpToGraphFn(ParseOpToGraphClipV9Stub);
} // namespace ge } // namespace ge
#endif // MAIN_OPS_STUB_H #endif // MAIN_OPS_STUB_H

+ 64
- 0
tests/st/parser_st_utils.cc View File

@@ -16,6 +16,7 @@


#include "st/parser_st_utils.h" #include "st/parser_st_utils.h"
#include "framework/common/debug/ge_log.h" #include "framework/common/debug/ge_log.h"
#include <limits.h>


namespace ge { namespace ge {
void ParerSTestsUtils::ClearParserInnerCtx() { void ParerSTestsUtils::ClearParserInnerCtx() {
@@ -41,4 +42,67 @@ void ParerSTestsUtils::ClearParserInnerCtx() {
ge::GetParserContext().enable_scope_fusion_passes = ""; ge::GetParserContext().enable_scope_fusion_passes = "";
GELOGI("Clear parser inner context successfully."); GELOGI("Clear parser inner context successfully.");
} }

MemBuffer* ParerSTestsUtils::MemBufferFromFile(const char *path) {
char path_temp[PATH_MAX + 1] = {0x00};
if(strlen(path) > PATH_MAX || nullptr == realpath(path, path_temp)) {
return nullptr;
}
FILE *fp = fopen(path_temp, "r+");
if (fp == nullptr) {
return nullptr;
}

// get model file length
if (0 != fseek(fp, 0, SEEK_END)) {
fclose(fp);
return nullptr;
}
long file_length = ftell(fp);
if (fseek(fp, 0, SEEK_SET)) {
fclose(fp);
return nullptr;
}
if (file_length <= 0) {
fclose(fp);
return nullptr;
}

// alloc model buffer
void *data = malloc((unsigned int)file_length);
if (!data) {
fclose(fp);
return nullptr;
}

// read file into memory
uint32_t read_size = (uint32_t)fread(data, 1, (unsigned int)file_length, fp);

// check if read success
if ((long)read_size != file_length) {
free(data);
data = nullptr;
fclose(fp);
return nullptr;
}

// close model file
fclose(fp);

// create an MemBuffer
MemBuffer* membuf = new MemBuffer();
if (!membuf) {
free(data);
data = nullptr;
return nullptr;
}
membuf->data = malloc((unsigned int)read_size);

// set size && data
membuf->size = (uint32_t)read_size;
memcpy((char*)membuf->data, (char*)data, read_size);
free(data);
return membuf;
}

} // namespace ge } // namespace ge

+ 6
- 0
tests/st/parser_st_utils.h View File

@@ -20,9 +20,15 @@
#include "framework/omg/parser/parser_inner_ctx.h" #include "framework/omg/parser/parser_inner_ctx.h"


namespace ge { namespace ge {
struct MemBuffer {
void *data;
uint32_t size;
};

class ParerSTestsUtils { class ParerSTestsUtils {
public: public:
static void ClearParserInnerCtx(); static void ClearParserInnerCtx();
static MemBuffer* MemBufferFromFile(const char *path);
}; };
} // namespace ge } // namespace ge




BIN
tests/st/testcase/origin_models/onnx_clip_v9.onnx View File


+ 28
- 0
tests/st/testcase/origin_models/onnx_clip_v9.py View File

@@ -0,0 +1,28 @@
import onnx
from onnx import helper
from onnx import AttributeProto, TensorProto, GraphProto


def make_clip_V9():
X = helper.make_tensor_value_info("X", TensorProto.FLOAT, [3, 4, 5])
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [3, 4, 5])
node_def = helper.make_node('Clip',
inputs=['X'],
outputs=['Y'],
max = 1.0,
min = -1.0,
)
graph = helper.make_graph(
[node_def],
"test_clip_case_V9",
[X],
[Y],
)

model = helper.make_model(graph, producer_name="onnx-mul_test")
model.opset_import[0].version = 9
onnx.save(model, "./onnx_clip_v9.onnx")


if __name__ == '__main__':
make_clip_V9()

BIN
tests/st/testcase/origin_models/onnx_const_type.onnx View File


+ 45
- 0
tests/st/testcase/test_onnx_parser.cc View File

@@ -24,6 +24,7 @@
#include "st/parser_st_utils.h" #include "st/parser_st_utils.h"
#include "external/ge/ge_api_types.h" #include "external/ge/ge_api_types.h"
#include "tests/depends/ops_stub/ops_stub.h" #include "tests/depends/ops_stub/ops_stub.h"
#include "parser/onnx/onnx_parser.h"


namespace ge { namespace ge {
class STestOnnxParser : public testing::Test { class STestOnnxParser : public testing::Test {
@@ -128,4 +129,48 @@ TEST_F(STestOnnxParser, onnx_parser_if_node) {
auto ret = ge::aclgrphParseONNX(model_file.c_str(), parser_params, graph); auto ret = ge::aclgrphParseONNX(model_file.c_str(), parser_params, graph);
EXPECT_EQ(ret, GRAPH_SUCCESS); EXPECT_EQ(ret, GRAPH_SUCCESS);
} }

TEST_F(STestOnnxParser, onnx_parser_expand_one_to_many) {
std::string case_dir = __FILE__;
case_dir = case_dir.substr(0, case_dir.find_last_of("/"));
std::string model_file = case_dir + "/origin_models/onnx_clip_v9.onnx";
std::map<ge::AscendString, ge::AscendString> parser_params;
ge::Graph graph;
auto ret = ge::aclgrphParseONNX(model_file.c_str(), parser_params, graph);
EXPECT_EQ(ret, GRAPH_SUCCESS);

MemBuffer *buffer = ParerSTestsUtils::MemBufferFromFile(model_file.c_str());
ret = ge::aclgrphParseONNXFromMem(reinterpret_cast<char *>(buffer->data), buffer->size, parser_params, graph);
EXPECT_EQ(ret, GRAPH_SUCCESS);
}

TEST_F(STestOnnxParser, onnx_parser_to_json) {
std::string case_dir = __FILE__;
case_dir = case_dir.substr(0, case_dir.find_last_of("/"));
std::string model_file = case_dir + "/origin_models/onnx_clip_v9.onnx";
std::map<ge::AscendString, ge::AscendString> parser_params;
OnnxModelParser onnx_parser;

const char *json_file = "tmp.json";
auto ret = onnx_parser.ToJson(model_file.c_str(), json_file);
EXPECT_EQ(ret, SUCCESS);

const char *json_null = nullptr;
ret = onnx_parser.ToJson(model_file.c_str(), json_null);
EXPECT_EQ(ret, FAILED);
const char *model_null = nullptr;
ret = onnx_parser.ToJson(model_null, json_null);
EXPECT_EQ(ret, FAILED);
}

TEST_F(STestOnnxParser, onnx_parser_const_data_type) {
std::string case_dir = __FILE__;
case_dir = case_dir.substr(0, case_dir.find_last_of("/"));
std::string model_file = case_dir + "/origin_models/onnx_const_type.onnx";
std::map<ge::AscendString, ge::AscendString> parser_params;
ge::Graph graph;
auto ret = ge::aclgrphParseONNX(model_file.c_str(), parser_params, graph);
EXPECT_EQ(ret, GRAPH_SUCCESS);
}

} // namespace ge } // namespace ge

Loading…
Cancel
Save