You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

util.h 14 kB

modified: CMakeLists.txt modified: build.sh modified: ge/ge_runtime/runtime_model.cc modified: ge/ge_runtime/task/aicpu_task.cc modified: ge/ge_runtime/task/hccl_task.cc modified: ge/ge_runtime/task/label_goto_task.cc modified: ge/ge_runtime/task/label_switch_task.cc new file: tests/st/CMakeLists.txt new file: tests/st/cmake/graphengine.cmake new file: tests/st/framework/CMakeLists.txt new file: tests/st/framework/framework.cc new file: tests/st/framework/framework.h new file: tests/st/framework/stub_engine/CMakeLists.txt new file: tests/st/framework/stub_engine/common/constant/constant.h new file: tests/st/framework/stub_engine/engine/stub_engine.cc new file: tests/st/framework/stub_engine/engine/stub_engine.h new file: tests/st/framework/stub_engine/ops_kernel_store/host_cpu_ops_kernel_builder.cc new file: tests/st/framework/stub_engine/ops_kernel_store/host_cpu_ops_kernel_builder.h new file: tests/st/framework/stub_engine/ops_kernel_store/host_cpu_ops_kernel_info.cc new file: tests/st/framework/stub_engine/ops_kernel_store/host_cpu_ops_kernel_info.h new file: tests/st/framework/stub_engine/ops_kernel_store/op/host_op.cc new file: tests/st/framework/stub_engine/ops_kernel_store/op/host_op.h new file: tests/st/framework/stub_engine/ops_kernel_store/op/op.h new file: tests/st/framework/stub_engine/ops_kernel_store/op/op_factory.cc new file: tests/st/framework/stub_engine/ops_kernel_store/op/op_factory.h new file: tests/st/framework/stub_engine/proto/task.proto new file: tests/st/framework/stub_op_proto/array_ops.cc new file: tests/st/framework/stub_op_proto/array_ops.h new file: tests/st/framework/stub_op_proto/control_flow_ops.cc new file: tests/st/framework/stub_op_proto/control_flow_ops.h new file: tests/st/framework/stub_op_proto/elewise_calculation_ops.cc new file: tests/st/framework/stub_op_proto/elewise_calculation_ops.h new file: tests/st/framework/stub_op_proto/util/array_ops_shape_fns.cc new file: tests/st/framework/stub_op_proto/util/array_ops_shape_fns.h new file: tests/st/framework/stub_op_proto/util/axis_util.cc new file: tests/st/framework/stub_op_proto/util/axis_util.h new file: tests/st/framework/stub_op_proto/util/common_shape_fns.cc new file: tests/st/framework/stub_op_proto/util/common_shape_fns.h new file: tests/st/framework/stub_op_proto/util/error_code.h new file: tests/st/framework/stub_op_proto/util/error_util.cc new file: tests/st/framework/stub_op_proto/util/error_util.h new file: tests/st/framework/stub_op_proto/util/op_common_util.h new file: tests/st/framework/stub_op_proto/util/op_log.h new file: tests/st/framework/stub_op_proto/util/transfer_shape_according_to_format.cc new file: tests/st/framework/stub_op_proto/util/transfer_shape_according_to_format.h new file: tests/st/framework/stub_op_proto/util/util.cc new file: tests/st/framework/stub_op_proto/util/util.h new file: tests/st/framework/utils/assertion/graph_assertion.cc new file: tests/st/framework/utils/assertion/graph_assertion.h new file: tests/st/framework/utils/builder/graph_builder_utils.cc new file: tests/st/framework/utils/builder/graph_builder_utils.h new file: tests/st/framework/utils/builder/tensor_builder_utils.cc new file: tests/st/framework/utils/builder/tensor_builder_utils.h new file: tests/st/test.cc new file: tests/st/testcase/CMakeLists.txt new file: tests/st/testcase/test_framework_dummy.cc modified: CMakeLists.txt modified: build.sh modified: ge/ge_runtime/runtime_model.cc modified: ge/ge_runtime/task/aicpu_task.cc modified: ge/ge_runtime/task/hccl_task.cc modified: ge/ge_runtime/task/label_goto_task.cc modified: ge/ge_runtime/task/label_switch_task.cc new file: tests/st/CMakeLists.txt new file: tests/st/cmake/graphengine.cmake new file: tests/st/framework/CMakeLists.txt new file: tests/st/framework/framework.cc new file: tests/st/framework/framework.h new file: tests/st/framework/stub_engine/CMakeLists.txt new file: tests/st/framework/stub_engine/common/constant/constant.h new file: tests/st/framework/stub_engine/engine/stub_engine.cc new file: tests/st/framework/stub_engine/engine/stub_engine.h new file: tests/st/framework/stub_engine/ops_kernel_store/host_cpu_ops_kernel_builder.cc new file: tests/st/framework/stub_engine/ops_kernel_store/host_cpu_ops_kernel_builder.h new file: tests/st/framework/stub_engine/ops_kernel_store/host_cpu_ops_kernel_info.cc new file: tests/st/framework/stub_engine/ops_kernel_store/host_cpu_ops_kernel_info.h new file: tests/st/framework/stub_engine/ops_kernel_store/op/host_op.cc new file: tests/st/framework/stub_engine/ops_kernel_store/op/host_op.h new file: tests/st/framework/stub_engine/ops_kernel_store/op/op.h new file: tests/st/framework/stub_engine/ops_kernel_store/op/op_factory.cc new file: tests/st/framework/stub_engine/ops_kernel_store/op/op_factory.h new file: tests/st/framework/stub_engine/proto/task.proto new file: tests/st/framework/stub_op_proto/array_ops.cc new file: tests/st/framework/stub_op_proto/array_ops.h new file: tests/st/framework/stub_op_proto/control_flow_ops.cc new file: tests/st/framework/stub_op_proto/control_flow_ops.h new file: tests/st/framework/stub_op_proto/elewise_calculation_ops.cc new file: tests/st/framework/stub_op_proto/elewise_calculation_ops.h new file: tests/st/framework/stub_op_proto/util/array_ops_shape_fns.cc new file: tests/st/framework/stub_op_proto/util/array_ops_shape_fns.h new file: tests/st/framework/stub_op_proto/util/axis_util.cc new file: tests/st/framework/stub_op_proto/util/axis_util.h new file: tests/st/framework/stub_op_proto/util/common_shape_fns.cc new file: tests/st/framework/stub_op_proto/util/common_shape_fns.h new file: tests/st/framework/stub_op_proto/util/error_code.h new file: tests/st/framework/stub_op_proto/util/error_util.cc new file: tests/st/framework/stub_op_proto/util/error_util.h new file: tests/st/framework/stub_op_proto/util/op_common_util.h new file: tests/st/framework/stub_op_proto/util/op_log.h new file: tests/st/framework/stub_op_proto/util/transfer_shape_according_to_format.cc new file: tests/st/framework/stub_op_proto/util/transfer_shape_according_to_format.h new file: tests/st/framework/stub_op_proto/util/util.cc new file: tests/st/framework/stub_op_proto/util/util.h new file: tests/st/framework/utils/assertion/graph_assertion.cc new file: tests/st/framework/utils/assertion/graph_assertion.h new file: tests/st/framework/utils/builder/graph_builder_utils.cc new file: tests/st/framework/utils/builder/graph_builder_utils.h new file: tests/st/framework/utils/builder/tensor_builder_utils.cc new file: tests/st/framework/utils/builder/tensor_builder_utils.h new file: tests/st/test.cc new file: tests/st/testcase/CMakeLists.txt new file: tests/st/testcase/test_framework_dummy.cc modified: CMakeLists.txt modified: build.sh modified: ge/ge_runtime/runtime_model.cc modified: ge/ge_runtime/task/aicpu_task.cc modified: ge/ge_runtime/task/hccl_task.cc modified: ge/ge_runtime/task/label_goto_task.cc modified: ge/ge_runtime/task/label_switch_task.cc new file: tests/st/CMakeLists.txt new file: tests/st/cmake/graphengine.cmake new file: tests/st/framework/CMakeLists.txt new file: tests/st/framework/framework.cc new file: tests/st/framework/framework.h new file: tests/st/framework/stub_engine/CMakeLists.txt new file: tests/st/framework/stub_engine/common/constant/constant.h new file: tests/st/framework/stub_engine/engine/stub_engine.cc new file: tests/st/framework/stub_engine/engine/stub_engine.h new file: tests/st/framework/stub_engine/ops_kernel_store/host_cpu_ops_kernel_builder.cc new file: tests/st/framework/stub_engine/ops_kernel_store/host_cpu_ops_kernel_builder.h new file: tests/st/framework/stub_engine/ops_kernel_store/host_cpu_ops_kernel_info.cc new file: tests/st/framework/stub_engine/ops_kernel_store/host_cpu_ops_kernel_info.h new file: tests/st/framework/stub_engine/ops_kernel_store/op/host_op.cc new file: tests/st/framework/stub_engine/ops_kernel_store/op/host_op.h new file: tests/st/framework/stub_engine/ops_kernel_store/op/op.h new file: tests/st/framework/stub_engine/ops_kernel_store/op/op_factory.cc new file: tests/st/framework/stub_engine/ops_kernel_store/op/op_factory.h new file: tests/st/framework/stub_engine/proto/task.proto new file: tests/st/framework/stub_op_proto/array_ops.cc new file: tests/st/framework/stub_op_proto/array_ops.h new file: tests/st/framework/stub_op_proto/control_flow_ops.cc new file: tests/st/framework/stub_op_proto/control_flow_ops.h new file: tests/st/framework/stub_op_proto/elewise_calculation_ops.cc new file: tests/st/framework/stub_op_proto/elewise_calculation_ops.h new file: tests/st/framework/stub_op_proto/util/array_ops_shape_fns.cc new file: tests/st/framework/stub_op_proto/util/array_ops_shape_fns.h new file: tests/st/framework/stub_op_proto/util/axis_util.cc new file: tests/st/framework/stub_op_proto/util/axis_util.h new file: tests/st/framework/stub_op_proto/util/common_shape_fns.cc new file: tests/st/framework/stub_op_proto/util/common_shape_fns.h new file: tests/st/framework/stub_op_proto/util/error_code.h new file: tests/st/framework/stub_op_proto/util/error_util.cc new file: tests/st/framework/stub_op_proto/util/error_util.h new file: tests/st/framework/stub_op_proto/util/op_common_util.h new file: tests/st/framework/stub_op_proto/util/op_log.h new file: tests/st/framework/stub_op_proto/util/transfer_shape_according_to_format.cc new file: tests/st/framework/stub_op_proto/util/transfer_shape_according_to_format.h new file: tests/st/framework/stub_op_proto/util/util.cc new file: tests/st/framework/stub_op_proto/util/util.h new file: tests/st/framework/utils/assertion/graph_assertion.cc new file: tests/st/framework/utils/assertion/graph_assertion.h new file: tests/st/framework/utils/builder/graph_builder_utils.cc new file: tests/st/framework/utils/builder/graph_builder_utils.h new file: tests/st/framework/utils/builder/tensor_builder_utils.cc new file: tests/st/framework/utils/builder/tensor_builder_utils.h new file: tests/st/test.cc new file: tests/st/testcase/CMakeLists.txt new file: tests/st/testcase/test_framework_dummy.cc modified: CMakeLists.txt modified: build.sh modified: ge/ge_runtime/runtime_model.cc modified: ge/ge_runtime/task/aicpu_task.cc modified: ge/ge_runtime/task/hccl_task.cc modified: ge/ge_runtime/task/label_goto_task.cc modified: ge/ge_runtime/task/label_switch_task.cc new file: tests/st/CMakeLists.txt new file: tests/st/cmake/graphengine.cmake new file: tests/st/framework/CMakeLists.txt new file: tests/st/framework/framework.cc new file: tests/st/framework/framework.h new file: tests/st/framework/stub_engine/CMakeLists.txt new file: tests/st/framework/stub_engine/common/constant/constant.h new file: tests/st/framework/stub_engine/engine/stub_engine.cc new file: tests/st/framework/stub_engine/engine/stub_engine.h new file: tests/st/framework/stub_engine/ops_kernel_store/host_cpu_ops_kernel_builder.cc new file: tests/st/framework/stub_engine/ops_kernel_store/host_cpu_ops_kernel_builder.h new file: tests/st/framework/stub_engine/ops_kernel_store/host_cpu_ops_kernel_info.cc new file: tests/st/framework/stub_engine/ops_kernel_store/host_cpu_ops_kernel_info.h new file: tests/st/framework/stub_engine/ops_kernel_store/op/host_op.cc new file: tests/st/framework/stub_engine/ops_kernel_store/op/host_op.h new file: tests/st/framework/stub_engine/ops_kernel_store/op/op.h new file: tests/st/framework/stub_engine/ops_kernel_store/op/op_factory.cc new file: tests/st/framework/stub_engine/ops_kernel_store/op/op_factory.h new file: tests/st/framework/stub_engine/proto/task.proto new file: tests/st/framework/stub_op_proto/array_ops.cc new file: tests/st/framework/stub_op_proto/array_ops.h new file: tests/st/framework/stub_op_proto/control_flow_ops.cc new file: tests/st/framework/stub_op_proto/control_flow_ops.h new file: tests/st/framework/stub_op_proto/elewise_calculation_ops.cc new file: tests/st/framework/stub_op_proto/elewise_calculation_ops.h new file: tests/st/framework/stub_op_proto/util/array_ops_shape_fns.cc new file: tests/st/framework/stub_op_proto/util/array_ops_shape_fns.h new file: tests/st/framework/stub_op_proto/util/axis_util.cc new file: tests/st/framework/stub_op_proto/util/axis_util.h new file: tests/st/framework/stub_op_proto/util/common_shape_fns.cc new file: tests/st/framework/stub_op_proto/util/common_shape_fns.h new file: tests/st/framework/stub_op_proto/util/error_code.h new file: tests/st/framework/stub_op_proto/util/error_util.cc new file: tests/st/framework/stub_op_proto/util/error_util.h new file: tests/st/framework/stub_op_proto/util/op_common_util.h new file: tests/st/framework/stub_op_proto/util/op_log.h new file: tests/st/framework/stub_op_proto/util/transfer_shape_according_to_format.cc new file: tests/st/framework/stub_op_proto/util/transfer_shape_according_to_format.h new file: tests/st/framework/stub_op_proto/util/util.cc new file: tests/st/framework/stub_op_proto/util/util.h new file: tests/st/framework/utils/assertion/graph_assertion.cc new file: tests/st/framework/utils/assertion/graph_assertion.h new file: tests/st/framework/utils/builder/graph_builder_utils.cc new file: tests/st/framework/utils/builder/graph_builder_utils.h new file: tests/st/framework/utils/builder/tensor_builder_utils.cc new file: tests/st/framework/utils/builder/tensor_builder_utils.h new file: tests/st/testcase/CMakeLists.txt new file: tests/st/testcase/test_framework_dummy.cc
4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363
  1. /**
  2. * Copyright 2021 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. /*!
  17. * \file util.h
  18. * \brief
  19. */
  20. #ifndef OPS_BUILT_IN_OP_PROTO_UTIL_UTIL_H_
  21. #define OPS_BUILT_IN_OP_PROTO_UTIL_UTIL_H_
  22. #include <memory.h>
  23. #include <string>
  24. #include <vector>
  25. #include <map>
  26. #include <algorithm>
  27. #include "framework/omg/omg_inner_types.h"
  28. #include "operator.h"
  29. #include "graph/operator_reg.h"
  30. #include "graph/operator_reg.h"
  31. #include "transfer_shape_according_to_format.h"
  32. #include "graph/utils/op_desc_utils.h"
  33. #include "graph/utils/tensor_utils.h"
  34. #include "graph/utils/node_utils.h"
  35. #include "graph/tensor.h"
  36. #include "graph/node.h"
  37. #include "graph/ge_tensor.h"
  38. #include "op_log.h"
  39. #define LOG_ERROR(format, args...) printf(format, ##args)
  40. namespace ge {
  41. // enum type and string type mapping
  42. static const std::map<ge::DataType, std::string> DTYPE_STR_MAP{
  43. {ge::DT_FLOAT16, "float16"}, {ge::DT_FLOAT, "float32"}, {ge::DT_INT8, "int8"}, {ge::DT_INT16, "int16"},
  44. {ge::DT_INT32, "int32"}, {ge::DT_INT64, "int64"}, {ge::DT_UINT8, "uint8"}, {ge::DT_UINT16, "uint16"},
  45. {ge::DT_UINT32, "uint32"}, {ge::DT_UINT64, "uint64"}, {ge::DT_BOOL, "bool"}};
  46. // define the input num of shape
  47. const size_t INPUT_NUM0 = 0;
  48. const size_t INPUT_NUM1 = 1;
  49. const size_t INPUT_NUM2 = 2;
  50. const size_t INPUT_NUM3 = 3;
  51. const size_t INPUT_NUM4 = 4;
  52. const size_t INPUT_NUM5 = 5;
  53. const size_t INPUT_NUM6 = 6;
  54. const size_t INPUT_NUM7 = 7;
  55. const size_t INPUT_NUM8 = 8;
  56. const size_t INPUT_NUM9 = 9;
  57. // define the dims size of shape
  58. const size_t DIM_SIZE0 = 0;
  59. const size_t DIM_SIZE1 = 1;
  60. const size_t DIM_SIZE2 = 2;
  61. const size_t DIM_SIZE3 = 3;
  62. const size_t DIM_SIZE4 = 4;
  63. const size_t DIM_SIZE5 = 5;
  64. const size_t DIM_SIZE6 = 6;
  65. const size_t DIM_SIZE7 = 7;
  66. const size_t DIM_SIZE8 = 8;
  67. // define the index of shape dim
  68. const size_t DIM_INDEX0 = 0;
  69. const size_t DIM_INDEX1 = 1;
  70. const size_t DIM_INDEX2 = 2;
  71. const size_t DIM_INDEX3 = 3;
  72. const size_t DIM_INDEX4 = 4;
  73. const size_t DIM_INDEX5 = 5;
  74. const size_t DIM_INDEX6 = 6;
  75. const size_t DIM_INDEX7 = 7;
  76. const size_t DIM_INDEX8 = 8;
  77. /*
  78. * get the datatype of input
  79. * param[in] dataType input datatype of enum value
  80. * param[in] supportList the support range of op
  81. * return true :get type success
  82. * false:get type failed
  83. */
  84. bool GetInputDataType(const ge::DataType& data_type, const std::vector<ge::DataType>& supportList);
  85. bool GetInputDataType(const ge::DataType& dataType, const std::vector<ge::DataType>& supportList, std::string& dType);
  86. /* infer shape of two input and on output with broadcast
  87. * param[in] op op desc supply by ge
  88. * param[in] inputName1 first input name
  89. * param[in] inputName2 second input name
  90. * param[in] outputName output name
  91. * return SUCCESS:infer success
  92. * FAILED:infer failed like unsupported broadcast input shape
  93. */
  94. bool CheckInputDataType(const Operator& op, const std::string& input_name,
  95. const std::vector<ge::DataType>& support_list);
  96. /*
  97. * check the datatype and shape of input
  98. * param[in] op the operator
  99. * param[in] inputTensorMap the map of input name and support datatype
  100. * param[in] paramType the mode of input param, tensor or scalar
  101. * return true
  102. * false
  103. */
  104. bool CheckInputDtypeAndShape(const Operator& op, const std::map<std::string, std::vector<DataType>>& inputTensorMap);
  105. /*
  106. * infer shape of two input and on output with broadcast
  107. * param[in] op op desc supply by ge
  108. * param[in] inputName1 first input name
  109. * param[in] inputName2 second input name
  110. * param[in] outputName output name
  111. * return SUCCESS:infer success
  112. * FAILED:infer failed like unsupported broadcast input shape
  113. */
  114. bool InferShapeAndTypeTwoInOneOutBroadcast(Operator& op, const string& input_name1, const string& input_name2,
  115. const string& output_name);
  116. /*
  117. * infer shape of two input and on output with broadcast
  118. * param[in] op op desc supply by ge
  119. * param[in] inputName1 first input name
  120. * param[in] inputName2 second input name
  121. * param[in] outputName output name
  122. * param[in] is_dynamic whether the shape of output is dynamic shape
  123. * return SUCCESS:infer success
  124. * FAILED:infer failed like unsupported broadcast input shape
  125. */
  126. bool InferShapeAndTypeTwoInOneOutBroadcast(Operator& op, const string& input_name1, const string& input_name2,
  127. const string& output_name, bool& is_dynamic);
  128. bool InferShapeRangeTwoInOneOutBroadcase(Operator& op, const string& input_name1, const string& input_name2,
  129. const string& output_name);
  130. bool CheckInputDataType(const Operator& op, std::string* data_type, const std::string& input_name,
  131. const std::vector<ge::DataType>& supportList);
  132. bool CheckTwoInputDtypeSame(const Operator& op, const string& input_name1, const string& input_name2);
  133. bool CheckInputDtypeSame(const Operator& op, std::vector<std::string>& input_tensors);
  134. bool CheckInputsShapeDtypeSame(const Operator& op, const std::vector<std::string>& input_names);
  135. bool GetConstValue(const ge::Operator& op, const std::string& key_name, float& attr_value);
  136. bool GetConstValue(const ge::Operator& op, const std::string& key_name, int64_t& attr_value);
  137. bool GetConstValue(const ge::Operator& op, const std::string& key_name, bool& attr_value);
  138. bool GetConstValue(const ge::Operator& op, const std::string& key_name, std::vector<int32_t>& attr_value);
  139. /**
  140. * Get int type const value from tensor data
  141. * @param [in] data const tensor data
  142. * @param [in] data_type DT_INT8, DT_INT16, DT_INT32, DT_INT64
  143. * @param [out] const_values const int values
  144. * @return true:success, false:failed.
  145. */
  146. bool GetConstIntData(const Tensor& data, DataType data_type, std::vector<int64_t>& const_values);
  147. bool GetConstValue(const Operator& op, const Tensor& const_tensor, const DataType& dtype,
  148. std::vector<int64_t>& const_data);
  149. bool GetConstValue(const Operator& op, const GeTensorPtr& const_tensor, const DataType& dtype,
  150. std::vector<int64_t>& const_data);
  151. bool GetScalerValue(const Operator& op, const Tensor& const_tensor, const DataType& dtype, std::int64_t& const_data);
  152. bool InferShapeAndTypeTwoInOneOutBroadcast(Operator& op, const string& input_name1, const string& input_name2,
  153. const string& output_name);
  154. /*
  155. * Check input dtype and format is supported in supportList from inputNumBeg to inputNumEnd
  156. * param[in] op op desc supply by ge
  157. * param[in] inputNumBeg input index begin, [0, N]
  158. * param[in] inputNumEnd input index end need to be checked
  159. * param[in] supportList, support type of ge::DataType and ge::Format
  160. * return true: check pass
  161. * false: check failed
  162. */
  163. template <typename T>
  164. bool CheckSimilarInputDtypeAndFormat(const Operator& op, std::size_t inputNumBeg, std::size_t inputNumEnd,
  165. const std::vector<T>& supportList) {
  166. for (std::size_t i = inputNumBeg; i < inputNumEnd; i++) {
  167. if (std::is_same<typename std::decay<T>::type, ge::DataType>::value) {
  168. ge::DataType inType = op.GetInputDesc(i).GetDataType();
  169. const auto& findDtype = std::find(supportList.begin(), supportList.end(), inType);
  170. if (findDtype == supportList.end()) {
  171. return false;
  172. }
  173. } else if (std::is_same<typename std::decay<T>::type, ge::Format>::value) {
  174. ge::Format inType = op.GetInputDesc(i).GetFormat();
  175. const auto& findDtype = std::find(supportList.begin(), supportList.end(), inType);
  176. if (findDtype == supportList.end()) {
  177. return false;
  178. }
  179. }
  180. }
  181. return true;
  182. }
  183. /*
  184. * Check input dtype and format is supported in supportList from inputNumBeg to inputNumEnd
  185. * param[in] op op desc supply by ge
  186. * param[in] indexNeedCheck input index need to be checked
  187. * param[in] supportList, support type of ge::DataType and ge::Format
  188. * return true: check pass
  189. * false: check failed
  190. */
  191. template <typename T>
  192. bool CheckSimilarInputDtypeAndFormat(const Operator& op, const std::vector<std::size_t>& indexNeedCheck,
  193. const std::vector<T>& supportList) {
  194. for (auto i : indexNeedCheck) {
  195. if (std::is_same<typename std::decay<T>::type, ge::DataType>::value) {
  196. ge::DataType inType = op.GetInputDesc(i).GetDataType();
  197. const auto& findDtype = std::find(supportList.begin(), supportList.end(), inType);
  198. if (findDtype == supportList.end()) {
  199. return false;
  200. }
  201. } else if (std::is_same<typename std::decay<T>::type, ge::Format>::value) {
  202. ge::Format inType = op.GetInputDesc(i).GetFormat();
  203. const auto& findDtype = std::find(supportList.begin(), supportList.end(), inType);
  204. if (findDtype == supportList.end()) {
  205. return false;
  206. }
  207. }
  208. }
  209. return true;
  210. }
  211. /*
  212. * get const attr
  213. * param[in] op op desc supply by ge
  214. * param[in] attrName list need to be get
  215. * param[out] attr vector
  216. * return true: get success
  217. * false: get failed
  218. */
  219. template <typename T>
  220. bool GetConstAttr(const Operator& op, const std::vector<std::string>& attrNameList, std::vector<T>& attrVec) {
  221. T value;
  222. for (auto name : attrNameList) {
  223. if (op.GetAttr(name, value) != ge::GRAPH_SUCCESS) {
  224. return false;
  225. }
  226. attrVec.push_back(value);
  227. }
  228. return true;
  229. }
  230. /*
  231. * get const attr list
  232. * param[in] op op desc supply by ge
  233. * param[in] attrName list need to be get
  234. * param[out] attr vector
  235. * return true: get success
  236. * false: get failed
  237. */
  238. template <typename T>
  239. bool GetConstAttr(const Operator& op, const std::vector<std::string>& attrNameList,
  240. std::vector<std::vector<T>>& attrListVec) {
  241. for (auto name : attrNameList) {
  242. std::vector<T> valueList;
  243. if (op.GetAttr(name, valueList) != ge::GRAPH_SUCCESS) {
  244. return false;
  245. }
  246. attrListVec.push_back(valueList);
  247. }
  248. return true;
  249. }
  250. std::string to_string(const vector<int64_t>& shape);
  251. std::string to_string(const ge::Shape& shape);
  252. std::string to_string(const ge::GeShape& shape);
  253. std::string to_string(const vector<pair<int64_t, int64_t>>& ranges);
  254. class DynamicShapeInfer {
  255. public:
  256. std::map<std::string, Format> map_format;
  257. std::map<std::string, DataType> map_dtype;
  258. std::map<std::string, uint32_t> inputs;
  259. std::map<std::string, uint32_t> outputs;
  260. Operator& op;
  261. OpDescPtr& op_desc;
  262. DynamicShapeInfer(Operator& op_v, OpDescPtr& opDesc_v) : op(op_v), op_desc(opDesc_v) {
  263. }
  264. bool CatchFormatAndShape();
  265. bool UpdateFormatAndShape();
  266. ~DynamicShapeInfer() {
  267. UpdateFormatAndShape();
  268. }
  269. };
  270. #define PREPARE_DYNAMIC_SHAPE(depends_names) auto op_desc = OpDescUtils::GetOpDescFromOperator(op);\
  271. do { \
  272. if (!depends_names.empty()) { \
  273. op_desc->SetOpInferDepends(depends_names); \
  274. } \
  275. } while(0)
  276. bool IsEmptyTensor(const std::vector<int64_t>& dims);
  277. bool IsUnknownRank(const Operator& op, const std::string& tensor_name, const std::string& types = "input");
  278. bool IsUnknownRankShape(const std::vector<int64_t>& shape_vec);
  279. bool IsUnKnownShape(const std::vector<int64_t>& shape_vec);
  280. bool IsUnknownShape(const Operator& op, const std::string& tensor_name, const std::string& types = "input");
  281. bool IsUnknownVec(std::vector<int64_t>& shape_vec);
  282. bool IsUnknown(const std::vector<int64_t>& shape_vec);
  283. void MakeUpShapeRange(const std::vector<int64_t>& shape, std::vector<std::pair<int64_t, int64_t>>& range);
  284. std::string DataTypeToStringDesc(const ge::DataType& dataType);
  285. bool OneInOneOutDynamicInfer(const Operator& op,
  286. const std::string& input_name,
  287. const std::vector<std::string>& output_name_list);
  288. bool TwoInOneOutDynamicInferNoBroadcast(Operator& op,
  289. const string& input1_name,
  290. const string& input2_name,
  291. const std::vector<string>& output_name_list);
  292. void FixShapeRangeWithDims(const std::vector<int64_t>& dims,
  293. std::vector<int64_t>& shape_1,
  294. std::vector<int64_t>& shape_2,
  295. std::vector<std::pair<int64_t, int64_t>>& range_1,
  296. std::vector<std::pair<int64_t, int64_t>>& range_2);
  297. bool SetScalarOutputDesc(const string& input,
  298. const string& output,
  299. OpDescPtr op_desc,
  300. GeShape& output_shape);
  301. namespace array_ops {
  302. bool CheckInt64MulOverflow(int64_t a, int64_t b);
  303. void ReshapeRangeInfer(const Operator &op, const std::vector<std::pair<int64_t, int64_t>>& x_range,
  304. int64_t& range_max);
  305. void ReshapeRangeInfer(const Operator &op, const std::vector<std::pair<int64_t, int64_t>>& x_range,
  306. std::vector<std::pair<int64_t, int64_t>>& y_range, GeShape& output_shape);
  307. }
  308. } // namespace ge
  309. #endif // OPS_BUILT_IN_OP_PROTO_UTIL_UTIL_H_

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示