You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

onnx_constant_parser.cc 9.4 kB

4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
3 years ago
4 years ago
4 years ago
4 years ago
3 years ago
4 years ago
3 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
3 years ago
4 years ago
3 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
3 years ago
4 years ago
4 years ago
4 years ago
3 years ago
4 years ago
4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "onnx_constant_parser.h"
  17. #include <map>
  18. #include <vector>
  19. #include "parser/common/acl_graph_parser_util.h"
  20. #include "framework/omg/parser/parser_inner_ctx.h"
  21. #include "graph/ge_tensor.h"
  22. #include "graph/utils/tensor_adapter.h"
  23. #include "parser/common/op_parser_factory.h"
  24. #include "parser/onnx/onnx_util.h"
  25. using ge::onnx::NodeProto;
  26. using ge::onnx::TensorProto;
  27. using domi::ONNX;
  28. using GeShape = ge::GeShape;
  29. using GeTensorDesc = ge::GeTensorDesc;
  30. using namespace ge::parser;
  31. namespace ge {
  32. Status OnnxConstantParser::ParseConvertData(const ge::onnx::TensorProto &tensor_proto, ge::Tensor &tensor, int count) {
  33. int64_t data_type = tensor_proto.data_type();
  34. if (ge::OnnxUtil::ConvertOnnxDataType(data_type) == ge::DataType::DT_UNDEFINED) {
  35. REPORT_INNER_ERROR("E19999", "data_type %ld not support.", data_type);
  36. GELOGE(FAILED, "[Check][Param] data_type %ld not support.", data_type);
  37. return FAILED;
  38. }
  39. if (count == 0) {
  40. GELOGI("At least one dim equals zero, result in the count equal to zero.");
  41. return SUCCESS;
  42. }
  43. std::map<uint32_t, int32_t> datatype_val_size_map = {
  44. // for int32, uint8, int8, uint16, int16, bool, and float16 values
  45. {OnnxDataType::INT32, tensor_proto.int32_data_size()},
  46. {OnnxDataType::UINT8, tensor_proto.int32_data_size()},
  47. {OnnxDataType::INT8, tensor_proto.int32_data_size()},
  48. {OnnxDataType::UINT16, tensor_proto.int32_data_size()},
  49. {OnnxDataType::INT16, tensor_proto.int32_data_size()},
  50. {OnnxDataType::BOOL, tensor_proto.int32_data_size()},
  51. {OnnxDataType::FLOAT16, tensor_proto.int32_data_size()},
  52. // for int64 values
  53. {OnnxDataType::INT64, tensor_proto.int64_data_size()},
  54. // for string values
  55. {OnnxDataType::STRING, tensor_proto.string_data_size()},
  56. // for float and complex64 values
  57. {OnnxDataType::FLOAT, tensor_proto.float_data_size()},
  58. {OnnxDataType::COMPLEX64, tensor_proto.float_data_size()},
  59. // for double and complex128 values
  60. {OnnxDataType::DOUBLE, tensor_proto.double_data_size()},
  61. {OnnxDataType::COMPLEX128, tensor_proto.double_data_size()},
  62. // for uint64 and uint32 values
  63. {OnnxDataType::UINT64, tensor_proto.uint64_data_size()},
  64. {OnnxDataType::UINT32, tensor_proto.uint64_data_size()},
  65. };
  66. int32_t datatype_val_size = 0;
  67. std::map<uint32_t, int32_t>::const_iterator iter = datatype_val_size_map.find(data_type);
  68. if (iter != datatype_val_size_map.end()) {
  69. datatype_val_size = iter->second;
  70. } else {
  71. REPORT_INNER_ERROR("E19999", "data_type %ld not support.", data_type);
  72. GELOGE(domi::PARAM_INVALID, "[Find][DataType]data_type %ld not support.", data_type);
  73. return FAILED;
  74. }
  75. // find raw data
  76. if (datatype_val_size == 0) {
  77. if (tensor_proto.raw_data().empty()) {
  78. REPORT_INNER_ERROR("E19999", "tensor_proto has no elements or raw_data");
  79. GELOGE(domi::PARAM_INVALID, "[Check][Param]tensor_proto has no elements or raw_data");
  80. return FAILED;
  81. }
  82. if (data_type == OnnxDataType::STRING) {
  83. tensor.SetData(tensor_proto.raw_data().c_str());
  84. } else {
  85. tensor.SetData(PtrToPtr<const char_t, const uint8_t>(tensor_proto.raw_data().c_str()),
  86. tensor_proto.raw_data().size());
  87. }
  88. GELOGD("Raw data size is : %zu", tensor_proto.raw_data().size());
  89. return SUCCESS;
  90. }
  91. // find _data() elements
  92. ParseConvertDataElements(tensor_proto, tensor, count, data_type);
  93. return SUCCESS;
  94. }
  95. void OnnxConstantParser::ParseConvertDataElements(const ge::onnx::TensorProto &tensor_proto, ge::Tensor &tensor,
  96. int count, int64_t data_type) {
  97. switch (data_type) {
  98. // for int32, uint8, int8, uint16, int16, bool, and float16 values
  99. case OnnxDataType::INT32:
  100. case OnnxDataType::UINT8:
  101. case OnnxDataType::INT8:
  102. case OnnxDataType::UINT16:
  103. case OnnxDataType::INT16:
  104. case OnnxDataType::BOOL:
  105. case OnnxDataType::FLOAT16:
  106. (void)SetTensorData(tensor_proto.int32_data_size(), tensor_proto.int32_data(), count, tensor);
  107. break;
  108. // for int64 values
  109. case OnnxDataType::INT64:
  110. (void)SetTensorData(tensor_proto.int64_data_size(), tensor_proto.int64_data(), count, tensor);
  111. break;
  112. // for string values
  113. case OnnxDataType::STRING: {
  114. std::vector<AscendString> data;
  115. for (auto str_data : tensor_proto.string_data()) {
  116. data.emplace_back(AscendString(str_data.c_str()));
  117. }
  118. tensor.SetData(data);
  119. break;
  120. }
  121. // for float and complex64 values
  122. case OnnxDataType::FLOAT:
  123. (void)SetTensorData(tensor_proto.float_data_size(), tensor_proto.float_data(), count, tensor);
  124. break;
  125. case OnnxDataType::COMPLEX64:
  126. (void)SetTensorData(tensor_proto.float_data_size(), tensor_proto.float_data(),
  127. tensor_proto.float_data_size(), tensor);
  128. break;
  129. // for double and complex128 values
  130. case OnnxDataType::DOUBLE:
  131. (void)SetTensorData(tensor_proto.double_data_size(), tensor_proto.double_data(), count, tensor);
  132. break;
  133. case OnnxDataType::COMPLEX128:
  134. (void)SetTensorData(tensor_proto.double_data_size(), tensor_proto.double_data(),
  135. tensor_proto.double_data_size(), tensor);
  136. break;
  137. // for uint64 and uint32 values
  138. case OnnxDataType::UINT64:
  139. case OnnxDataType::UINT32:
  140. (void)SetTensorData(tensor_proto.uint64_data_size(), tensor_proto.uint64_data(), count, tensor);
  141. break;
  142. default:
  143. break;
  144. }
  145. }
  146. Status OnnxConstantParser::ParseConvertTensor(const ge::onnx::TensorProto &tensor_proto, ge::Tensor &tensor) {
  147. // convert shape and format
  148. std::vector<int64_t> tmp_shape;
  149. int count = 1;
  150. for (int i = 0; i < tensor_proto.dims_size(); i++) {
  151. tmp_shape.push_back(tensor_proto.dims(i));
  152. int64_t dim = tmp_shape[i];
  153. // support weights shape [0],have no weights
  154. if (dim < 0 || (count != 0 && (dim >= INT64_MAX / count))) {
  155. REPORT_INNER_ERROR("E19999", "Dim size is invalid, dim is less than zero or dim size exceeds INT64_MAX.");
  156. GELOGE(FAILED, "[Check][Param] Dim size is invalid, dim is less than zero or dim size exceeds INT64_MAX.");
  157. return FAILED;
  158. }
  159. count *= dim;
  160. };
  161. TensorDesc tensor_desc = tensor.GetTensorDesc();
  162. tensor_desc.SetShape(ge::Shape(tmp_shape));
  163. tensor.SetTensorDesc(tensor_desc);
  164. // set data
  165. if (ParseConvertData(tensor_proto, tensor, count) != SUCCESS) {
  166. GELOGE(FAILED, "[Invoke][ParseConvertData]Convert ge tensor data and format failed.");
  167. return FAILED;
  168. }
  169. return SUCCESS;
  170. }
  171. Status OnnxConstantParser::ParseConvertDataType(const ge::onnx::TensorProto &tensor_proto, ge::Tensor &tensor) {
  172. int64_t data_type = tensor_proto.data_type();
  173. ge::DataType type = ge::OnnxUtil::ConvertOnnxDataType(data_type);
  174. if (type == ge::DataType::DT_UNDEFINED) {
  175. REPORT_INNER_ERROR("E19999", "tensor_proto date type %ld is undefined.", data_type);
  176. GELOGE(domi::PARAM_INVALID, "[Check][Param] tensor_proto date type %ld is undefined.", data_type);
  177. return FAILED;
  178. }
  179. TensorDesc tensor_desc = tensor.GetTensorDesc();
  180. tensor_desc.SetDataType(ge::DataType(type));
  181. tensor.SetTensorDesc(tensor_desc);
  182. return SUCCESS;
  183. }
  184. Status OnnxConstantParser::ParseConstFromInput(const ge::onnx::NodeProto *op_src, ge::Operator &op_def) {
  185. GE_CHECK_NOTNULL(op_src);
  186. const NodeProto *node = PtrToPtr<const ge::onnx::NodeProto, const NodeProto>(op_src);
  187. // Get const Tensor from node
  188. Tensor tensor;
  189. for (auto it : node->attribute()) {
  190. if (it.name() != ge::kAttrNameValue) {
  191. continue;
  192. }
  193. const ::ge::onnx::TensorProto it_tensor = it.t();
  194. if (ParseConvertDataType(it_tensor, tensor) != SUCCESS) {
  195. GELOGE(FAILED, "[Check][Param] Convert ge tensor date type failed, attribute name is %s.", it.name().c_str());
  196. return FAILED;
  197. }
  198. if (ParseConvertTensor(it_tensor, tensor) != SUCCESS) {
  199. GELOGE(FAILED, "[Check][Param] Convert ge tensor shape and format failed, attribute name is %s.",
  200. it.name().c_str());
  201. return FAILED;
  202. }
  203. }
  204. op_def.SetAttr(ge::kAttrNameValue, tensor);
  205. return SUCCESS;
  206. }
  207. Status OnnxConstantParser::ParseParams(const Message *op_src, ge::Operator &op_def) {
  208. GE_CHECK_NOTNULL(op_src);
  209. const ge::onnx::NodeProto *node = PtrToPtr<const Message, const ge::onnx::NodeProto>(op_src);
  210. GE_CHECK_NOTNULL(node);
  211. GELOGD("Onnx op node name = %s, op type= %s, parse params", node->name().c_str(), node->op_type().c_str());
  212. if (ParseConstFromInput(node, op_def) != SUCCESS) {
  213. GELOGE(FAILED, "[Parse][Constant] node %s failed", node->name().c_str());
  214. return FAILED;
  215. }
  216. return SUCCESS;
  217. }
  218. REGISTER_OP_PARSER_CREATOR(ONNX, CONSTANT, OnnxConstantParser);
  219. } // namespace ge