@@ -29,17 +29,6 @@ const std::map<uint32_t, ge::DataType> onnx_data_type_map = { | |||||
{OnnxDataType::COMPLEX64, ge::DataType::DT_COMPLEX64}, {OnnxDataType::COMPLEX128, ge::DataType::DT_COMPLEX128}, | {OnnxDataType::COMPLEX64, ge::DataType::DT_COMPLEX64}, {OnnxDataType::COMPLEX128, ge::DataType::DT_COMPLEX128}, | ||||
{OnnxDataType::BFLOAT16, ge::DataType::DT_UNDEFINED}, | {OnnxDataType::BFLOAT16, ge::DataType::DT_UNDEFINED}, | ||||
}; | }; | ||||
const std::map<uint32_t, int64_t> onnx_data_type_size_map = { | |||||
{OnnxDataType::FLOAT, sizeof(float)}, {OnnxDataType::UINT8, sizeof(uint8_t)}, | |||||
{OnnxDataType::INT8, sizeof(int8_t)}, {OnnxDataType::UINT16, sizeof(uint16_t)}, | |||||
{OnnxDataType::INT16, sizeof(int16_t)}, {OnnxDataType::INT32, sizeof(int32_t)}, | |||||
{OnnxDataType::INT64, sizeof(int64_t)}, {OnnxDataType::STRING, sizeof(std::string)}, | |||||
{OnnxDataType::BOOL, sizeof(bool)}, {OnnxDataType::FLOAT16, 2}, | |||||
{OnnxDataType::DOUBLE, sizeof(double)}, {OnnxDataType::UINT32, sizeof(uint32_t)}, | |||||
{OnnxDataType::UINT64, sizeof(uint64_t)}, {OnnxDataType::COMPLEX64, 8}, | |||||
{OnnxDataType::COMPLEX128, 16}, {OnnxDataType::BFLOAT16, 2}, | |||||
}; | |||||
} | } | ||||
namespace ge { | namespace ge { | ||||
@@ -52,15 +41,6 @@ ge::DataType OnnxUtil::ConvertOnnxDataType(int64_t onnx_data_type) { | |||||
} | } | ||||
} | } | ||||
int64_t OnnxUtil::CaculateDataSize(int64_t onnx_data_type) { | |||||
auto search = onnx_data_type_size_map.find(onnx_data_type); | |||||
if (search != onnx_data_type_size_map.end()) { | |||||
return search->second; | |||||
} else { | |||||
return ge::DataType::DT_UNDEFINED; | |||||
} | |||||
} | |||||
void OnnxUtil::GenUniqueSubgraphName(int subgraph_index, const std::string &original_subgraph_name, | void OnnxUtil::GenUniqueSubgraphName(int subgraph_index, const std::string &original_subgraph_name, | ||||
const std::string &parent_node_name, std::string &unique_subgraph_name) { | const std::string &parent_node_name, std::string &unique_subgraph_name) { | ||||
unique_subgraph_name = parent_node_name + "_" + std::to_string(subgraph_index) + "_" + original_subgraph_name; | unique_subgraph_name = parent_node_name + "_" + std::to_string(subgraph_index) + "_" + original_subgraph_name; | ||||
@@ -52,7 +52,6 @@ const char *const kOpTypeInput = "Input"; | |||||
class OnnxUtil { | class OnnxUtil { | ||||
public: | public: | ||||
static ge::DataType ConvertOnnxDataType(int64_t onnx_data_type); | static ge::DataType ConvertOnnxDataType(int64_t onnx_data_type); | ||||
static int64_t CaculateDataSize(int64_t onnx_data_type); | |||||
static void GenUniqueSubgraphName(int subgraph_index, const std::string &original_subgraph_name, | static void GenUniqueSubgraphName(int subgraph_index, const std::string &original_subgraph_name, | ||||
const std::string &parent_node_name, std::string &unique_subgraph_name); | const std::string &parent_node_name, std::string &unique_subgraph_name); | ||||
}; | }; | ||||
@@ -19,6 +19,7 @@ | |||||
#include "external/graph/operator_reg.h" | #include "external/graph/operator_reg.h" | ||||
#include "register/op_registry.h" | #include "register/op_registry.h" | ||||
#include "graph/utils/op_desc_utils.h" | |||||
namespace ge { | namespace ge { | ||||
// for ir | // for ir | ||||
@@ -99,6 +100,14 @@ REG_OP(Abs) | |||||
.OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE, DT_INT32, DT_INT64})) | .OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE, DT_INT32, DT_INT64})) | ||||
.OP_END_FACTORY_REG(Abs) | .OP_END_FACTORY_REG(Abs) | ||||
REG_OP(PartitionedCall) | |||||
.DYNAMIC_INPUT(args, TensorType::ALL()) | |||||
.DYNAMIC_OUTPUT(output, TensorType::ALL()) | |||||
.GRAPH(f) | |||||
.ATTR(config, String, "") | |||||
.ATTR(config_proto, String, "") | |||||
.ATTR(executor_type, String, "") | |||||
.OP_END_FACTORY_REG(PartitionedCall) | |||||
// for plugin | // for plugin | ||||
static Status ParseParamsStub(const google::protobuf::Message* op_src, ge::Operator& op_dest) { | static Status ParseParamsStub(const google::protobuf::Message* op_src, ge::Operator& op_dest) { | ||||
@@ -127,6 +136,29 @@ static Status ParseSubgraphPostFnIfStub(const std::string& subgraph_name, const | |||||
}); | }); | ||||
} | } | ||||
static Status ParseParamsClipV9Stub(const Message* op_src, ge::Operator& op_dest) { | |||||
auto opDesc = ge::OpDescUtils::GetOpDescFromOperator(op_dest); | |||||
// 1.add dynamic input and out | |||||
opDesc->AddDynamicInputDesc("x", 1); | |||||
opDesc->AddDynamicOutputDesc("output", 1); | |||||
// 2.set original_type | |||||
ge::AttrUtils::SetStr(opDesc, "original_type", "ai.onnx::9::Clip"); | |||||
return SUCCESS; | |||||
} | |||||
static Status ParseOpToGraphClipV9Stub(const Operator& op, Graph& graph) { | |||||
auto data0 = op::Data("data0").set_attr_index(0); | |||||
auto abs0 = op::Abs("abs0").set_input_x(data0); | |||||
std::vector<Operator> inputs{data0}; | |||||
std::vector<std::pair<Operator, std::vector<size_t> > > output_indexs; | |||||
output_indexs.emplace_back(abs0, vector<std::size_t>{0}); | |||||
graph.SetInputs(inputs).SetOutputs(output_indexs); | |||||
return SUCCESS; | |||||
} | |||||
// caffe plugin | // caffe plugin | ||||
REGISTER_CUSTOM_OP("Data") | REGISTER_CUSTOM_OP("Data") | ||||
.FrameworkType(domi::CAFFE) | .FrameworkType(domi::CAFFE) | ||||
@@ -170,5 +202,12 @@ REGISTER_CUSTOM_OP("Add") | |||||
.FrameworkType(domi::TENSORFLOW) | .FrameworkType(domi::TENSORFLOW) | ||||
.OriginOpType("Add") | .OriginOpType("Add") | ||||
.ParseParamsFn(ParseParamsStub); | .ParseParamsFn(ParseParamsStub); | ||||
REGISTER_CUSTOM_OP("PartitionedCall") | |||||
.FrameworkType(domi::ONNX) | |||||
.OriginOpType({"ai.onnx::9::Clip"}) | |||||
.ParseParamsFn(ParseParamsClipV9Stub) | |||||
.ParseOpToGraphFn(ParseOpToGraphClipV9Stub); | |||||
} // namespace ge | } // namespace ge | ||||
#endif // MAIN_OPS_STUB_H | #endif // MAIN_OPS_STUB_H |
@@ -16,6 +16,7 @@ | |||||
#include "st/parser_st_utils.h" | #include "st/parser_st_utils.h" | ||||
#include "framework/common/debug/ge_log.h" | #include "framework/common/debug/ge_log.h" | ||||
#include <limits.h> | |||||
namespace ge { | namespace ge { | ||||
void ParerSTestsUtils::ClearParserInnerCtx() { | void ParerSTestsUtils::ClearParserInnerCtx() { | ||||
@@ -41,4 +42,67 @@ void ParerSTestsUtils::ClearParserInnerCtx() { | |||||
ge::GetParserContext().enable_scope_fusion_passes = ""; | ge::GetParserContext().enable_scope_fusion_passes = ""; | ||||
GELOGI("Clear parser inner context successfully."); | GELOGI("Clear parser inner context successfully."); | ||||
} | } | ||||
MemBuffer* ParerSTestsUtils::MemBufferFromFile(const char *path) { | |||||
char path_temp[PATH_MAX + 1] = {0x00}; | |||||
if(strlen(path) > PATH_MAX || nullptr == realpath(path, path_temp)) { | |||||
return nullptr; | |||||
} | |||||
FILE *fp = fopen(path_temp, "r+"); | |||||
if (fp == nullptr) { | |||||
return nullptr; | |||||
} | |||||
// get model file length | |||||
if (0 != fseek(fp, 0, SEEK_END)) { | |||||
fclose(fp); | |||||
return nullptr; | |||||
} | |||||
long file_length = ftell(fp); | |||||
if (fseek(fp, 0, SEEK_SET)) { | |||||
fclose(fp); | |||||
return nullptr; | |||||
} | |||||
if (file_length <= 0) { | |||||
fclose(fp); | |||||
return nullptr; | |||||
} | |||||
// alloc model buffer | |||||
void *data = malloc((unsigned int)file_length); | |||||
if (!data) { | |||||
fclose(fp); | |||||
return nullptr; | |||||
} | |||||
// read file into memory | |||||
uint32_t read_size = (uint32_t)fread(data, 1, (unsigned int)file_length, fp); | |||||
// check if read success | |||||
if ((long)read_size != file_length) { | |||||
free(data); | |||||
data = nullptr; | |||||
fclose(fp); | |||||
return nullptr; | |||||
} | |||||
// close model file | |||||
fclose(fp); | |||||
// create an MemBuffer | |||||
MemBuffer* membuf = new MemBuffer(); | |||||
if (!membuf) { | |||||
free(data); | |||||
data = nullptr; | |||||
return nullptr; | |||||
} | |||||
membuf->data = malloc((unsigned int)read_size); | |||||
// set size && data | |||||
membuf->size = (uint32_t)read_size; | |||||
memcpy((char*)membuf->data, (char*)data, read_size); | |||||
free(data); | |||||
return membuf; | |||||
} | |||||
} // namespace ge | } // namespace ge |
@@ -20,9 +20,15 @@ | |||||
#include "framework/omg/parser/parser_inner_ctx.h" | #include "framework/omg/parser/parser_inner_ctx.h" | ||||
namespace ge { | namespace ge { | ||||
struct MemBuffer { | |||||
void *data; | |||||
uint32_t size; | |||||
}; | |||||
class ParerSTestsUtils { | class ParerSTestsUtils { | ||||
public: | public: | ||||
static void ClearParserInnerCtx(); | static void ClearParserInnerCtx(); | ||||
static MemBuffer* MemBufferFromFile(const char *path); | |||||
}; | }; | ||||
} // namespace ge | } // namespace ge | ||||
@@ -0,0 +1,28 @@ | |||||
import onnx | |||||
from onnx import helper | |||||
from onnx import AttributeProto, TensorProto, GraphProto | |||||
def make_clip_V9(): | |||||
X = helper.make_tensor_value_info("X", TensorProto.FLOAT, [3, 4, 5]) | |||||
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [3, 4, 5]) | |||||
node_def = helper.make_node('Clip', | |||||
inputs=['X'], | |||||
outputs=['Y'], | |||||
max = 1.0, | |||||
min = -1.0, | |||||
) | |||||
graph = helper.make_graph( | |||||
[node_def], | |||||
"test_clip_case_V9", | |||||
[X], | |||||
[Y], | |||||
) | |||||
model = helper.make_model(graph, producer_name="onnx-mul_test") | |||||
model.opset_import[0].version = 9 | |||||
onnx.save(model, "./onnx_clip_v9.onnx") | |||||
if __name__ == '__main__': | |||||
make_clip_V9() |
@@ -24,6 +24,7 @@ | |||||
#include "st/parser_st_utils.h" | #include "st/parser_st_utils.h" | ||||
#include "external/ge/ge_api_types.h" | #include "external/ge/ge_api_types.h" | ||||
#include "tests/depends/ops_stub/ops_stub.h" | #include "tests/depends/ops_stub/ops_stub.h" | ||||
#include "parser/onnx/onnx_parser.h" | |||||
namespace ge { | namespace ge { | ||||
class STestOnnxParser : public testing::Test { | class STestOnnxParser : public testing::Test { | ||||
@@ -128,4 +129,48 @@ TEST_F(STestOnnxParser, onnx_parser_if_node) { | |||||
auto ret = ge::aclgrphParseONNX(model_file.c_str(), parser_params, graph); | auto ret = ge::aclgrphParseONNX(model_file.c_str(), parser_params, graph); | ||||
EXPECT_EQ(ret, GRAPH_SUCCESS); | EXPECT_EQ(ret, GRAPH_SUCCESS); | ||||
} | } | ||||
TEST_F(STestOnnxParser, onnx_parser_expand_one_to_many) { | |||||
std::string case_dir = __FILE__; | |||||
case_dir = case_dir.substr(0, case_dir.find_last_of("/")); | |||||
std::string model_file = case_dir + "/origin_models/onnx_clip_v9.onnx"; | |||||
std::map<ge::AscendString, ge::AscendString> parser_params; | |||||
ge::Graph graph; | |||||
auto ret = ge::aclgrphParseONNX(model_file.c_str(), parser_params, graph); | |||||
EXPECT_EQ(ret, GRAPH_SUCCESS); | |||||
MemBuffer *buffer = ParerSTestsUtils::MemBufferFromFile(model_file.c_str()); | |||||
ret = ge::aclgrphParseONNXFromMem(reinterpret_cast<char *>(buffer->data), buffer->size, parser_params, graph); | |||||
EXPECT_EQ(ret, GRAPH_SUCCESS); | |||||
} | |||||
TEST_F(STestOnnxParser, onnx_parser_to_json) { | |||||
std::string case_dir = __FILE__; | |||||
case_dir = case_dir.substr(0, case_dir.find_last_of("/")); | |||||
std::string model_file = case_dir + "/origin_models/onnx_clip_v9.onnx"; | |||||
std::map<ge::AscendString, ge::AscendString> parser_params; | |||||
OnnxModelParser onnx_parser; | |||||
const char *json_file = "tmp.json"; | |||||
auto ret = onnx_parser.ToJson(model_file.c_str(), json_file); | |||||
EXPECT_EQ(ret, SUCCESS); | |||||
const char *json_null = nullptr; | |||||
ret = onnx_parser.ToJson(model_file.c_str(), json_null); | |||||
EXPECT_EQ(ret, FAILED); | |||||
const char *model_null = nullptr; | |||||
ret = onnx_parser.ToJson(model_null, json_null); | |||||
EXPECT_EQ(ret, FAILED); | |||||
} | |||||
TEST_F(STestOnnxParser, onnx_parser_const_data_type) { | |||||
std::string case_dir = __FILE__; | |||||
case_dir = case_dir.substr(0, case_dir.find_last_of("/")); | |||||
std::string model_file = case_dir + "/origin_models/onnx_const_type.onnx"; | |||||
std::map<ge::AscendString, ge::AscendString> parser_params; | |||||
ge::Graph graph; | |||||
auto ret = ge::aclgrphParseONNX(model_file.c_str(), parser_params, graph); | |||||
EXPECT_EQ(ret, GRAPH_SUCCESS); | |||||
} | |||||
} // namespace ge | } // namespace ge |