You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

util.cc 43 kB

modified: CMakeLists.txt modified: build.sh modified: ge/ge_runtime/runtime_model.cc modified: ge/ge_runtime/task/aicpu_task.cc modified: ge/ge_runtime/task/hccl_task.cc modified: ge/ge_runtime/task/label_goto_task.cc modified: ge/ge_runtime/task/label_switch_task.cc new file: tests/st/CMakeLists.txt new file: tests/st/cmake/graphengine.cmake new file: tests/st/framework/CMakeLists.txt new file: tests/st/framework/framework.cc new file: tests/st/framework/framework.h new file: tests/st/framework/stub_engine/CMakeLists.txt new file: tests/st/framework/stub_engine/common/constant/constant.h new file: tests/st/framework/stub_engine/engine/stub_engine.cc new file: tests/st/framework/stub_engine/engine/stub_engine.h new file: tests/st/framework/stub_engine/ops_kernel_store/host_cpu_ops_kernel_builder.cc new file: tests/st/framework/stub_engine/ops_kernel_store/host_cpu_ops_kernel_builder.h new file: tests/st/framework/stub_engine/ops_kernel_store/host_cpu_ops_kernel_info.cc new file: tests/st/framework/stub_engine/ops_kernel_store/host_cpu_ops_kernel_info.h new file: tests/st/framework/stub_engine/ops_kernel_store/op/host_op.cc new file: tests/st/framework/stub_engine/ops_kernel_store/op/host_op.h new file: tests/st/framework/stub_engine/ops_kernel_store/op/op.h new file: tests/st/framework/stub_engine/ops_kernel_store/op/op_factory.cc new file: tests/st/framework/stub_engine/ops_kernel_store/op/op_factory.h new file: tests/st/framework/stub_engine/proto/task.proto new file: tests/st/framework/stub_op_proto/array_ops.cc new file: tests/st/framework/stub_op_proto/array_ops.h new file: tests/st/framework/stub_op_proto/control_flow_ops.cc new file: tests/st/framework/stub_op_proto/control_flow_ops.h new file: tests/st/framework/stub_op_proto/elewise_calculation_ops.cc new file: tests/st/framework/stub_op_proto/elewise_calculation_ops.h new file: tests/st/framework/stub_op_proto/util/array_ops_shape_fns.cc new file: tests/st/framework/stub_op_proto/util/array_ops_shape_fns.h new file: tests/st/framework/stub_op_proto/util/axis_util.cc new file: tests/st/framework/stub_op_proto/util/axis_util.h new file: tests/st/framework/stub_op_proto/util/common_shape_fns.cc new file: tests/st/framework/stub_op_proto/util/common_shape_fns.h new file: tests/st/framework/stub_op_proto/util/error_code.h new file: tests/st/framework/stub_op_proto/util/error_util.cc new file: tests/st/framework/stub_op_proto/util/error_util.h new file: tests/st/framework/stub_op_proto/util/op_common_util.h new file: tests/st/framework/stub_op_proto/util/op_log.h new file: tests/st/framework/stub_op_proto/util/transfer_shape_according_to_format.cc new file: tests/st/framework/stub_op_proto/util/transfer_shape_according_to_format.h new file: tests/st/framework/stub_op_proto/util/util.cc new file: tests/st/framework/stub_op_proto/util/util.h new file: tests/st/framework/utils/assertion/graph_assertion.cc new file: tests/st/framework/utils/assertion/graph_assertion.h new file: tests/st/framework/utils/builder/graph_builder_utils.cc new file: tests/st/framework/utils/builder/graph_builder_utils.h new file: tests/st/framework/utils/builder/tensor_builder_utils.cc new file: tests/st/framework/utils/builder/tensor_builder_utils.h new file: tests/st/test.cc new file: tests/st/testcase/CMakeLists.txt new file: tests/st/testcase/test_framework_dummy.cc modified: CMakeLists.txt modified: build.sh modified: ge/ge_runtime/runtime_model.cc modified: ge/ge_runtime/task/aicpu_task.cc modified: ge/ge_runtime/task/hccl_task.cc modified: ge/ge_runtime/task/label_goto_task.cc modified: ge/ge_runtime/task/label_switch_task.cc new file: tests/st/CMakeLists.txt new file: tests/st/cmake/graphengine.cmake new file: tests/st/framework/CMakeLists.txt new file: tests/st/framework/framework.cc new file: tests/st/framework/framework.h new file: tests/st/framework/stub_engine/CMakeLists.txt new file: tests/st/framework/stub_engine/common/constant/constant.h new file: tests/st/framework/stub_engine/engine/stub_engine.cc new file: tests/st/framework/stub_engine/engine/stub_engine.h new file: tests/st/framework/stub_engine/ops_kernel_store/host_cpu_ops_kernel_builder.cc new file: tests/st/framework/stub_engine/ops_kernel_store/host_cpu_ops_kernel_builder.h new file: tests/st/framework/stub_engine/ops_kernel_store/host_cpu_ops_kernel_info.cc new file: tests/st/framework/stub_engine/ops_kernel_store/host_cpu_ops_kernel_info.h new file: tests/st/framework/stub_engine/ops_kernel_store/op/host_op.cc new file: tests/st/framework/stub_engine/ops_kernel_store/op/host_op.h new file: tests/st/framework/stub_engine/ops_kernel_store/op/op.h new file: tests/st/framework/stub_engine/ops_kernel_store/op/op_factory.cc new file: tests/st/framework/stub_engine/ops_kernel_store/op/op_factory.h new file: tests/st/framework/stub_engine/proto/task.proto new file: tests/st/framework/stub_op_proto/array_ops.cc new file: tests/st/framework/stub_op_proto/array_ops.h new file: tests/st/framework/stub_op_proto/control_flow_ops.cc new file: tests/st/framework/stub_op_proto/control_flow_ops.h new file: tests/st/framework/stub_op_proto/elewise_calculation_ops.cc new file: tests/st/framework/stub_op_proto/elewise_calculation_ops.h new file: tests/st/framework/stub_op_proto/util/array_ops_shape_fns.cc new file: tests/st/framework/stub_op_proto/util/array_ops_shape_fns.h new file: tests/st/framework/stub_op_proto/util/axis_util.cc new file: tests/st/framework/stub_op_proto/util/axis_util.h new file: tests/st/framework/stub_op_proto/util/common_shape_fns.cc new file: tests/st/framework/stub_op_proto/util/common_shape_fns.h new file: tests/st/framework/stub_op_proto/util/error_code.h new file: tests/st/framework/stub_op_proto/util/error_util.cc new file: tests/st/framework/stub_op_proto/util/error_util.h new file: tests/st/framework/stub_op_proto/util/op_common_util.h new file: tests/st/framework/stub_op_proto/util/op_log.h new file: tests/st/framework/stub_op_proto/util/transfer_shape_according_to_format.cc new file: tests/st/framework/stub_op_proto/util/transfer_shape_according_to_format.h new file: tests/st/framework/stub_op_proto/util/util.cc new file: tests/st/framework/stub_op_proto/util/util.h new file: tests/st/framework/utils/assertion/graph_assertion.cc new file: tests/st/framework/utils/assertion/graph_assertion.h new file: tests/st/framework/utils/builder/graph_builder_utils.cc new file: tests/st/framework/utils/builder/graph_builder_utils.h new file: tests/st/framework/utils/builder/tensor_builder_utils.cc new file: tests/st/framework/utils/builder/tensor_builder_utils.h new file: tests/st/test.cc new file: tests/st/testcase/CMakeLists.txt new file: tests/st/testcase/test_framework_dummy.cc modified: CMakeLists.txt modified: build.sh modified: ge/ge_runtime/runtime_model.cc modified: ge/ge_runtime/task/aicpu_task.cc modified: ge/ge_runtime/task/hccl_task.cc modified: ge/ge_runtime/task/label_goto_task.cc modified: ge/ge_runtime/task/label_switch_task.cc new file: tests/st/CMakeLists.txt new file: tests/st/cmake/graphengine.cmake new file: tests/st/framework/CMakeLists.txt new file: tests/st/framework/framework.cc new file: tests/st/framework/framework.h new file: tests/st/framework/stub_engine/CMakeLists.txt new file: tests/st/framework/stub_engine/common/constant/constant.h new file: tests/st/framework/stub_engine/engine/stub_engine.cc new file: tests/st/framework/stub_engine/engine/stub_engine.h new file: tests/st/framework/stub_engine/ops_kernel_store/host_cpu_ops_kernel_builder.cc new file: tests/st/framework/stub_engine/ops_kernel_store/host_cpu_ops_kernel_builder.h new file: tests/st/framework/stub_engine/ops_kernel_store/host_cpu_ops_kernel_info.cc new file: tests/st/framework/stub_engine/ops_kernel_store/host_cpu_ops_kernel_info.h new file: tests/st/framework/stub_engine/ops_kernel_store/op/host_op.cc new file: tests/st/framework/stub_engine/ops_kernel_store/op/host_op.h new file: tests/st/framework/stub_engine/ops_kernel_store/op/op.h new file: tests/st/framework/stub_engine/ops_kernel_store/op/op_factory.cc new file: tests/st/framework/stub_engine/ops_kernel_store/op/op_factory.h new file: tests/st/framework/stub_engine/proto/task.proto new file: tests/st/framework/stub_op_proto/array_ops.cc new file: tests/st/framework/stub_op_proto/array_ops.h new file: tests/st/framework/stub_op_proto/control_flow_ops.cc new file: tests/st/framework/stub_op_proto/control_flow_ops.h new file: tests/st/framework/stub_op_proto/elewise_calculation_ops.cc new file: tests/st/framework/stub_op_proto/elewise_calculation_ops.h new file: tests/st/framework/stub_op_proto/util/array_ops_shape_fns.cc new file: tests/st/framework/stub_op_proto/util/array_ops_shape_fns.h new file: tests/st/framework/stub_op_proto/util/axis_util.cc new file: tests/st/framework/stub_op_proto/util/axis_util.h new file: tests/st/framework/stub_op_proto/util/common_shape_fns.cc new file: tests/st/framework/stub_op_proto/util/common_shape_fns.h new file: tests/st/framework/stub_op_proto/util/error_code.h new file: tests/st/framework/stub_op_proto/util/error_util.cc new file: tests/st/framework/stub_op_proto/util/error_util.h new file: tests/st/framework/stub_op_proto/util/op_common_util.h new file: tests/st/framework/stub_op_proto/util/op_log.h new file: tests/st/framework/stub_op_proto/util/transfer_shape_according_to_format.cc new file: tests/st/framework/stub_op_proto/util/transfer_shape_according_to_format.h new file: tests/st/framework/stub_op_proto/util/util.cc new file: tests/st/framework/stub_op_proto/util/util.h new file: tests/st/framework/utils/assertion/graph_assertion.cc new file: tests/st/framework/utils/assertion/graph_assertion.h new file: tests/st/framework/utils/builder/graph_builder_utils.cc new file: tests/st/framework/utils/builder/graph_builder_utils.h new file: tests/st/framework/utils/builder/tensor_builder_utils.cc new file: tests/st/framework/utils/builder/tensor_builder_utils.h new file: tests/st/test.cc new file: tests/st/testcase/CMakeLists.txt new file: tests/st/testcase/test_framework_dummy.cc modified: CMakeLists.txt modified: build.sh modified: ge/ge_runtime/runtime_model.cc modified: ge/ge_runtime/task/aicpu_task.cc modified: ge/ge_runtime/task/hccl_task.cc modified: ge/ge_runtime/task/label_goto_task.cc modified: ge/ge_runtime/task/label_switch_task.cc new file: tests/st/CMakeLists.txt new file: tests/st/cmake/graphengine.cmake new file: tests/st/framework/CMakeLists.txt new file: tests/st/framework/framework.cc new file: tests/st/framework/framework.h new file: tests/st/framework/stub_engine/CMakeLists.txt new file: tests/st/framework/stub_engine/common/constant/constant.h new file: tests/st/framework/stub_engine/engine/stub_engine.cc new file: tests/st/framework/stub_engine/engine/stub_engine.h new file: tests/st/framework/stub_engine/ops_kernel_store/host_cpu_ops_kernel_builder.cc new file: tests/st/framework/stub_engine/ops_kernel_store/host_cpu_ops_kernel_builder.h new file: tests/st/framework/stub_engine/ops_kernel_store/host_cpu_ops_kernel_info.cc new file: tests/st/framework/stub_engine/ops_kernel_store/host_cpu_ops_kernel_info.h new file: tests/st/framework/stub_engine/ops_kernel_store/op/host_op.cc new file: tests/st/framework/stub_engine/ops_kernel_store/op/host_op.h new file: tests/st/framework/stub_engine/ops_kernel_store/op/op.h new file: tests/st/framework/stub_engine/ops_kernel_store/op/op_factory.cc new file: tests/st/framework/stub_engine/ops_kernel_store/op/op_factory.h new file: tests/st/framework/stub_engine/proto/task.proto new file: tests/st/framework/stub_op_proto/array_ops.cc new file: tests/st/framework/stub_op_proto/array_ops.h new file: tests/st/framework/stub_op_proto/control_flow_ops.cc new file: tests/st/framework/stub_op_proto/control_flow_ops.h new file: tests/st/framework/stub_op_proto/elewise_calculation_ops.cc new file: tests/st/framework/stub_op_proto/elewise_calculation_ops.h new file: tests/st/framework/stub_op_proto/util/array_ops_shape_fns.cc new file: tests/st/framework/stub_op_proto/util/array_ops_shape_fns.h new file: tests/st/framework/stub_op_proto/util/axis_util.cc new file: tests/st/framework/stub_op_proto/util/axis_util.h new file: tests/st/framework/stub_op_proto/util/common_shape_fns.cc new file: tests/st/framework/stub_op_proto/util/common_shape_fns.h new file: tests/st/framework/stub_op_proto/util/error_code.h new file: tests/st/framework/stub_op_proto/util/error_util.cc new file: tests/st/framework/stub_op_proto/util/error_util.h new file: tests/st/framework/stub_op_proto/util/op_common_util.h new file: tests/st/framework/stub_op_proto/util/op_log.h new file: tests/st/framework/stub_op_proto/util/transfer_shape_according_to_format.cc new file: tests/st/framework/stub_op_proto/util/transfer_shape_according_to_format.h new file: tests/st/framework/stub_op_proto/util/util.cc new file: tests/st/framework/stub_op_proto/util/util.h new file: tests/st/framework/utils/assertion/graph_assertion.cc new file: tests/st/framework/utils/assertion/graph_assertion.h new file: tests/st/framework/utils/builder/graph_builder_utils.cc new file: tests/st/framework/utils/builder/graph_builder_utils.h new file: tests/st/framework/utils/builder/tensor_builder_utils.cc new file: tests/st/framework/utils/builder/tensor_builder_utils.h new file: tests/st/testcase/CMakeLists.txt new file: tests/st/testcase/test_framework_dummy.cc
4 years ago
12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097
  1. /**
  2. * Copyright 2021 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. /*!
  17. * \file util.cpp
  18. * \brief
  19. */
  20. #include "util.h"
  21. #include <string>
  22. #include <vector>
  23. #include <map>
  24. #include <functional>
  25. #include <algorithm>
  26. #include "./error_util.h"
  27. #include "op_common_util.h"
  28. #include "graph/utils/type_utils.h"
  29. #include "axis_util.h"
  30. namespace ge {
  31. bool GetInputDataType(const ge::DataType& data_type, const std::vector<ge::DataType>& supportList) {
  32. std::vector<ge::DataType>::const_iterator supportIter = find(supportList.begin(), supportList.end(), data_type);
  33. if (supportIter == supportList.end()) {
  34. return false;
  35. }
  36. return true;
  37. }
  38. bool CheckInputDtypeAndShape(const Operator& op, const std::map<std::string, std::vector<DataType>>& inputTensorMap) {
  39. auto iter = inputTensorMap.begin();
  40. auto first_name = iter->first;
  41. auto first_shape_dims = op.GetInputDesc(iter->first).GetShape().GetDims();
  42. auto first_input_dtype = op.GetInputDesc(iter->first).GetDataType();
  43. for (; iter != inputTensorMap.end(); ++iter) {
  44. const TensorDesc input_desc = op.GetInputDesc(iter->first);
  45. // check input dtype
  46. auto input_type = input_desc.GetDataType();
  47. if (input_type != first_input_dtype) {
  48. OP_LOGE(op.GetName().c_str(), "the op type of param %s must equal with param %s", iter->first.c_str(),
  49. first_name.c_str());
  50. return false;
  51. }
  52. auto dims = input_desc.GetShape().GetDims();
  53. if (dims != first_shape_dims) {
  54. OP_LOGE(op.GetName().c_str(), "the op shape of param %s must equal with param %s", iter->first.c_str(),
  55. first_name.c_str());
  56. return false;
  57. }
  58. }
  59. return true;
  60. }
  61. bool CheckInputDataType(const Operator& op, const std::string& input_name,
  62. const std::vector<ge::DataType>& support_list) {
  63. bool valid = false;
  64. DataType input_type = op.GetInputDesc(input_name).GetDataType();
  65. do {
  66. const auto& found_list = find(support_list.begin(), support_list.end(), input_type);
  67. if (found_list == support_list.end()) {
  68. break;
  69. }
  70. const auto& found_map = DTYPE_STR_MAP.find(input_type);
  71. if (found_map == DTYPE_STR_MAP.end()) {
  72. break;
  73. }
  74. valid = true;
  75. } while (0);
  76. if (!valid) {
  77. OpsInputDtypeErrReport(op.GetName(), input_name, DebugString(support_list), ConcatString(input_type));
  78. OP_LOGE(op.GetName().c_str(), "The op do not support the dtype %s",
  79. ge::TypeUtils::DataTypeToSerialString(input_type).c_str());
  80. return false;
  81. }
  82. return true;
  83. }
  84. bool CheckTwoInputDtypeSame(const Operator& op, const string& input_name1, const string& input_name2) {
  85. auto op_desc = OpDescUtils::GetOpDescFromOperator(op);
  86. CHECK(op_desc == nullptr || op_desc->MutableInputDesc(input_name1) == nullptr ||
  87. op_desc->MutableInputDesc(input_name2) == nullptr,
  88. OP_LOGE(op.GetName().c_str(), "invalid OpDesc."), return false);
  89. DataType input_type_x1 = op_desc->MutableInputDesc(input_name1)->GetDataType();
  90. DataType input_type_x2 = op_desc->MutableInputDesc(input_name2)->GetDataType();
  91. if (input_type_x1 != input_type_x2) {
  92. OpsTwoInputDtypeErrReport(op.GetName(), input_name1, input_name2, ConcatString(input_type_x1),
  93. ConcatString(input_type_x2));
  94. OP_LOGE(op.GetName().c_str(), "The %s op dtype is not same, type1:%s, type2:%s", op.GetName().c_str(),
  95. ge::TypeUtils::DataTypeToSerialString(input_type_x1).c_str(),
  96. ge::TypeUtils::DataTypeToSerialString(input_type_x2).c_str());
  97. return false;
  98. }
  99. return true;
  100. }
  101. bool CheckInputDtypeSame(const Operator& op, std::vector<std::string>& input_tensors) {
  102. auto first_name = input_tensors.begin();
  103. auto first_input_dtype = op.GetInputDesc(*first_name).GetDataType();
  104. for (const string& input_name : input_tensors) {
  105. const TensorDesc input_desc = op.GetInputDesc(input_name);
  106. auto input_dtype = input_desc.GetDataType();
  107. if (input_dtype != first_input_dtype) {
  108. OP_LOGE(op.GetName().c_str(), "the op type of param %s must equal with param %s", input_name.c_str(),
  109. (*first_name).c_str());
  110. return false;
  111. }
  112. }
  113. return true;
  114. }
  115. bool CheckInputsShapeDtypeSame(const Operator& op, const std::vector<std::string>& input_names) {
  116. auto first_input_name = input_names.begin();
  117. auto first_input_des = op.GetInputDesc(*first_input_name);
  118. auto input_name = first_input_name;
  119. for (++input_name; input_name != input_names.end(); ++input_name) {
  120. auto input_des = op.GetInputDesc(*first_input_name);
  121. if (input_des.GetDataType() != first_input_des.GetDataType() ||
  122. input_des.GetShape().GetDims() != first_input_des.GetShape().GetDims()) {
  123. OpsAttrValueErrReport(
  124. op.GetName(), ConcatString(input_name->c_str(), "'s dtype and shape"),
  125. ConcatString("same as", first_input_name->c_str(), "[", first_input_des.GetDataType(), "]", "[",
  126. DebugString(first_input_des.GetShape().GetDims()), "]"),
  127. ConcatString("[", input_des.GetDataType(), "]", "[", DebugString(input_des.GetShape().GetDims()), "]"));
  128. OP_LOGE(op.GetName().c_str(), "the dtype and shape of param %s must be same as param %s",
  129. first_input_name->c_str(), input_name->c_str());
  130. return false;
  131. }
  132. }
  133. return true;
  134. }
  135. bool InferShapeAndTypeTwoInOneOutBroadcast(Operator& op, const string& input_name1, const string& input_name2,
  136. const string& output_name, bool& is_dynamic) {
  137. auto op_desc = OpDescUtils::GetOpDescFromOperator(op);
  138. CHECK(op_desc == nullptr || op_desc->MutableOutputDesc(output_name) == nullptr||
  139. op_desc->MutableInputDesc(input_name1) == nullptr || op_desc->MutableInputDesc(input_name2) == nullptr,
  140. OP_LOGE(op.GetName().c_str(), "invalid OpDesc."), return false);
  141. DataType input_dtype = op_desc->MutableInputDesc(input_name1)->GetDataType();
  142. // output Desc
  143. GeTensorDescPtr tensordesc_output = op_desc->MutableOutputDesc(output_name);
  144. tensordesc_output->SetDataType(input_dtype);
  145. ge::GeShape shapeX = op_desc->MutableInputDesc(input_name1)->GetShape();
  146. ge::GeShape shapeY = op_desc->MutableInputDesc(input_name2)->GetShape();
  147. OP_LOGI(op.GetName().c_str(), "shape %s: %s, shape %s: %s.", input_name1.c_str(), to_string(shapeX).c_str(),
  148. input_name2.c_str(), to_string(shapeY).c_str());
  149. std::vector<int64_t> dimsX = shapeX.GetDims();
  150. std::vector<int64_t> dimsY = shapeY.GetDims();
  151. // swap based on shape size
  152. if (dimsX.size() < dimsY.size()) {
  153. std::vector<int64_t> dimsTmp = dimsX;
  154. dimsX = dimsY;
  155. dimsY = dimsTmp;
  156. }
  157. std::vector<int64_t> dimVec;
  158. // unknown rank
  159. if (IsUnknownRankShape(dimsX) || IsUnknownRankShape(dimsY)) {
  160. tensordesc_output->SetShape(ge::GeShape(UNKNOWN_RANK));
  161. OP_LOGI(op.GetName().c_str(), "output shape is: %s, output dtype is:%d.", to_string(ge::Shape(UNKNOWN_RANK)).c_str(),
  162. input_dtype);
  163. is_dynamic = false;
  164. return true;
  165. }
  166. // pad 1 for small shape
  167. if (dimsX.size() != dimsY.size()) {
  168. int dec = dimsX.size() - dimsY.size();
  169. for (int i = 0; i < dec; i++) {
  170. dimsY.insert(dimsY.begin(), (int64_t)1);
  171. }
  172. }
  173. // when not dynamic case, do infer shape only
  174. if (!IsUnknown(dimsY) && !IsUnknown(dimsX)) {
  175. for (size_t i = 0; i < dimsX.size(); i++) {
  176. int64_t dims = dimsX[i] > dimsY[i] ? dimsX[i] : dimsY[i];
  177. dims = (dimsY[i] == 0 || dimsX[i] == 0) ? 0 : dims;
  178. dimVec.push_back(dims);
  179. }
  180. tensordesc_output->SetShape(ge::GeShape(dimVec));
  181. is_dynamic = false;
  182. return true;
  183. }
  184. // dynamic case
  185. for (size_t i = 0; i < dimsX.size(); i++) {
  186. CHECK((dimsX[i] != dimsY[i]) && (dimsX[i] != 1) && (dimsY[i] != 1) && (dimsX[i] != -1) && (dimsY[i] != -1),
  187. OpsInputShapeBroadcastErrReport(op.GetName(), input_name1, input_name2, ConcatString(dimsX[i]),
  188. ConcatString(dimsY[i]));
  189. OP_LOGE(op.GetName().c_str(), "The %s's dimensions does not match the broadcast rule(%lu %lu).",
  190. op.GetName().c_str(), dimsX[i], dimsY[i]),
  191. return false);
  192. if ((dimsX[i] == -1) && (dimsY[i] != -1)) {
  193. if (dimsY[i] > 1) {
  194. int64_t dims = dimsX[i] > dimsY[i] ? dimsX[i] : dimsY[i];
  195. dimVec.push_back(dims);
  196. } else if (dimsY[i] == 1) {
  197. int64_t dims = dimsX[i] > dimsY[i] ? dimsX[i] : dimsY[i];
  198. dimVec.push_back(dims);
  199. dimVec[i] = -1;
  200. } else if ((dimsY[i] == 0) || (dimsX[i] == 0)) {
  201. dimVec.push_back(0);
  202. }
  203. } else if ((dimsX[i] != -1) && (dimsY[i] == -1)) {
  204. if (dimsX[i] > 1) {
  205. int64_t dims = dimsX[i] > dimsY[i] ? dimsX[i] : dimsY[i];
  206. dimVec.push_back(dims);
  207. } else if (dimsX[i] == 0) {
  208. dimVec.push_back(0);
  209. } else if (dimsX[i] == 1) {
  210. int64_t dims = dimsX[i] > dimsY[i] ? dimsX[i] : dimsY[i];
  211. dimVec.push_back(dims);
  212. dimVec[i] = -1;
  213. }
  214. } else {
  215. if ((dimsX[i] == -1) && (dimsY[i] == -1)) {
  216. int64_t dims = dimsX[i] > dimsY[i] ? dimsX[i] : dimsY[i];
  217. dimVec.push_back(dims);
  218. dimVec[i] = -1;
  219. } else {
  220. if (dimsY[i] == 0 || dimsX[i] == 0) {
  221. dimVec.push_back(0);
  222. } else {
  223. int64_t dims = dimsX[i] > dimsY[i] ? dimsX[i] : dimsY[i];
  224. dimVec.push_back(dims);
  225. }
  226. }
  227. }
  228. }
  229. ge::GeShape outputShape = ge::GeShape(dimVec);
  230. tensordesc_output->SetShape(outputShape);
  231. OP_LOGI(op.GetName().c_str(), "output shape is: %s, output dtype is:%s.", to_string(outputShape).c_str(),
  232. ge::TypeUtils::DataTypeToSerialString(input_dtype).c_str());
  233. is_dynamic = IsUnknown(dimVec);
  234. if (is_dynamic) {
  235. if (!InferShapeRangeTwoInOneOutBroadcase(op, input_name1, input_name2, output_name)) {
  236. return false;
  237. }
  238. }
  239. return true;
  240. }
  241. bool InferShapeAndTypeTwoInOneOutBroadcast(Operator& op, const string& input_name1, const string& input_name2,
  242. const string& output_name) {
  243. auto op_desc = OpDescUtils::GetOpDescFromOperator(op);
  244. CHECK(op_desc == nullptr || op_desc->MutableInputDesc(input_name1) == nullptr ||
  245. op_desc->MutableOutputDesc(output_name) == nullptr || op_desc->MutableInputDesc(input_name2) == nullptr,
  246. OP_LOGE(op.GetName().c_str(), "invalid OpDesc."), return false);
  247. DataType input_dtype = op_desc->MutableInputDesc(input_name1)->GetDataType();
  248. GeTensorDescPtr tensordesc_output = op_desc->MutableOutputDesc(output_name);
  249. ge::GeShape shapeX = op_desc->MutableInputDesc(input_name1)->GetShape();
  250. ge::GeShape shapeY = op_desc->MutableInputDesc(input_name2)->GetShape();
  251. OP_LOGI(op.GetName().c_str(), "shape %s: %s, shape %s: %s.", input_name1.c_str(), to_string(shapeX).c_str(),
  252. input_name2.c_str(), to_string(shapeY).c_str());
  253. std::vector<int64_t> dimsX = shapeX.GetDims();
  254. std::vector<int64_t> dimsY = shapeY.GetDims();
  255. // swap based on shape size
  256. if (dimsX.size() < dimsY.size()) {
  257. std::vector<int64_t> dimsTmp = dimsX;
  258. dimsX = dimsY;
  259. dimsY = dimsTmp;
  260. }
  261. std::vector<int64_t> dimVec;
  262. // unknown rank
  263. if (IsUnknownRankShape(dimsX) || IsUnknownRankShape(dimsY)) {
  264. tensordesc_output->SetShape(ge::GeShape(UNKNOWN_RANK));
  265. tensordesc_output->SetDataType(input_dtype);
  266. OP_LOGI(op.GetName().c_str(), "output shape is: %s, output dtype is:%d.", to_string(ge::Shape(UNKNOWN_RANK)).c_str(),
  267. input_dtype);
  268. return true;
  269. }
  270. // pad 1 for small shape
  271. if (dimsX.size() != dimsY.size()) {
  272. int dec = dimsX.size() - dimsY.size();
  273. for (int i = 0; i < dec; i++) {
  274. dimsY.insert(dimsY.begin(), (int64_t)1);
  275. }
  276. }
  277. for (size_t i = 0; i < dimsX.size(); i++) {
  278. CHECK((dimsX[i] != dimsY[i]) && (dimsX[i] != 1) && (dimsY[i] != 1) && (dimsX[i] != -1) && (dimsY[i] != -1),
  279. OpsInputShapeBroadcastErrReport(op.GetName(), input_name1, input_name2, ConcatString(dimsX[i]),
  280. ConcatString(dimsY[i]));
  281. OP_LOGE(op.GetName().c_str(), "The %s's dimensions does not match the broadcast rule(%lu %lu).",
  282. op.GetName().c_str(), dimsX[i], dimsY[i]),
  283. return false);
  284. if ((dimsX[i] == -1) && (dimsY[i] != -1)) {
  285. if (dimsY[i] > 1) {
  286. int64_t dims = dimsX[i] > dimsY[i] ? dimsX[i] : dimsY[i];
  287. dimVec.push_back(dims);
  288. } else if (dimsY[i] == 1) {
  289. int64_t dims = dimsX[i] > dimsY[i] ? dimsX[i] : dimsY[i];
  290. dimVec.push_back(dims);
  291. dimVec[i] = -1;
  292. } else if ((dimsY[i] == 0) || (dimsX[i] == 0)) {
  293. dimVec.push_back(0);
  294. }
  295. } else if ((dimsX[i] != -1) && (dimsY[i] == -1)) {
  296. if (dimsX[i] > 1) {
  297. int64_t dims = dimsX[i] > dimsY[i] ? dimsX[i] : dimsY[i];
  298. dimVec.push_back(dims);
  299. } else if (dimsX[i] == 0) {
  300. dimVec.push_back(0);
  301. } else if (dimsX[i] == 1) {
  302. int64_t dims = dimsX[i] > dimsY[i] ? dimsX[i] : dimsY[i];
  303. dimVec.push_back(dims);
  304. dimVec[i] = -1;
  305. }
  306. } else {
  307. if ((dimsX[i] == -1) && (dimsY[i] == -1)) {
  308. int64_t dims = dimsX[i] > dimsY[i] ? dimsX[i] : dimsY[i];
  309. dimVec.push_back(dims);
  310. dimVec[i] = -1;
  311. } else {
  312. if (dimsY[i] == 0 || dimsX[i] == 0) {
  313. dimVec.push_back(0);
  314. } else {
  315. int64_t dims = dimsX[i] > dimsY[i] ? dimsX[i] : dimsY[i];
  316. dimVec.push_back(dims);
  317. }
  318. }
  319. }
  320. }
  321. ge::GeShape outputShape = ge::GeShape(dimVec);
  322. tensordesc_output->SetShape(outputShape);
  323. tensordesc_output->SetDataType(input_dtype);
  324. OP_LOGI(op.GetName().c_str(), "output shape is: %s, output dtype is:%s.", to_string(outputShape).c_str(),
  325. ge::TypeUtils::DataTypeToSerialString(input_dtype).c_str());
  326. return true;
  327. }
  328. bool InferShapeRangeTwoInOneOutBroadcase(Operator& op, const string& input_name1, const string& input_name2,
  329. const string& output_name) {
  330. auto op_desc = OpDescUtils::GetOpDescFromOperator(op);
  331. CHECK(op_desc == nullptr || op_desc->MutableInputDesc(input_name1) == nullptr ||
  332. op_desc->MutableOutputDesc(output_name) == nullptr || op_desc->MutableInputDesc(input_name2) == nullptr,
  333. OP_LOGE(op.GetName().c_str(), "invalid OpDesc."), return false);
  334. ge::GeShape shape_x = op_desc->MutableInputDesc(input_name1)->GetShape();
  335. ge::GeShape shape_y = op_desc->MutableInputDesc(input_name2)->GetShape();
  336. std::vector<int64_t> dims_x = shape_x.GetDims();
  337. std::vector<int64_t> dims_y = shape_y.GetDims();
  338. std::vector<std::pair<int64_t, int64_t>> shape_range_x;
  339. op_desc->MutableInputDesc(input_name1)->GetShapeRange(shape_range_x);
  340. std::vector<std::pair<int64_t, int64_t>> shape_range_y;
  341. op_desc->MutableInputDesc(input_name2)->GetShapeRange(shape_range_y);
  342. MakeUpShapeRange(dims_x, shape_range_x);
  343. MakeUpShapeRange(dims_y, shape_range_y);
  344. ge::GeShape shape_out = op_desc->MutableOutputDesc(output_name)->GetShape();
  345. std::vector<int64_t> dims_out = shape_out.GetDims();
  346. size_t size_shape_out = dims_out.size();
  347. std::vector<std::pair<int64_t, int64_t>> out_range;
  348. if (!IsUnknownRankShape(dims_out)) {
  349. // shape switch by shape dim size
  350. if (dims_x.size() < dims_y.size()) {
  351. std::vector<int64_t> dims_tmp = dims_x;
  352. dims_x = dims_y;
  353. dims_y = dims_tmp;
  354. std::vector<std::pair<int64_t, int64_t>> range_temp = shape_range_x;
  355. shape_range_x = shape_range_y;
  356. shape_range_y = range_temp;
  357. }
  358. while (dims_x.size() > shape_range_y.size()) {
  359. shape_range_y.insert(shape_range_y.begin(), std::pair<int64_t, int64_t>(1, 1));
  360. }
  361. for (size_t i = 0; i < size_shape_out; i++) {
  362. if (dims_out[i] != -1) {
  363. out_range.push_back(std::pair<int64_t, int64_t>(dims_out[i], dims_out[i]));
  364. continue;
  365. }
  366. if (i < shape_range_x.size() && i < shape_range_y.size()) {
  367. if (shape_range_x[i].second == -1 && shape_range_y[i].second == 1) {
  368. out_range.push_back(std::pair<int64_t, int64_t>(1, -1));
  369. } else if (shape_range_x[i].second == 1 && shape_range_y[i].second == -1) {
  370. out_range.push_back(std::pair<int64_t, int64_t>(1, -1));
  371. } else if (shape_range_x[i].first == 1 || shape_range_y[i].first == 1) {
  372. // one shape size maybe 1, so will support boardcast
  373. // first_range == max first
  374. int64_t first_range = std::max(shape_range_x[i].first, shape_range_y[i].first);
  375. int64_t second_range = shape_range_x[i].first == 1 ? shape_range_y[i].second : shape_range_x[i].second;
  376. if (shape_range_x[i].first == 1 && shape_range_y[i].first == 1) {
  377. second_range = std::max(shape_range_x[i].second, shape_range_y[i].second);
  378. second_range = (shape_range_x[i].second == -1 || shape_range_y[i].second == -1) ? -1 : second_range;
  379. }
  380. out_range.push_back(std::pair<int64_t, int64_t>(first_range, second_range));
  381. } else {
  382. // no 1 in range.first, mean no boardcast for range
  383. // get intersect range
  384. int64_t first_range = std::max(shape_range_x[i].first, shape_range_y[i].first);
  385. int64_t second_range = std::min(shape_range_x[i].second, shape_range_y[i].second);
  386. second_range = (shape_range_x[i].second == -1 || shape_range_y[i].second == -1)
  387. ? std::max(shape_range_x[i].second, shape_range_y[i].second)
  388. : second_range;
  389. out_range.push_back(std::pair<int64_t, int64_t>(first_range, second_range));
  390. }
  391. }
  392. }
  393. }
  394. GeTensorDescPtr tensor_out = op_desc->MutableOutputDesc(output_name);
  395. tensor_out->SetShapeRange(out_range);
  396. return true;
  397. }
  398. bool GetInputDataType(const ge::DataType& dataType, const std::vector<ge::DataType>& supportList, std::string& dType) {
  399. std::vector<ge::DataType>::const_iterator supportIter = find(supportList.begin(), supportList.end(), dataType);
  400. if (supportIter == supportList.end()) {
  401. return false;
  402. }
  403. std::map<ge::DataType, std::string>::const_iterator totalIter = DTYPE_STR_MAP.find(dataType);
  404. if (totalIter == DTYPE_STR_MAP.end()) {
  405. return false;
  406. }
  407. dType = totalIter->second;
  408. return true;
  409. }
  410. bool CheckInputDataType(const Operator& op, std::string* data_type, const std::string& input_name,
  411. const std::vector<ge::DataType>& supportList) {
  412. DataType input_type = op.GetInputDesc(input_name).GetDataType();
  413. if (false == GetInputDataType(input_type, supportList, *data_type)) {
  414. LOG_ERROR("[ERROR]op [%s] [%s] do not supported dtype [%s]!\n", op.GetName().c_str(), input_name.c_str(),
  415. data_type->c_str());
  416. return false;
  417. }
  418. return true;
  419. }
  420. bool GetConstValue(const ge::Operator& op, const std::string& key_name, float& attr_value) {
  421. if (ge::GRAPH_SUCCESS != op.GetAttr(key_name, attr_value)) {
  422. LOG_ERROR("[ERROR]op [%s] GetOpAttr [%s] failed!\n", op.GetName().c_str(), key_name.c_str());
  423. return false;
  424. }
  425. return true;
  426. }
  427. bool GetConstValue(const ge::Operator& op, const std::string& key_name, int64_t& attr_value) {
  428. if (ge::GRAPH_SUCCESS != op.GetAttr(key_name, attr_value)) {
  429. LOG_ERROR("[ERROR]op [%s] GetOpAttr [%s] failed!\n", op.GetName().c_str(), key_name.c_str());
  430. return false;
  431. }
  432. return true;
  433. }
  434. bool GetConstValue(const ge::Operator& op, const std::string& key_name, bool& attr_value) {
  435. if (ge::GRAPH_SUCCESS != op.GetAttr(key_name, attr_value)) {
  436. LOG_ERROR("[ERROR]op [%s] GetOpAttr [%s] failed!\n", op.GetName().c_str(), key_name.c_str());
  437. return false;
  438. }
  439. return true;
  440. }
  441. bool GetConstValue(const ge::Operator& op, const std::string& key_name, std::vector<int32_t>& attr_value) {
  442. if (ge::GRAPH_SUCCESS != op.GetAttr(key_name, attr_value)) {
  443. LOG_ERROR("[ERROR]op [%s] GetOpAttr [%s] failed!\n", op.GetName().c_str(), key_name.c_str());
  444. return false;
  445. }
  446. return true;
  447. }
  448. template <typename T>
  449. static std::vector<int64_t> GetConstIntData(const uint8_t* const_data, size_t data_size) {
  450. size_t size = data_size / sizeof(T);
  451. std::vector<int64_t> result(size);
  452. T* data = (T*)const_data;
  453. for (size_t i = 0; i < size; i++) {
  454. result[i] = *(data + i);
  455. }
  456. return result;
  457. }
  458. bool GetConstIntData(const Tensor& data, DataType data_type, std::vector<int64_t>& const_values) {
  459. using namespace std::placeholders;
  460. const std::map<DataType, std::function<std::vector<int64_t>(const uint8_t*, size_t)>> type_call_map = {
  461. {DT_INT8, std::bind(GetConstIntData<int8_t>, _1, _2)},
  462. {DT_INT16, std::bind(GetConstIntData<int16_t>, _1, _2)},
  463. {DT_INT32, std::bind(GetConstIntData<int32_t>, _1, _2)},
  464. {DT_INT64, std::bind(GetConstIntData<int64_t>, _1, _2)},
  465. };
  466. auto found = type_call_map.find(data_type);
  467. if (found == type_call_map.end()) {
  468. USER_GE_LOGE("[ERROR]GetConstIntData is not support data_type[%s]!",
  469. ge::TypeUtils::DataTypeToSerialString(data_type).c_str());
  470. return false;
  471. }
  472. const_values = found->second(data.GetData(), data.GetSize());
  473. return true;
  474. }
  475. bool GetConstValue(const Operator& op, const Tensor& const_tensor, const DataType& dtype,
  476. std::vector<int64_t>& const_data) {
  477. size_t size = 0;
  478. CHECK(dtype != ge::DT_INT32 && dtype != ge::DT_INT64,
  479. OP_LOGE(op.GetName().c_str(), "not support this type"), return false);
  480. if (dtype == ge::DT_INT32) {
  481. int32_t* const_data_ptr = (int32_t*)const_tensor.GetData();
  482. size = const_tensor.GetSize() / sizeof(int32_t);
  483. for (size_t i = 0; i < size; ++i) {
  484. const_data.push_back((int32_t)((*(const_data_ptr + i))));
  485. OP_LOGD(op.GetName().c_str(), "const data int32 fusion pass ====== %d", (int32_t)(*(const_data_ptr + i)));
  486. }
  487. } else if (dtype == ge::DT_INT64) {
  488. int64_t* const_data_ptr = (int64_t*)const_tensor.GetData();
  489. size = const_tensor.GetSize() / sizeof(int64_t);
  490. for (size_t i = 0; i < size; ++i) {
  491. const_data.push_back(((int64_t)(*(const_data_ptr + i))));
  492. OP_LOGD(op.GetName().c_str(), "const data int64 fusion pass ====== %d", (int64_t)(*(const_data_ptr + i)));
  493. }
  494. }
  495. return true;
  496. }
  497. bool GetConstValue(const Operator& op, const GeTensorPtr& const_tensor,
  498. const DataType& dtype, std::vector<int64_t>& const_data) {
  499. size_t size = const_tensor->GetData().GetSize();
  500. void* data_ptr = (void*)const_tensor->GetData().GetData();
  501. CHECK(data_ptr == nullptr, OP_LOGE(op.GetName().c_str(), "data is null."), return false);
  502. CHECK(dtype != ge::DT_INT32 && dtype != ge::DT_INT64,
  503. OP_LOGE(op.GetName().c_str(), "const not support this type"), return false);
  504. if (dtype == ge::DT_INT32){
  505. int32_t* const_data_ptr = reinterpret_cast<int32_t*>(data_ptr);
  506. size = size / sizeof(int32_t);
  507. for (size_t i=0; i < size; i++) {
  508. const_data.push_back((int64_t)((int32_t) ((*(const_data_ptr + i)))));
  509. }
  510. } else if (dtype == ge::DT_INT64) {
  511. int64_t* const_data_ptr = reinterpret_cast<int64_t*>(data_ptr);
  512. size = size / sizeof(int64_t);
  513. for (size_t i=0; i < size; i++) {
  514. const_data.push_back((int64_t)((int64_t) ((*(const_data_ptr + i)))));
  515. }
  516. }
  517. return true;
  518. }
  519. bool GetScalerValue(const Operator& op, const Tensor& const_tensor, const DataType& dtype, std::int64_t& const_data) {
  520. if (dtype == ge::DT_INT32) {
  521. int32_t* const_data_ptr = (int32_t*)const_tensor.GetData();
  522. const_data = (int32_t)(*const_data_ptr);
  523. } else if (dtype == ge::DT_INT64) {
  524. int64_t* const_data_ptr = (int64_t*)const_tensor.GetData();
  525. const_data = (int64_t)(*const_data_ptr);
  526. } else {
  527. OP_LOGE(op.GetName().c_str(), "not support this type");
  528. return false;
  529. }
  530. return true;
  531. }
  532. string to_string(const vector<int64_t>& shape) {
  533. return ops::to_string(shape);
  534. }
  535. std::string to_string(const ge::Shape& shape) {
  536. return to_string(shape.GetDims());
  537. }
  538. std::string to_string(const ge::GeShape& shape) {
  539. return to_string(shape.GetDims());
  540. }
  541. std::string to_string(const vector<pair<int64_t, int64_t>>& ranges) {
  542. return ops::to_string(ranges);
  543. }
  544. bool DynamicShapeInfer::CatchFormatAndShape() {
  545. inputs = op_desc->GetAllInputName();
  546. outputs = op_desc->GetAllOutputName();
  547. GeTensorDescPtr tensor_desc_input, tensor_desc_output;
  548. // get and save current input shape&format, and assign origin ones to them
  549. std::string input_name;
  550. for (map<std::string, uint32_t>::iterator it = inputs.begin(); it != inputs.end(); it++) {
  551. input_name = it->first;
  552. tensor_desc_input = op_desc->MutableInputDesc(input_name);
  553. if (tensor_desc_input == nullptr) {
  554. continue;
  555. }
  556. Format curr_format = tensor_desc_input->GetFormat();
  557. map_format.insert(std::pair<std::string, Format>(input_name, curr_format));
  558. map_dtype.insert(std::pair<std::string, DataType>(input_name, tensor_desc_input->GetDataType()));
  559. if (tensor_desc_input->GetOriginFormat() == curr_format) {
  560. continue;
  561. }
  562. tensor_desc_input->SetFormat(tensor_desc_input->GetOriginFormat());
  563. tensor_desc_input->SetShape(tensor_desc_input->GetOriginShape());
  564. }
  565. // get and save current output shape&format, and assign origin ones to them
  566. std::string output_name;
  567. for (map<std::string, uint32_t>::iterator it = outputs.begin(); it != outputs.end(); it++) {
  568. output_name = it->first;
  569. tensor_desc_output = op_desc->MutableOutputDesc(output_name);
  570. if (tensor_desc_output == nullptr) {
  571. continue;
  572. }
  573. Format curr_format = tensor_desc_output->GetFormat();
  574. map_format.insert(std::pair<std::string, Format>(output_name, curr_format));
  575. map_dtype.insert(std::pair<std::string, DataType>(output_name, tensor_desc_output->GetDataType()));
  576. if (tensor_desc_output->GetOriginFormat() == curr_format) {
  577. continue;
  578. }
  579. tensor_desc_output->SetFormat(tensor_desc_output->GetOriginFormat());
  580. }
  581. return true;
  582. }
  583. bool DynamicShapeInfer::UpdateFormatAndShape() {
  584. const int64_t opImplType = EN_IMPL_CUSTOM_TBE;
  585. GeTensorDescPtr tensor_desc_input, tensor_desc_output;
  586. // assign output's after infershape to origin shape
  587. for (map<std::string, uint32_t>::iterator it = outputs.begin(); it != outputs.end(); it++) {
  588. tensor_desc_output = op_desc->MutableOutputDesc(it->first);
  589. if (tensor_desc_output == nullptr) {
  590. continue;
  591. }
  592. tensor_desc_output->SetOriginShape(tensor_desc_output->GetShape());
  593. }
  594. // transfer input's origin shape to current shape
  595. Format ori_input_format, cur_input_format;
  596. GeShape ori_infer_shape, current_shape;
  597. std::string input_name;
  598. for (map<std::string, uint32_t>::iterator it = inputs.begin(); it != inputs.end(); it++) {
  599. input_name = it->first;
  600. tensor_desc_input = op_desc->MutableInputDesc(input_name);
  601. if (tensor_desc_input == nullptr) {
  602. continue;
  603. }
  604. ori_input_format = tensor_desc_input->GetFormat();
  605. ori_infer_shape = tensor_desc_input->GetShape();
  606. cur_input_format = map_format[input_name];
  607. // print some info
  608. OP_LOGI(op.GetName().c_str(), "origin input shape %s is %s", input_name.c_str(),
  609. to_string(ori_infer_shape).c_str());
  610. ShapeAndFormat shapeAndFormatInfoInput = {ori_infer_shape, current_shape, ori_input_format,
  611. cur_input_format, map_dtype[input_name], opImplType};
  612. if (ori_input_format == cur_input_format) {
  613. // no need to transfer shape
  614. continue;
  615. } else {
  616. ShapeTransferAccordingToFormat* global_object = new ShapeTransferAccordingToFormat();
  617. CHECK(global_object == nullptr, OP_LOGE(op.GetName().c_str(), "new ShapeTransferAccordingToFormat failed."),
  618. return false);
  619. global_object->GetShapeAccordingToFormat(shapeAndFormatInfoInput);
  620. // print some info
  621. OP_LOGI(op.GetName().c_str(), "current input shape %s is %s", input_name.c_str(),
  622. to_string(current_shape).c_str());
  623. tensor_desc_input->SetFormat(cur_input_format);
  624. tensor_desc_input->SetShape(current_shape);
  625. delete global_object;
  626. }
  627. }
  628. // transfer output's origin shape to current shape
  629. Format ori_output_format, cur_output_format;
  630. GeShape ori_infer_out_shape, current_out_shape;
  631. std::string output_name;
  632. for (map<std::string, uint32_t>::iterator it = outputs.begin(); it != outputs.end(); it++) {
  633. output_name = it->first;
  634. tensor_desc_output = op_desc->MutableOutputDesc(output_name);
  635. if (tensor_desc_output == nullptr) {
  636. continue;
  637. }
  638. ori_output_format = tensor_desc_output->GetFormat();
  639. ori_infer_out_shape = tensor_desc_output->GetShape();
  640. cur_output_format = map_format[output_name];
  641. // print some info
  642. OP_LOGI(op.GetName().c_str(), "origin output shape %s is %s", output_name.c_str(),
  643. to_string(ori_infer_out_shape).c_str());
  644. ShapeAndFormat shapeAndFormatInfoOutput = {ori_infer_out_shape, current_out_shape, ori_output_format,
  645. cur_output_format, map_dtype[output_name], opImplType};
  646. if (ori_output_format == cur_output_format) {
  647. // no need to transfer shape
  648. continue;
  649. } else {
  650. ShapeTransferAccordingToFormat* global_object = new ShapeTransferAccordingToFormat();
  651. CHECK(global_object == nullptr, OP_LOGE(op.GetName().c_str(), "new ShapeTransferAccordingToFormat failed."),
  652. return false);
  653. global_object->GetShapeAccordingToFormat(shapeAndFormatInfoOutput);
  654. // print some info
  655. OP_LOGI(op.GetName().c_str(), "current output shape %s is %s", output_name.c_str(),
  656. to_string(current_out_shape).c_str());
  657. tensor_desc_output->SetFormat(cur_output_format);
  658. tensor_desc_output->SetShape(current_out_shape);
  659. delete global_object;
  660. }
  661. }
  662. return true;
  663. }
  664. bool IsEmptyTensor(const std::vector<int64_t>& dims) {
  665. if (dims.size() == 1 && dims[0] == 0) {
  666. return true;
  667. } else {
  668. return false;
  669. }
  670. }
  671. bool IsUnknownRank(const Operator& op, const std::string& tensor_name, const std::string& types) {
  672. auto op_desc = OpDescUtils::GetOpDescFromOperator(op);
  673. CHECK(op_desc == nullptr, OP_LOGE(op.GetName().c_str(), "invalid OpDesc."), return false);
  674. GeTensorDescPtr tensor_desc;
  675. if (types == "input") {
  676. tensor_desc = op_desc->MutableInputDesc(tensor_name);
  677. } else if (types == "output") {
  678. tensor_desc = op_desc->MutableOutputDesc(tensor_name);
  679. } else {
  680. OP_LOGE(op.GetName().c_str(), "invalid params of types to judge.");
  681. return false;
  682. }
  683. std::vector<int64_t> shape_vec = tensor_desc->GetShape().GetDims();
  684. if (shape_vec.size() == 1 && shape_vec[0] == -2) {
  685. return true;
  686. }
  687. return false;
  688. }
  689. bool IsUnknownRankShape(const std::vector<int64_t>& shape_vec) {
  690. if (shape_vec.size() == 1 && shape_vec[0] == -2) {
  691. return true;
  692. }
  693. return false;
  694. }
  695. bool IsUnKnownShape(const std::vector<int64_t>& shape_vec) {
  696. auto found = find(shape_vec.begin(), shape_vec.end(), -1);
  697. return found != shape_vec.end();
  698. }
  699. bool IsUnknown(const std::vector<int64_t>& shape_vec) {
  700. return (IsUnKnownShape(shape_vec) || IsUnknownRankShape(shape_vec));
  701. }
  702. bool IsUnknownShape(const Operator& op, const std::string& tensor_name, const std::string& types) {
  703. auto op_desc = OpDescUtils::GetOpDescFromOperator(op);
  704. CHECK(op_desc == nullptr, OP_LOGE(op.GetName().c_str(), "invalid OpDesc."), return false);
  705. GeTensorDescPtr tensor_desc;
  706. if (types == "input") {
  707. tensor_desc = op_desc->MutableInputDesc(tensor_name);
  708. } else if (types == "output") {
  709. tensor_desc = op_desc->MutableOutputDesc(tensor_name);
  710. } else {
  711. OP_LOGE(op.GetName().c_str(), "invalid params of types to judge.");
  712. return false;
  713. }
  714. std::vector<int64_t> shape_vec = tensor_desc->GetShape().GetDims();
  715. std::vector<int64_t>::iterator it_shape;
  716. it_shape = find(shape_vec.begin(), shape_vec.end(), -1);
  717. if (it_shape == shape_vec.end()) {
  718. return false;
  719. } else {
  720. return true;
  721. }
  722. }
  723. bool IsUnknownVec(std::vector<int64_t>& shape_vec) {
  724. std::vector<int64_t>::iterator it_shape;
  725. it_shape = find(shape_vec.begin(), shape_vec.end(), -1);
  726. if (it_shape == shape_vec.end()) {
  727. return false;
  728. } else {
  729. return true;
  730. }
  731. }
  732. void MakeUpShapeRange(const std::vector<int64_t>& shape, std::vector<std::pair<int64_t, int64_t>>& range) {
  733. if (IsUnknownRankShape(shape)) {
  734. return;
  735. }
  736. if (range.empty()) {
  737. for (size_t i = 0; i < shape.size(); i++) {
  738. if (shape[i] == -1) {
  739. range.push_back(std::pair<int64_t, int64_t>(1, -1));
  740. } else {
  741. range.push_back(std::pair<int64_t, int64_t>(shape[i], shape[i]));
  742. }
  743. }
  744. }
  745. }
  746. std::string DataTypeToStringDesc(const ge::DataType& dataType) {
  747. std::map<ge::DataType, std::string>::const_iterator totalIter = DTYPE_STR_MAP.find(dataType);
  748. if (totalIter == DTYPE_STR_MAP.end()) {
  749. return "UNDEFINED";
  750. }
  751. return totalIter->second;
  752. }
  753. bool OneInOneOutDynamicInfer(const Operator& op,
  754. const std::string& input_name,
  755. const std::vector<std::string>& output_name_list) {
  756. // get input desc
  757. auto op_info = OpDescUtils::GetOpDescFromOperator(op);
  758. CHECK(op_info == nullptr, OP_LOGE(op.GetName().c_str(), "invalid OpDesc."), return false);
  759. auto input_desc = op_info->MutableInputDesc(input_name);
  760. vector<int64_t> input_shape = input_desc->MutableShape().GetDims();
  761. DataType input_dtype = input_desc->GetDataType();
  762. if (IsUnknown(input_shape)) {
  763. std::vector<std::pair<int64_t, int64_t>> input_range;
  764. input_desc->GetShapeRange(input_range);
  765. MakeUpShapeRange(input_shape, input_range);
  766. auto output_desc = op_info->MutableOutputDesc(0);
  767. for (const string& output_name : output_name_list) {
  768. output_desc = op_info->MutableOutputDesc(output_name);
  769. output_desc->SetShape(GeShape(input_shape));
  770. output_desc->SetOriginShape(GeShape(input_shape));
  771. output_desc->SetShapeRange(input_range);
  772. output_desc->SetDataType(input_dtype);
  773. }
  774. } else {
  775. auto output_desc = op_info->MutableOutputDesc(0);
  776. for (const string& output_name : output_name_list) {
  777. output_desc = op_info->MutableOutputDesc(output_name);
  778. output_desc->SetShape(GeShape(input_shape));
  779. output_desc->SetDataType(input_dtype);
  780. }
  781. }
  782. return true;
  783. }
  784. void FixShapeRangeWithDims(const std::vector<int64_t>& dims,
  785. std::vector<int64_t>& shape_1,
  786. std::vector<int64_t>& shape_2,
  787. std::vector<std::pair<int64_t, int64_t>>& range_1,
  788. std::vector<std::pair<int64_t, int64_t>>& range_2) {
  789. MakeUpShapeRange(shape_1, range_1);
  790. MakeUpShapeRange(shape_2, range_2);
  791. bool is_all_fix = dims.empty();
  792. if (shape_1 == UNKNOWN_RANK && shape_2 == UNKNOWN_RANK) {
  793. return;
  794. }
  795. if (shape_1 == UNKNOWN_RANK) {
  796. shape_1 = shape_2;
  797. range_1 = range_2;
  798. return;
  799. }
  800. if (shape_2 == UNKNOWN_RANK) {
  801. shape_2 = shape_1;
  802. range_2 = range_1;
  803. return;
  804. }
  805. if ((shape_1.size() != shape_2.size()) || (range_1.size() != range_2.size())) {
  806. return;
  807. }
  808. auto loop_size = is_all_fix ? shape_1.size() : dims.size();
  809. for (size_t i = 0; i < loop_size; i ++) {
  810. auto dim_num = is_all_fix ? i : dims[i];
  811. if (shape_1[dim_num] != -1) {
  812. shape_2[dim_num] = shape_1[dim_num];
  813. range_1[dim_num] = std::pair<int64_t, int64_t>(shape_1[dim_num], shape_1[dim_num]);
  814. range_2[dim_num] = std::pair<int64_t, int64_t>(shape_1[dim_num], shape_1[dim_num]);
  815. continue;
  816. }
  817. if (shape_2[dim_num] != -1) {
  818. shape_1[dim_num] = shape_2[dim_num];
  819. range_1[dim_num] = std::pair<int64_t, int64_t>(shape_2[dim_num], shape_2[dim_num]);
  820. range_2[dim_num] = std::pair<int64_t, int64_t>(shape_2[dim_num], shape_2[dim_num]);
  821. continue;
  822. }
  823. // both the dim in shape1 and shape2 are -1
  824. auto range_1_min = range_1[dim_num].first;
  825. auto range_2_min = range_2[dim_num].first;
  826. auto range_1_max = range_1[dim_num].second;
  827. auto range_2_max = range_2[dim_num].second;
  828. auto range_fisrt = range_1_min > range_2_min ? range_1_min : range_2_min;
  829. auto range_second_min = range_1_max > range_2_max ? range_2_max : range_1_max;
  830. auto range_second_max = range_1_max > range_2_max ? range_1_max : range_2_max;
  831. range_second_min = range_second_min == -1 ? range_second_max : range_second_min;
  832. range_1[dim_num] = std::pair<int64_t, int64_t>(range_fisrt, range_second_min);
  833. range_2[dim_num] = std::pair<int64_t, int64_t>(range_fisrt, range_second_min);
  834. }
  835. }
  836. bool TwoInOneOutDynamicInferNoBroadcast(Operator& op,
  837. const string& input1_name,
  838. const string& input2_name,
  839. const std::vector<string>& output_name_list) {
  840. // get input1 desc
  841. auto op_info = OpDescUtils::GetOpDescFromOperator(op);
  842. CHECK(op_info == nullptr || op_info->MutableInputDesc(input1_name) == nullptr ||
  843. op_info->MutableInputDesc(input2_name) == nullptr, OP_LOGE(op.GetName().c_str(), "invalid OpDesc."),
  844. return false);
  845. auto input1_desc = op_info->MutableInputDesc(input1_name);
  846. vector<int64_t> input1_shape = input1_desc->MutableShape().GetDims();
  847. DataType input_dtype = input1_desc->GetDataType();
  848. // get input2 desc
  849. auto input2_desc = op_info->MutableInputDesc(input2_name);
  850. vector<int64_t> input2_shape = input2_desc->MutableShape().GetDims();
  851. if (IsUnknown(input1_shape) || IsUnknown(input2_shape)) {
  852. std::vector<std::pair<int64_t, int64_t>> input1_range;
  853. input1_desc->GetShapeRange(input1_range);
  854. std::vector<std::pair<int64_t, int64_t>> input2_range;
  855. input2_desc->GetShapeRange(input2_range);
  856. vector<int64_t> dim_size = {};
  857. FixShapeRangeWithDims(dim_size, input1_shape, input2_shape, input1_range, input2_range);
  858. // update output desc
  859. auto output_desc = op_info->MutableOutputDesc(0);
  860. for (const string& output_name : output_name_list) {
  861. output_desc = op_info->MutableOutputDesc(output_name);
  862. output_desc->SetShape(GeShape(input1_shape));
  863. output_desc->SetOriginShape(GeShape(input1_shape));
  864. output_desc->SetShapeRange(input1_range);
  865. output_desc->SetDataType(input_dtype);
  866. }
  867. } else {
  868. auto output_desc = op_info->MutableOutputDesc(0);
  869. for (const string& output_name : output_name_list) {
  870. output_desc = op_info->MutableOutputDesc(output_name);
  871. output_desc->SetShape(GeShape(input1_shape));
  872. output_desc->SetDataType(input_dtype);
  873. }
  874. }
  875. return true;
  876. }
  877. bool SetScalarOutputDesc(const string& input, const string& output, OpDescPtr op_desc, GeShape& output_shape) {
  878. if (output_shape.IsScalar()) {
  879. auto td = op_desc->MutableOutputDesc(output);
  880. td->SetShape(output_shape);
  881. td->SetOriginShape(output_shape);
  882. td->SetDataType(op_desc->MutableInputDesc(input)->GetDataType());
  883. td->SetOriginDataType(op_desc->MutableInputDesc(input)->GetDataType());
  884. return true;
  885. } else {
  886. return false;
  887. }
  888. }
  889. namespace array_ops {
  890. bool CheckInt64MulOverflow(int64_t a, int64_t b) {
  891. if (a > 0) {
  892. if (b > 0) {
  893. if (a >(INT64_MAX / b)) {
  894. return false;
  895. }
  896. } else {
  897. if (b < (INT64_MIN / a)) {
  898. return false;
  899. }
  900. }
  901. } else {
  902. if (b > 0) {
  903. if (a < (INT64_MIN / b)) {
  904. return false;
  905. }
  906. } else {
  907. if ((a != 0) && (b < (INT64_MAX / a))) {
  908. return false;
  909. }
  910. }
  911. }
  912. return true;
  913. }
  914. void ReshapeRangeInfer(const Operator &op, const std::vector<std::pair<int64_t, int64_t>>& x_range,
  915. int64_t& range_max) {
  916. for (const auto& ele : x_range) {
  917. if (ele.second < 0) {
  918. range_max = -1;
  919. return;
  920. }
  921. if (array_ops::CheckInt64MulOverflow(range_max, ele.second)) {
  922. range_max *= ele.second;
  923. } else {
  924. range_max = INT64_MAX;
  925. GE_OP_LOGW(op.GetName().c_str(), "Range Infer out of int64 max!Do set int64max!");
  926. return;
  927. }
  928. }
  929. }
  930. void ReshapeRangeInfer(const Operator &op, const std::vector<std::pair<int64_t, int64_t>>& x_range,
  931. std::vector<std::pair<int64_t, int64_t>>& y_range, GeShape& output_shape) {
  932. int64_t max_input_dims = 1;
  933. for (const auto& pair : x_range) {
  934. if (pair.second < 0) {
  935. max_input_dims = -1;
  936. break;
  937. }
  938. if (array_ops::CheckInt64MulOverflow(max_input_dims, pair.second)) {
  939. max_input_dims *= pair.second;
  940. } else {
  941. max_input_dims = INT64_MAX;
  942. GE_OP_LOGW(op.GetName().c_str(), "Range Infer out of int64 max!Do set int64max!");
  943. break;
  944. }
  945. }
  946. if (max_input_dims < 0) {
  947. for (const auto dim : output_shape.GetDims()) {
  948. if (dim < 0) {
  949. y_range.emplace_back(std::pair<int64_t, int64_t>(1, -1));
  950. } else {
  951. y_range.emplace_back(std::pair<int64_t, int64_t>(dim, dim));
  952. }
  953. }
  954. } else {
  955. int64_t left = max_input_dims;
  956. left = (left > INT32_MAX) ? INT32_MAX : left;
  957. for (const auto dim : output_shape.GetDims()) {
  958. if (dim < 0) {
  959. y_range.emplace_back(std::pair<int64_t, int64_t>(1, left));
  960. } else {
  961. y_range.emplace_back(std::pair<int64_t, int64_t>(dim, dim));
  962. if (dim != 0) {
  963. left = static_cast<int64_t>((static_cast<double>(left) + 0.5) / dim);
  964. }
  965. }
  966. }
  967. }
  968. }
  969. }
  970. } // namespace ge

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示