You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

array_ops.cc 69 kB

modified: CMakeLists.txt modified: build.sh modified: ge/ge_runtime/runtime_model.cc modified: ge/ge_runtime/task/aicpu_task.cc modified: ge/ge_runtime/task/hccl_task.cc modified: ge/ge_runtime/task/label_goto_task.cc modified: ge/ge_runtime/task/label_switch_task.cc new file: tests/st/CMakeLists.txt new file: tests/st/cmake/graphengine.cmake new file: tests/st/framework/CMakeLists.txt new file: tests/st/framework/framework.cc new file: tests/st/framework/framework.h new file: tests/st/framework/stub_engine/CMakeLists.txt new file: tests/st/framework/stub_engine/common/constant/constant.h new file: tests/st/framework/stub_engine/engine/stub_engine.cc new file: tests/st/framework/stub_engine/engine/stub_engine.h new file: tests/st/framework/stub_engine/ops_kernel_store/host_cpu_ops_kernel_builder.cc new file: tests/st/framework/stub_engine/ops_kernel_store/host_cpu_ops_kernel_builder.h new file: tests/st/framework/stub_engine/ops_kernel_store/host_cpu_ops_kernel_info.cc new file: tests/st/framework/stub_engine/ops_kernel_store/host_cpu_ops_kernel_info.h new file: tests/st/framework/stub_engine/ops_kernel_store/op/host_op.cc new file: tests/st/framework/stub_engine/ops_kernel_store/op/host_op.h new file: tests/st/framework/stub_engine/ops_kernel_store/op/op.h new file: tests/st/framework/stub_engine/ops_kernel_store/op/op_factory.cc new file: tests/st/framework/stub_engine/ops_kernel_store/op/op_factory.h new file: tests/st/framework/stub_engine/proto/task.proto new file: tests/st/framework/stub_op_proto/array_ops.cc new file: tests/st/framework/stub_op_proto/array_ops.h new file: tests/st/framework/stub_op_proto/control_flow_ops.cc new file: tests/st/framework/stub_op_proto/control_flow_ops.h new file: tests/st/framework/stub_op_proto/elewise_calculation_ops.cc new file: tests/st/framework/stub_op_proto/elewise_calculation_ops.h new file: tests/st/framework/stub_op_proto/util/array_ops_shape_fns.cc new file: tests/st/framework/stub_op_proto/util/array_ops_shape_fns.h new file: tests/st/framework/stub_op_proto/util/axis_util.cc new file: tests/st/framework/stub_op_proto/util/axis_util.h new file: tests/st/framework/stub_op_proto/util/common_shape_fns.cc new file: tests/st/framework/stub_op_proto/util/common_shape_fns.h new file: tests/st/framework/stub_op_proto/util/error_code.h new file: tests/st/framework/stub_op_proto/util/error_util.cc new file: tests/st/framework/stub_op_proto/util/error_util.h new file: tests/st/framework/stub_op_proto/util/op_common_util.h new file: tests/st/framework/stub_op_proto/util/op_log.h new file: tests/st/framework/stub_op_proto/util/transfer_shape_according_to_format.cc new file: tests/st/framework/stub_op_proto/util/transfer_shape_according_to_format.h new file: tests/st/framework/stub_op_proto/util/util.cc new file: tests/st/framework/stub_op_proto/util/util.h new file: tests/st/framework/utils/assertion/graph_assertion.cc new file: tests/st/framework/utils/assertion/graph_assertion.h new file: tests/st/framework/utils/builder/graph_builder_utils.cc new file: tests/st/framework/utils/builder/graph_builder_utils.h new file: tests/st/framework/utils/builder/tensor_builder_utils.cc new file: tests/st/framework/utils/builder/tensor_builder_utils.h new file: tests/st/test.cc new file: tests/st/testcase/CMakeLists.txt new file: tests/st/testcase/test_framework_dummy.cc modified: CMakeLists.txt modified: build.sh modified: ge/ge_runtime/runtime_model.cc modified: ge/ge_runtime/task/aicpu_task.cc modified: ge/ge_runtime/task/hccl_task.cc modified: ge/ge_runtime/task/label_goto_task.cc modified: ge/ge_runtime/task/label_switch_task.cc new file: tests/st/CMakeLists.txt new file: tests/st/cmake/graphengine.cmake new file: tests/st/framework/CMakeLists.txt new file: tests/st/framework/framework.cc new file: tests/st/framework/framework.h new file: tests/st/framework/stub_engine/CMakeLists.txt new file: tests/st/framework/stub_engine/common/constant/constant.h new file: tests/st/framework/stub_engine/engine/stub_engine.cc new file: tests/st/framework/stub_engine/engine/stub_engine.h new file: tests/st/framework/stub_engine/ops_kernel_store/host_cpu_ops_kernel_builder.cc new file: tests/st/framework/stub_engine/ops_kernel_store/host_cpu_ops_kernel_builder.h new file: tests/st/framework/stub_engine/ops_kernel_store/host_cpu_ops_kernel_info.cc new file: tests/st/framework/stub_engine/ops_kernel_store/host_cpu_ops_kernel_info.h new file: tests/st/framework/stub_engine/ops_kernel_store/op/host_op.cc new file: tests/st/framework/stub_engine/ops_kernel_store/op/host_op.h new file: tests/st/framework/stub_engine/ops_kernel_store/op/op.h new file: tests/st/framework/stub_engine/ops_kernel_store/op/op_factory.cc new file: tests/st/framework/stub_engine/ops_kernel_store/op/op_factory.h new file: tests/st/framework/stub_engine/proto/task.proto new file: tests/st/framework/stub_op_proto/array_ops.cc new file: tests/st/framework/stub_op_proto/array_ops.h new file: tests/st/framework/stub_op_proto/control_flow_ops.cc new file: tests/st/framework/stub_op_proto/control_flow_ops.h new file: tests/st/framework/stub_op_proto/elewise_calculation_ops.cc new file: tests/st/framework/stub_op_proto/elewise_calculation_ops.h new file: tests/st/framework/stub_op_proto/util/array_ops_shape_fns.cc new file: tests/st/framework/stub_op_proto/util/array_ops_shape_fns.h new file: tests/st/framework/stub_op_proto/util/axis_util.cc new file: tests/st/framework/stub_op_proto/util/axis_util.h new file: tests/st/framework/stub_op_proto/util/common_shape_fns.cc new file: tests/st/framework/stub_op_proto/util/common_shape_fns.h new file: tests/st/framework/stub_op_proto/util/error_code.h new file: tests/st/framework/stub_op_proto/util/error_util.cc new file: tests/st/framework/stub_op_proto/util/error_util.h new file: tests/st/framework/stub_op_proto/util/op_common_util.h new file: tests/st/framework/stub_op_proto/util/op_log.h new file: tests/st/framework/stub_op_proto/util/transfer_shape_according_to_format.cc new file: tests/st/framework/stub_op_proto/util/transfer_shape_according_to_format.h new file: tests/st/framework/stub_op_proto/util/util.cc new file: tests/st/framework/stub_op_proto/util/util.h new file: tests/st/framework/utils/assertion/graph_assertion.cc new file: tests/st/framework/utils/assertion/graph_assertion.h new file: tests/st/framework/utils/builder/graph_builder_utils.cc new file: tests/st/framework/utils/builder/graph_builder_utils.h new file: tests/st/framework/utils/builder/tensor_builder_utils.cc new file: tests/st/framework/utils/builder/tensor_builder_utils.h new file: tests/st/test.cc new file: tests/st/testcase/CMakeLists.txt new file: tests/st/testcase/test_framework_dummy.cc modified: CMakeLists.txt modified: build.sh modified: ge/ge_runtime/runtime_model.cc modified: ge/ge_runtime/task/aicpu_task.cc modified: ge/ge_runtime/task/hccl_task.cc modified: ge/ge_runtime/task/label_goto_task.cc modified: ge/ge_runtime/task/label_switch_task.cc new file: tests/st/CMakeLists.txt new file: tests/st/cmake/graphengine.cmake new file: tests/st/framework/CMakeLists.txt new file: tests/st/framework/framework.cc new file: tests/st/framework/framework.h new file: tests/st/framework/stub_engine/CMakeLists.txt new file: tests/st/framework/stub_engine/common/constant/constant.h new file: tests/st/framework/stub_engine/engine/stub_engine.cc new file: tests/st/framework/stub_engine/engine/stub_engine.h new file: tests/st/framework/stub_engine/ops_kernel_store/host_cpu_ops_kernel_builder.cc new file: tests/st/framework/stub_engine/ops_kernel_store/host_cpu_ops_kernel_builder.h new file: tests/st/framework/stub_engine/ops_kernel_store/host_cpu_ops_kernel_info.cc new file: tests/st/framework/stub_engine/ops_kernel_store/host_cpu_ops_kernel_info.h new file: tests/st/framework/stub_engine/ops_kernel_store/op/host_op.cc new file: tests/st/framework/stub_engine/ops_kernel_store/op/host_op.h new file: tests/st/framework/stub_engine/ops_kernel_store/op/op.h new file: tests/st/framework/stub_engine/ops_kernel_store/op/op_factory.cc new file: tests/st/framework/stub_engine/ops_kernel_store/op/op_factory.h new file: tests/st/framework/stub_engine/proto/task.proto new file: tests/st/framework/stub_op_proto/array_ops.cc new file: tests/st/framework/stub_op_proto/array_ops.h new file: tests/st/framework/stub_op_proto/control_flow_ops.cc new file: tests/st/framework/stub_op_proto/control_flow_ops.h new file: tests/st/framework/stub_op_proto/elewise_calculation_ops.cc new file: tests/st/framework/stub_op_proto/elewise_calculation_ops.h new file: tests/st/framework/stub_op_proto/util/array_ops_shape_fns.cc new file: tests/st/framework/stub_op_proto/util/array_ops_shape_fns.h new file: tests/st/framework/stub_op_proto/util/axis_util.cc new file: tests/st/framework/stub_op_proto/util/axis_util.h new file: tests/st/framework/stub_op_proto/util/common_shape_fns.cc new file: tests/st/framework/stub_op_proto/util/common_shape_fns.h new file: tests/st/framework/stub_op_proto/util/error_code.h new file: tests/st/framework/stub_op_proto/util/error_util.cc new file: tests/st/framework/stub_op_proto/util/error_util.h new file: tests/st/framework/stub_op_proto/util/op_common_util.h new file: tests/st/framework/stub_op_proto/util/op_log.h new file: tests/st/framework/stub_op_proto/util/transfer_shape_according_to_format.cc new file: tests/st/framework/stub_op_proto/util/transfer_shape_according_to_format.h new file: tests/st/framework/stub_op_proto/util/util.cc new file: tests/st/framework/stub_op_proto/util/util.h new file: tests/st/framework/utils/assertion/graph_assertion.cc new file: tests/st/framework/utils/assertion/graph_assertion.h new file: tests/st/framework/utils/builder/graph_builder_utils.cc new file: tests/st/framework/utils/builder/graph_builder_utils.h new file: tests/st/framework/utils/builder/tensor_builder_utils.cc new file: tests/st/framework/utils/builder/tensor_builder_utils.h new file: tests/st/test.cc new file: tests/st/testcase/CMakeLists.txt new file: tests/st/testcase/test_framework_dummy.cc modified: CMakeLists.txt modified: build.sh modified: ge/ge_runtime/runtime_model.cc modified: ge/ge_runtime/task/aicpu_task.cc modified: ge/ge_runtime/task/hccl_task.cc modified: ge/ge_runtime/task/label_goto_task.cc modified: ge/ge_runtime/task/label_switch_task.cc new file: tests/st/CMakeLists.txt new file: tests/st/cmake/graphengine.cmake new file: tests/st/framework/CMakeLists.txt new file: tests/st/framework/framework.cc new file: tests/st/framework/framework.h new file: tests/st/framework/stub_engine/CMakeLists.txt new file: tests/st/framework/stub_engine/common/constant/constant.h new file: tests/st/framework/stub_engine/engine/stub_engine.cc new file: tests/st/framework/stub_engine/engine/stub_engine.h new file: tests/st/framework/stub_engine/ops_kernel_store/host_cpu_ops_kernel_builder.cc new file: tests/st/framework/stub_engine/ops_kernel_store/host_cpu_ops_kernel_builder.h new file: tests/st/framework/stub_engine/ops_kernel_store/host_cpu_ops_kernel_info.cc new file: tests/st/framework/stub_engine/ops_kernel_store/host_cpu_ops_kernel_info.h new file: tests/st/framework/stub_engine/ops_kernel_store/op/host_op.cc new file: tests/st/framework/stub_engine/ops_kernel_store/op/host_op.h new file: tests/st/framework/stub_engine/ops_kernel_store/op/op.h new file: tests/st/framework/stub_engine/ops_kernel_store/op/op_factory.cc new file: tests/st/framework/stub_engine/ops_kernel_store/op/op_factory.h new file: tests/st/framework/stub_engine/proto/task.proto new file: tests/st/framework/stub_op_proto/array_ops.cc new file: tests/st/framework/stub_op_proto/array_ops.h new file: tests/st/framework/stub_op_proto/control_flow_ops.cc new file: tests/st/framework/stub_op_proto/control_flow_ops.h new file: tests/st/framework/stub_op_proto/elewise_calculation_ops.cc new file: tests/st/framework/stub_op_proto/elewise_calculation_ops.h new file: tests/st/framework/stub_op_proto/util/array_ops_shape_fns.cc new file: tests/st/framework/stub_op_proto/util/array_ops_shape_fns.h new file: tests/st/framework/stub_op_proto/util/axis_util.cc new file: tests/st/framework/stub_op_proto/util/axis_util.h new file: tests/st/framework/stub_op_proto/util/common_shape_fns.cc new file: tests/st/framework/stub_op_proto/util/common_shape_fns.h new file: tests/st/framework/stub_op_proto/util/error_code.h new file: tests/st/framework/stub_op_proto/util/error_util.cc new file: tests/st/framework/stub_op_proto/util/error_util.h new file: tests/st/framework/stub_op_proto/util/op_common_util.h new file: tests/st/framework/stub_op_proto/util/op_log.h new file: tests/st/framework/stub_op_proto/util/transfer_shape_according_to_format.cc new file: tests/st/framework/stub_op_proto/util/transfer_shape_according_to_format.h new file: tests/st/framework/stub_op_proto/util/util.cc new file: tests/st/framework/stub_op_proto/util/util.h new file: tests/st/framework/utils/assertion/graph_assertion.cc new file: tests/st/framework/utils/assertion/graph_assertion.h new file: tests/st/framework/utils/builder/graph_builder_utils.cc new file: tests/st/framework/utils/builder/graph_builder_utils.h new file: tests/st/framework/utils/builder/tensor_builder_utils.cc new file: tests/st/framework/utils/builder/tensor_builder_utils.h new file: tests/st/testcase/CMakeLists.txt new file: tests/st/testcase/test_framework_dummy.cc
4 years ago
12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763
  1. /**
  2. * Copyright 2021 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. /*!
  17. * \file array_ops.cpp
  18. * \brief
  19. */
  20. #include "array_ops.h"
  21. #include <climits>
  22. #include <unordered_set>
  23. #include <utility>
  24. #include "./util/op_log.h"
  25. #include "./util/common_shape_fns.h"
  26. #include "./util/array_ops_shape_fns.h"
  27. #include "graph/utils/tensor_adapter.h"
  28. #include "graph/utils/node_utils.h"
  29. #include "./util/error_util.h"
  30. #include "util/util.h"
  31. namespace ge {
  32. const char* const kShape = "shape";
  33. const char* const kShapeDtype = "shape dtype";
  34. const char* const kAttrShape = "attr shape";
  35. const char* const kAttrDtype = "attr dtype";
  36. const char* const kAttrAxis = "attr axis";
  37. const char* const kAttrNumAxes = "attr num_axes";
  38. const char* const kPreOpInputShapeRange = "_pre_op_in_range";
  39. const int64_t kMaxDimNum = 8;
  40. IMPLEMT_INFERFUNC(Unique, UniqueInfer) {
  41. OpDescPtr op_desc = OpDescUtils::GetOpDescFromOperator(op);
  42. GeTensorDescPtr x_input = op_desc->MutableInputDesc(0);
  43. GeShape x_shape;
  44. if (WithRank(x_input, 1, x_shape) != GRAPH_SUCCESS) {
  45. ShapeErrReport(0, op.GetName(), DebugString(x_input->GetShape().GetDims()), "1D");
  46. OP_LOGE(op.GetName().c_str(), "input x must be 1-D");
  47. return GRAPH_FAILED;
  48. }
  49. DataType idx_type;
  50. if (op.GetAttr("out_idx", idx_type) != GRAPH_SUCCESS) {
  51. OP_LOGE(op.GetName().c_str(), "Op get attr out_idx failed");
  52. return GRAPH_FAILED;
  53. }
  54. GeTensorDescPtr idx_desc = op_desc->MutableOutputDesc(1);
  55. idx_desc->SetShape(x_shape);
  56. idx_desc->SetOriginShape(x_shape);
  57. idx_desc->SetDataType(idx_type);
  58. GeTensorDescPtr y_desc = op_desc->MutableOutputDesc(0);
  59. y_desc->SetShape(GeShape({UNKNOWN_DIM}));
  60. y_desc->SetOriginShape(GeShape({UNKNOWN_DIM}));
  61. y_desc->SetDataType(x_input->GetDataType());
  62. if (x_shape.GetShapeSize() == UNKNOWN_DIM) {
  63. return GRAPH_SUCCESS;
  64. } else {
  65. std::vector<std::pair<int64_t, int64_t>> range;
  66. int64_t max_dim = x_shape.GetDim(0);
  67. range.emplace_back(std::make_pair(1, max_dim));
  68. y_desc->SetShapeRange(range);
  69. return GRAPH_SUCCESS;
  70. }
  71. }
  72. INFER_FUNC_REG(Unique, UniqueInfer);
  73. IMPLEMT_INFERFUNC(Const, ConstInfer) {
  74. auto value = op.get_attr_value();
  75. auto valDesc = value.GetTensorDesc();
  76. auto dims = valDesc.GetShape().GetDims();
  77. auto attrDtype = valDesc.GetDataType();
  78. TensorDesc outDesc = op.get_output_desc_y();
  79. outDesc.SetDataType(ge::DataType(attrDtype));
  80. outDesc.SetShape(Shape(dims));
  81. (void)op.update_output_desc_y(outDesc);
  82. return GRAPH_SUCCESS;
  83. }
  84. INFER_FUNC_REG(Const, ConstInfer);
  85. IMPLEMT_INFERFUNC(Constant, ConstantInfer) {
  86. auto value = op.get_attr_value();
  87. auto valDesc = value.GetTensorDesc();
  88. auto dims = valDesc.GetShape().GetDims();
  89. auto attrDtype = valDesc.GetDataType();
  90. TensorDesc outDesc = op.get_output_desc_y();
  91. outDesc.SetDataType(ge::DataType(attrDtype));
  92. outDesc.SetShape(Shape(dims));
  93. (void)op.update_output_desc_y(outDesc);
  94. return GRAPH_SUCCESS;
  95. }
  96. INFER_FUNC_REG(Constant, ConstantInfer);
  97. graphStatus ConstAndConstantInferFormat(ge::Operator& op) {
  98. OP_LOGI(op.GetName().c_str(), "Const infer format start");
  99. auto op_desc = OpDescUtils::GetOpDescFromOperator(op);
  100. auto format = op_desc->MutableOutputDesc(0)->GetOriginFormat();
  101. ConstGeTensorPtr tensor_value;
  102. if (!AttrUtils::GetTensor(op_desc, "value", tensor_value)) {
  103. OP_LOGE(op.GetName().c_str(), "Get attr value failed!");
  104. return GRAPH_FAILED;
  105. }
  106. if (!tensor_value) {
  107. OP_LOGE(op.GetName().c_str(), "attr tensor is not exist!");
  108. return GRAPH_FAILED;
  109. }
  110. auto tensor_ptr = const_cast<GeTensor*>(tensor_value.get());
  111. tensor_ptr->MutableTensorDesc().SetOriginFormat(format);
  112. tensor_ptr->MutableTensorDesc().SetFormat(format);
  113. return GRAPH_SUCCESS;
  114. }
  115. IMPLEMT_INFERFORMAT_FUNC(Const, ConstInferFormat) {
  116. return ConstAndConstantInferFormat(op);
  117. }
  118. INFER_FORMAT_FUNC_REG(Const, ConstInferFormat);
  119. IMPLEMT_INFERFUNC(Snapshot, SnapshotInferFunc) {
  120. OP_LOGI(op.GetName().c_str(), "Snapshot infershape start");
  121. auto op_desc = OpDescUtils::GetOpDescFromOperator(op);
  122. auto input_desc_x = op_desc->MutableInputDesc("x");
  123. auto output_desc_y = op_desc->MutableOutputDesc("y");
  124. auto x_dims = input_desc_x->MutableShape().GetDims();
  125. auto x_type = input_desc_x->GetDataType();
  126. std::vector<std::pair<int64_t, int64_t>> x_range;
  127. input_desc_x->GetShapeRange(x_range);
  128. output_desc_y->SetShape(GeShape(x_dims));
  129. output_desc_y->SetOriginShape(GeShape(x_dims));
  130. output_desc_y->SetShapeRange(x_range);
  131. output_desc_y->SetDataType(x_type);
  132. OP_LOGI(op.GetName().c_str(), "Snapshot infershape end");
  133. return GRAPH_SUCCESS;
  134. }
  135. INFER_FUNC_REG(Snapshot, SnapshotInferFunc);
  136. IMPLEMT_INFERFUNC(GuaranteeConst, GuaranteeConstInfer) {
  137. TensorDesc tensorDesc = op.GetInputDesc("x");
  138. (void)op.UpdateOutputDesc("y", tensorDesc);
  139. return GRAPH_SUCCESS;
  140. }
  141. INFER_FUNC_REG(GuaranteeConst, GuaranteeConstInfer);
  142. IMPLEMT_INFERFUNC(BroadcastArgs, BroadcastArgsInferFunc) {
  143. auto op_desc = OpDescUtils::GetOpDescFromOperator(op);
  144. auto x1_desc = op_desc->MutableInputDesc("x1");
  145. auto x2_desc = op_desc->MutableInputDesc("x2");
  146. auto y_desc = op_desc->MutableOutputDesc("y");
  147. auto x1_dims = x1_desc->GetShape().GetDims();
  148. auto x2_dims = x2_desc->GetShape().GetDims();
  149. auto data_type = x1_desc->GetDataType();
  150. std::vector<std::pair<int64_t, int64_t>> x1_range;
  151. std::vector<std::pair<int64_t, int64_t>> x2_range;
  152. std::vector<std::pair<int64_t, int64_t>> out_range;
  153. x1_desc->GetShapeRange(x1_range);
  154. x2_desc->GetShapeRange(x2_range);
  155. bool data_type_check = ((x1_desc->GetDataType() != DT_INT32 && x1_desc->GetDataType() != DT_INT64) ||
  156. (x2_desc->GetDataType() != DT_INT32 && x2_desc->GetDataType() != DT_INT64));
  157. if (data_type_check) {
  158. string reason = "x1[" + std::to_string(x1_desc->GetDataType()) + "] + and + x2[" +
  159. std::to_string(x1_desc->GetDataType()) + "] must DT_INT32 or DT_INT64";
  160. GeInfershapeErrReport(op.GetName(), op.GetOpType(), "dtype", reason);
  161. GE_OP_LOGE(op.GetName().c_str(), "Data type check fail. x1[%u] and x2[%u] must DT_INT32 or DT_INT64",
  162. x1_desc->GetDataType(), x2_desc->GetDataType());
  163. return GRAPH_PARAM_INVALID;
  164. }
  165. if (x1_dims.size() > 1 || x2_dims.size() > 1) {
  166. string reason = "x1[" + std::to_string(x1_dims.size()) + "] + and + x2[" + std::to_string(x2_dims.size()) +
  167. "] must be less than or equal to 1";
  168. GeInfershapeErrReport(op.GetName(), op.GetOpType(), "dims", reason);
  169. GE_OP_LOGE(op.GetName().c_str(), "Size check fail. x1[%u] and x2[%u] must be less than or equal to 1",
  170. x1_dims.size(), x2_dims.size());
  171. return GRAPH_PARAM_INVALID;
  172. }
  173. if (x1_dims == UNKNOWN_RANK || x2_dims == UNKNOWN_RANK) {
  174. GE_OP_LOGD(op.GetName().c_str(), "all two inputs are unknown rank!");
  175. y_desc->SetShape(GeShape(UNKNOWN_SHAPE));
  176. y_desc->SetOriginShape(GeShape(UNKNOWN_SHAPE));
  177. y_desc->SetDataType(data_type);
  178. return GRAPH_SUCCESS;
  179. }
  180. if (x1_dims == UNKNOWN_SHAPE && x2_dims == UNKNOWN_SHAPE) {
  181. GE_OP_LOGD(op.GetName().c_str(), "all two inputs are unknown shape!");
  182. y_desc->SetShape(GeShape(UNKNOWN_SHAPE));
  183. y_desc->SetOriginShape(GeShape(UNKNOWN_SHAPE));
  184. y_desc->SetDataType(data_type);
  185. y_desc->SetShapeRange(x1_range);
  186. return GRAPH_SUCCESS;
  187. } else if (x1_dims == UNKNOWN_SHAPE) {
  188. GE_OP_LOGD(op.GetName().c_str(), "x1 is unknown shape!");
  189. int64_t range_max = x2_dims.size();
  190. std::pair<int64_t, int64_t> pair({1, range_max});
  191. out_range.emplace_back(pair);
  192. y_desc->SetShape(GeShape(UNKNOWN_SHAPE));
  193. y_desc->SetOriginShape(GeShape(UNKNOWN_SHAPE));
  194. y_desc->SetDataType(data_type);
  195. y_desc->SetShapeRange(out_range);
  196. return GRAPH_SUCCESS;
  197. } else if (x2_dims == UNKNOWN_SHAPE) {
  198. GE_OP_LOGD(op.GetName().c_str(), "x2 is unknown shape!");
  199. int64_t range_max = x2_dims.size();
  200. std::pair<int64_t, int64_t> pair({1, range_max});
  201. out_range.emplace_back(pair);
  202. y_desc->SetShape(GeShape(UNKNOWN_SHAPE));
  203. y_desc->SetOriginShape(GeShape(UNKNOWN_SHAPE));
  204. y_desc->SetDataType(data_type);
  205. y_desc->SetShapeRange(out_range);
  206. return GRAPH_SUCCESS;
  207. }
  208. if (x1_dims.empty()) {
  209. y_desc->SetShape(GeShape(x2_dims));
  210. } else if (x2_dims.empty()) {
  211. y_desc->SetShape(GeShape(x1_dims));
  212. } else {
  213. auto dims = x1_dims[0] > x2_dims[0] ? x1_dims : x2_dims;
  214. y_desc->SetShape(GeShape(dims));
  215. }
  216. int64_t range_max = x1_dims.size() > x2_dims.size() ? x1_dims.size() : x2_dims.size();
  217. std::pair<int64_t, int64_t> pair({1, range_max});
  218. out_range.emplace_back(pair);
  219. y_desc->SetShapeRange(out_range);
  220. y_desc->SetDataType(x1_desc->GetDataType());
  221. return GRAPH_SUCCESS;
  222. }
  223. INFER_FUNC_REG(BroadcastArgs, BroadcastArgsInferFunc);
  224. IMPLEMT_INFERFUNC(BroadcastGradientArgs, BroadcastGradientArgsInfer) {
  225. auto op_desc = OpDescUtils::GetOpDescFromOperator(op);
  226. auto input_desc_x1 = op_desc->MutableInputDesc("x1");
  227. auto input_desc_x2 = op_desc->MutableInputDesc("x2");
  228. auto output_desc_y1 = op_desc->MutableOutputDesc("y1");
  229. auto output_desc_y2 = op_desc->MutableOutputDesc("y2");
  230. auto dims_x1 = input_desc_x1->MutableShape().GetDims();
  231. auto dims_x2 = input_desc_x2->MutableShape().GetDims();
  232. auto x1_type = input_desc_x1->GetDataType();
  233. auto x2_type = input_desc_x2->GetDataType();
  234. std::vector<std::pair<int64_t, int64_t>> x1_range;
  235. std::vector<std::pair<int64_t, int64_t>> x2_range;
  236. std::vector<std::pair<int64_t, int64_t>> out_range;
  237. input_desc_x1->GetShapeRange(x1_range);
  238. input_desc_x2->GetShapeRange(x2_range);
  239. if (dims_x1 == UNKNOWN_RANK || dims_x2 == UNKNOWN_RANK) {
  240. GE_OP_LOGD(op.GetName().c_str(), "all two inputs are unknown rank!");
  241. output_desc_y1->SetShape(GeShape(UNKNOWN_SHAPE));
  242. output_desc_y1->SetOriginShape(GeShape(UNKNOWN_SHAPE));
  243. output_desc_y1->SetDataType(x1_type);
  244. output_desc_y2->SetShape(GeShape(UNKNOWN_SHAPE));
  245. output_desc_y2->SetOriginShape(GeShape(UNKNOWN_SHAPE));
  246. output_desc_y2->SetDataType(x2_type);
  247. return GRAPH_SUCCESS;
  248. }
  249. // Input Dim Num must be equal or smaller than 1
  250. if (dims_x1 == UNKNOWN_SHAPE && dims_x2 == UNKNOWN_SHAPE) {
  251. GE_OP_LOGD(op.GetName().c_str(), "all two inputs are unknown shape!");
  252. output_desc_y1->SetShape(GeShape(UNKNOWN_SHAPE));
  253. output_desc_y1->SetOriginShape(GeShape(UNKNOWN_SHAPE));
  254. output_desc_y1->SetDataType(x1_type);
  255. output_desc_y1->SetShapeRange(x1_range);
  256. output_desc_y2->SetShape(GeShape(UNKNOWN_SHAPE));
  257. output_desc_y2->SetOriginShape(GeShape(UNKNOWN_SHAPE));
  258. output_desc_y2->SetDataType(x2_type);
  259. output_desc_y2->SetShapeRange(x2_range);
  260. return GRAPH_SUCCESS;
  261. } else if (dims_x1 == UNKNOWN_SHAPE) {
  262. GE_OP_LOGD(op.GetName().c_str(), "x1 is unknown shape!");
  263. int64_t range_max = dims_x2.size();
  264. std::pair<int64_t, int64_t> pair({1, range_max});
  265. out_range.emplace_back(pair);
  266. output_desc_y1->SetShape(GeShape(UNKNOWN_SHAPE));
  267. output_desc_y1->SetOriginShape(GeShape(UNKNOWN_SHAPE));
  268. output_desc_y1->SetDataType(x1_type);
  269. output_desc_y1->SetShapeRange(out_range);
  270. output_desc_y2->SetShape(GeShape(UNKNOWN_SHAPE));
  271. output_desc_y2->SetOriginShape(GeShape(UNKNOWN_SHAPE));
  272. output_desc_y2->SetDataType(x2_type);
  273. output_desc_y2->SetShapeRange(out_range);
  274. return GRAPH_SUCCESS;
  275. } else if (dims_x2 == UNKNOWN_SHAPE) {
  276. GE_OP_LOGD(op.GetName().c_str(), "x2 is unknown shape!");
  277. int64_t range_max = dims_x1.size();
  278. std::pair<int64_t, int64_t> pair({1, range_max});
  279. out_range.emplace_back(pair);
  280. output_desc_y1->SetShape(GeShape(UNKNOWN_SHAPE));
  281. output_desc_y1->SetOriginShape(GeShape(UNKNOWN_SHAPE));
  282. output_desc_y1->SetDataType(x1_type);
  283. output_desc_y1->SetShapeRange(out_range);
  284. output_desc_y2->SetShape(GeShape(UNKNOWN_SHAPE));
  285. output_desc_y2->SetOriginShape(GeShape(UNKNOWN_SHAPE));
  286. output_desc_y2->SetDataType(x2_type);
  287. output_desc_y2->SetShapeRange(out_range);
  288. return GRAPH_SUCCESS;
  289. }
  290. GE_OP_LOGD(op.GetName().c_str(), "all two inputs are known shape!");
  291. int64_t range_max = dims_x1.size() == 0 ? 1 : dims_x1.size();
  292. std::pair<int64_t, int64_t> pair({1, range_max});
  293. out_range.emplace_back(pair);
  294. output_desc_y1->SetDataType(x1_type);
  295. output_desc_y2->SetDataType(x2_type);
  296. output_desc_y1->SetShape(GeShape(UNKNOWN_SHAPE));
  297. output_desc_y1->SetOriginShape(GeShape(UNKNOWN_SHAPE));
  298. output_desc_y2->SetShape(GeShape(UNKNOWN_SHAPE));
  299. output_desc_y2->SetOriginShape(GeShape(UNKNOWN_SHAPE));
  300. output_desc_y1->SetShapeRange(out_range);
  301. output_desc_y2->SetShapeRange(out_range);
  302. return GRAPH_SUCCESS;
  303. }
  304. INFER_FUNC_REG(BroadcastGradientArgs, BroadcastGradientArgsInfer);
  305. IMPLEMT_INFERFUNC(PreventGradient, PreventGradientInferFunc) {
  306. OP_LOGI(op.GetName().c_str(), "PreventGradient infershape start");
  307. auto op_desc = OpDescUtils::GetOpDescFromOperator(op);
  308. auto input_desc_x = op_desc->MutableInputDesc("x");
  309. auto output_desc_y = op_desc->MutableOutputDesc("y");
  310. auto x_dims = input_desc_x->MutableShape().GetDims();
  311. auto x_type = input_desc_x->GetDataType();
  312. std::vector<std::pair<int64_t, int64_t>> x_range;
  313. input_desc_x->GetShapeRange(x_range);
  314. output_desc_y->SetShape(GeShape(x_dims));
  315. output_desc_y->SetOriginShape(GeShape(x_dims));
  316. output_desc_y->SetShapeRange(x_range);
  317. output_desc_y->SetDataType(x_type);
  318. OP_LOGI(op.GetName().c_str(), "PreventGradient infershape end");
  319. return GRAPH_SUCCESS;
  320. }
  321. INFER_FUNC_REG(PreventGradient, PreventGradientInferFunc);
  322. IMPLEMT_INFERFUNC(StopGradient, StopGradientInferFunc) {
  323. OP_LOGI(op.GetName().c_str(), "StopGradient infershape start");
  324. auto op_desc = OpDescUtils::GetOpDescFromOperator(op);
  325. auto input_desc_x = op_desc->MutableInputDesc("x");
  326. auto output_desc_y = op_desc->MutableOutputDesc("y");
  327. auto x_dims = input_desc_x->MutableShape().GetDims();
  328. auto x_type = input_desc_x->GetDataType();
  329. std::vector<std::pair<int64_t, int64_t>> x_range;
  330. input_desc_x->GetShapeRange(x_range);
  331. output_desc_y->SetShape(GeShape(x_dims));
  332. output_desc_y->SetOriginShape(GeShape(x_dims));
  333. output_desc_y->SetShapeRange(x_range);
  334. output_desc_y->SetShapeRange(x_range);
  335. output_desc_y->SetDataType(x_type);
  336. OP_LOGI(op.GetName().c_str(), "StopGradient infershape end");
  337. return GRAPH_SUCCESS;
  338. }
  339. INFER_FUNC_REG(StopGradient, StopGradientInferFunc);
  340. IMPLEMT_INFERFUNC(ExpandDims, ExpandDimsInfer) {
  341. std::vector<string> dep_inputs = {"axis"};
  342. auto op_desc = OpDescUtils::GetOpDescFromOperator(op);
  343. auto node = NodeUtils::GetNodeFromOperator(op);
  344. if (node == nullptr) {
  345. GE_OP_LOGE(op.GetName().c_str(), "get null node ptr");
  346. return GRAPH_FAILED;
  347. }
  348. auto x_desc = op_desc->MutableInputDesc("x");
  349. auto axis_desc = op_desc->MutableInputDesc("axis");
  350. auto y_desc = op_desc->MutableOutputDesc("y");
  351. op_desc->SetOpInferDepends(dep_inputs);
  352. auto axis_type = axis_desc->GetDataType();
  353. auto x_type = x_desc->GetDataType();
  354. if (axis_type != DT_INT32 && axis_type != DT_INT64) {
  355. string reason = "axis dtype[" + std::to_string(axis_type) + "] must int32 or int64";
  356. GeInfershapeErrReport(op.GetName(), op.GetOpType(), kAttrDtype, reason);
  357. GE_OP_LOGE(op.GetName().c_str(), "axis dtype[%d] must int32 or int64", axis_type);
  358. return GRAPH_PARAM_INVALID;
  359. }
  360. bool is_x_unknonwn_rank = x_desc->MutableShape().GetDims() == UNKNOWN_RANK ? true : false;
  361. if (is_x_unknonwn_rank) {
  362. GE_OP_LOGD("input x shape is unknown rank!");
  363. y_desc->SetUnknownDimNumShape();
  364. y_desc->SetDataType(x_type);
  365. y_desc->SetOriginDataType(x_type);
  366. return GRAPH_SUCCESS;
  367. }
  368. int64_t axis_nums = axis_desc->MutableShape().GetShapeSize();
  369. if (axis_nums != 1) {
  370. // Shape::GetDims().size() == 0, means it's a scalar, its shape is [].
  371. if (!(axis_nums == 0 && axis_desc->MutableShape().GetDims().size() == 0)) {
  372. string reason = "axis input must be a tensor with a single value, but [" + std::to_string(axis_nums) + "] nums";
  373. GeInfershapeErrReport(op.GetName(), op.GetOpType(), "axis", reason);
  374. GE_OP_LOGE(op.GetName().c_str(), "'axis' input must be a tensor with a single value, but %d nums", axis_nums);
  375. return GRAPH_PARAM_INVALID;
  376. }
  377. }
  378. GeTensorPtr tensor_axis = nullptr;
  379. graphStatus status = NodeUtils::GetInputConstData(node, "axis", tensor_axis);
  380. if (status != GRAPH_SUCCESS) {
  381. GE_OP_LOGI(op.GetName().c_str(), "Op get input const data of axis failed");
  382. auto x_shape_size = x_desc->MutableShape().GetDims().size();
  383. std::vector<int64_t> out_dims(x_shape_size + 1, UNKNOWN_DIM);
  384. y_desc->SetShape(GeShape(out_dims));
  385. y_desc->SetOriginShape(GeShape(out_dims));
  386. y_desc->SetDataType(x_type);
  387. y_desc->SetOriginDataType(x_type);
  388. // infer shape range
  389. std::vector<std::pair<int64_t, int64_t>> x_range;
  390. (void)x_desc->GetShapeRange(x_range);
  391. if (x_range.empty()) {
  392. GE_OP_LOGD(op.GetName().c_str(), "last op does not set shape range!");
  393. return GRAPH_SUCCESS;
  394. }
  395. if (x_range.size() != x_shape_size) {
  396. GE_OP_LOGE(op.GetName().c_str(),
  397. "input range size num[%zu] should be same with input shape size[%zu]", x_range.size(), x_shape_size);
  398. return GRAPH_FAILED;
  399. }
  400. int64_t max_range_value = 1;
  401. for (const auto &ele : x_range) {
  402. if (ele.second > max_range_value) {
  403. max_range_value = ele.second;
  404. }
  405. }
  406. std::vector<std::pair<int64_t, int64_t>> y_range(x_shape_size + 1, std::pair<int64_t, int64_t>({1, max_range_value}));
  407. y_desc->SetShapeRange(y_range);
  408. return GRAPH_SUCCESS;
  409. }
  410. auto pbuff = tensor_axis->GetData().GetData();
  411. if (pbuff == nullptr) {
  412. GE_OP_LOGE(op.GetName().c_str(), "no const data when get data from tensor!");
  413. return GRAPH_FAILED;
  414. }
  415. int64_t axis;
  416. if (axis_type == DT_INT32) {
  417. axis = *const_cast<int32_t*>(reinterpret_cast<const int32_t*>(pbuff));
  418. } else if (axis_type == DT_INT64) {
  419. axis = *const_cast<int64_t*>(reinterpret_cast<const int64_t*>(pbuff));
  420. }
  421. std::vector<int64_t> vec_dim;
  422. int32_t dim_num = x_desc->MutableShape().GetDimNum();
  423. if (axis < -1 - dim_num || axis > dim_num) {
  424. string reason = "axis[" + std::to_string(axis) + "] is not in [" + std::to_string(-1 - dim_num) + " , " +
  425. std::to_string(dim_num) + "]";
  426. GeInfershapeErrReport(op.GetName(), op.GetOpType(), "axis", reason);
  427. GE_OP_LOGE(op.GetName().c_str(), "axis[%d] is not in [%d, %d]", axis, -1 - dim_num, dim_num);
  428. return GRAPH_PARAM_INVALID;
  429. }
  430. if (axis < 0) {
  431. axis += dim_num + 1;
  432. }
  433. for (int i = 0; i < dim_num; i++) {
  434. vec_dim.push_back(x_desc->MutableShape().GetDim(i));
  435. }
  436. vec_dim.emplace(vec_dim.begin() + axis, 1);
  437. y_desc->SetShape(GeShape(vec_dim));
  438. y_desc->SetOriginShape(GeShape(vec_dim));
  439. y_desc->SetDataType(x_type);
  440. y_desc->SetOriginDataType(x_type);
  441. // infer shape range
  442. auto x_shape_size = x_desc->MutableShape().GetDims().size();
  443. std::vector<std::pair<int64_t, int64_t>> x_range;
  444. (void)x_desc->GetShapeRange(x_range);
  445. if (x_range.empty()) {
  446. GE_OP_LOGD(op.GetName().c_str(), "last op does not set shape range, so break!");
  447. return GRAPH_SUCCESS;
  448. }
  449. if (x_range.size() != x_shape_size) {
  450. GE_OP_LOGE(op.GetName().c_str(),
  451. "input range size num[%zu] should be same with input shape size[%zu]", x_range.size(), x_shape_size);
  452. return GRAPH_FAILED;
  453. }
  454. x_range.emplace(x_range.begin() + axis, std::pair<int64_t, int64_t>{1, 1});
  455. y_desc->SetShapeRange(x_range);
  456. return GRAPH_SUCCESS;
  457. }
  458. INFER_FUNC_REG(ExpandDims, ExpandDimsInfer);
  459. template <typename T>
  460. static graphStatus ValidateShape(const GeTensorPtr& tenosr, int64_t& product, int& unknow_index, GeShape& output,
  461. Operator& op) {
  462. int64_t dim_num = tenosr->MutableTensorDesc().MutableShape().GetDim(0);
  463. T* shape_data = const_cast<T*>(reinterpret_cast<const T*>(tenosr->GetData().GetData()));
  464. std::vector<int64_t> out_dims = output.GetDims();
  465. if (shape_data == nullptr) {
  466. GE_OP_LOGE(op.GetName().c_str(), "truth shape data is invalid");
  467. return GRAPH_PARAM_INVALID;
  468. }
  469. for (int64_t i = 0; i < dim_num; i++) {
  470. if (shape_data[i] == -1) {
  471. if (unknow_index != -1) {
  472. string reason = "only one dim may be -1, not both dim[ " + std::to_string(unknow_index) + "] and dim[" +
  473. std::to_string(i) + "]";
  474. GeInfershapeErrReport(op.GetName(), op.GetOpType(), kShape, reason);
  475. GE_OP_LOGE(op.GetName().c_str(), "Only one dim may be -1, not both dim[%lld] and dim[%lld]", unknow_index, i);
  476. return GRAPH_PARAM_INVALID;
  477. }
  478. unknow_index = i;
  479. out_dims.push_back(1);
  480. } else if (shape_data[i] < 0) {
  481. string reason = "Size[" + std::to_string(i) + "] must be non-negative";
  482. GeInfershapeErrReport(op.GetName(), op.GetOpType(), kShape, reason);
  483. GE_OP_LOGE(op.GetName().c_str(), "Size[%lld] must be non-negative", i);
  484. return GRAPH_PARAM_INVALID;
  485. } else {
  486. if (shape_data[i] != 0 && product > (INT64_MAX / shape_data[i])) {
  487. string reason = "Mul overflow of int64, product[" + std::to_string(product) + "] shape_data[" +
  488. std::to_string((int64_t)shape_data[i]) + "]";
  489. GeInfershapeErrReport(op.GetName(), op.GetOpType(), kShape, reason);
  490. GE_OP_LOGE(op.GetName().c_str(), "Mul overflow of int64, product[%lld] shape_data[%lld]", product,
  491. (int64_t)shape_data[i]);
  492. return GRAPH_PARAM_INVALID;
  493. }
  494. out_dims.push_back(shape_data[i]);
  495. product *= shape_data[i];
  496. }
  497. }
  498. output = GeShape(out_dims);
  499. return GRAPH_SUCCESS;
  500. }
  501. static graphStatus CaffeReshapeInferShape(const vector<int64_t>& dims, const int64_t& axis, const int64_t& num_axes,
  502. Operator& op) {
  503. GE_OP_LOGI(op.GetName().c_str(), "Reshape infer shape start");
  504. auto op_desc = OpDescUtils::GetOpDescFromOperator(op);
  505. auto x_desc = op_desc->MutableInputDesc("x");
  506. auto shape_desc = op_desc->MutableInputDesc("shape");
  507. auto y_desc = op_desc->MutableOutputDesc("y");
  508. auto x_dims = x_desc->GetShape().GetDims();
  509. auto data_type = x_desc->GetDataType();
  510. if (x_dims == UNKNOWN_RANK || dims == UNKNOWN_RANK) {
  511. GE_OP_LOGD("Input data is unknown_rank");
  512. y_desc->SetShape(GeShape(UNKNOWN_RANK));
  513. y_desc->SetOriginShape(GeShape(UNKNOWN_RANK));
  514. y_desc->SetDataType(data_type);
  515. return GRAPH_SUCCESS;
  516. }
  517. if (x_dims == UNKNOWN_SHAPE) {
  518. GE_OP_LOGD("Input data is unknown_shape.");
  519. y_desc->SetShape(GeShape(UNKNOWN_SHAPE));
  520. y_desc->SetOriginShape(GeShape(UNKNOWN_SHAPE));
  521. y_desc->SetDataType(data_type);
  522. return GRAPH_SUCCESS;
  523. }
  524. int64_t inferred_axis = -1;
  525. int64_t constant_count = 1;
  526. vector<int64_t> copy_axes;
  527. // parsing dims
  528. for (size_t i = 0; i < dims.size(); ++i) {
  529. const int64_t shape_dim_i = dims[i];
  530. if (shape_dim_i == 0) {
  531. copy_axes.push_back(i);
  532. } else if (shape_dim_i == -1) {
  533. if (inferred_axis != -1) {
  534. string reason = "only one dim may be -1, not both dim[ " + std::to_string(inferred_axis) + "] and dim[" +
  535. std::to_string(i) + "]";
  536. GeInfershapeErrReport(op.GetName(), op.GetOpType(), kAttrShape, reason);
  537. GE_OP_LOGE(op.GetName().c_str(), "Only one dim may be -1, not both dim[%ld] and dim[%zu]", inferred_axis, i);
  538. return GRAPH_PARAM_INVALID;
  539. }
  540. inferred_axis = i;
  541. } else {
  542. constant_count *= shape_dim_i;
  543. }
  544. }
  545. // parsing start axis and end axis
  546. Shape bottom_shape = op.GetInputDesc("x").GetShape();
  547. const int64_t bottom_shape_size = bottom_shape.GetDims().size();
  548. int64_t start_axis = 0;
  549. if (axis >= 0) {
  550. start_axis = axis;
  551. } else {
  552. start_axis = axis + bottom_shape_size + 1;
  553. }
  554. if (start_axis < 0 || start_axis > bottom_shape_size) {
  555. int64_t range = -1 - bottom_shape_size;
  556. // if axis >=0 , axis range [0, bottom_shape_size], else axis < 0, axis range [-1 - bottom_shape_size, -1]
  557. // axis range [-1 - bottom_shape_size, bottom_shape_size]
  558. string reason = "axis's range is not in [" + std::to_string(range) + ", " + std::to_string(bottom_shape_size) + "]";
  559. GeInfershapeErrReport(op.GetName(), op.GetOpType(), kAttrAxis, reason);
  560. GE_OP_LOGE(op.GetName().c_str(), "reshape param axis is invalid, axis's range is not in [%ld, %ld]", range,
  561. bottom_shape_size);
  562. return GRAPH_PARAM_INVALID;
  563. }
  564. int64_t end_axis = 0;
  565. if (num_axes < -1) {
  566. GeInfershapeErrReport(op.GetName(), op.GetOpType(), kAttrNumAxes, "it must be greater than or equal to -1");
  567. GE_OP_LOGE(op.GetName().c_str(), "reshape param num_axes is invalid, it must be greater than or equal to -1");
  568. return GRAPH_PARAM_INVALID;
  569. } else if (num_axes == -1) {
  570. end_axis = bottom_shape_size;
  571. } else {
  572. end_axis = start_axis + num_axes;
  573. }
  574. if (end_axis > bottom_shape_size) {
  575. GeInfershapeErrReport(op.GetName(), op.GetOpType(), kAttrNumAxes,
  576. "num_axes must be less than or equal to " + std::to_string((bottom_shape_size - start_axis)));
  577. GE_OP_LOGE(op.GetName().c_str(), "reshape param num_axes is invalid, it must be less than or equal to %ld",
  578. bottom_shape_size - start_axis);
  579. return GRAPH_PARAM_INVALID;
  580. }
  581. // construct top shape
  582. vector<int64_t> bottom_dims = bottom_shape.GetDims();
  583. const int64_t num_axes_replaced = end_axis - start_axis;
  584. const int64_t num_axes_retained = bottom_shape_size - num_axes_replaced;
  585. const int64_t num_new_axes = dims.size();
  586. vector<int64_t> top_shape(num_axes_retained + num_new_axes);
  587. size_t top_shape_index = 0;
  588. for (int64_t i = 0; i < start_axis; ++i) {
  589. top_shape[top_shape_index] = bottom_dims[i];
  590. top_shape_index++;
  591. }
  592. for (int64_t i = 0; i < num_new_axes; ++i) {
  593. top_shape[top_shape_index] = dims[i];
  594. top_shape_index++;
  595. }
  596. for (int64_t i = end_axis; i < bottom_shape_size; ++i) {
  597. top_shape[top_shape_index] = bottom_dims[i];
  598. top_shape_index++;
  599. }
  600. if (top_shape_index != top_shape.size()) {
  601. GeInfershapeErrReport(op.GetName(), op.GetOpType(), "infer shape size",
  602. "top_shape_index not equal to top_shape size");
  603. GE_OP_LOGE(op.GetName().c_str(), "reshape infer shape faied, top_shape_index not equal to top_shape size");
  604. return GRAPH_FAILED;
  605. }
  606. // product of [0,start_axis) + [end_axis, bottom_shape_size)
  607. int64_t explicit_count = constant_count;
  608. int64_t bottom_count_all = 1;
  609. for (int i = 0; i < bottom_shape_size; ++i) {
  610. bottom_count_all *= bottom_dims[i];
  611. if (i < start_axis || i >= end_axis) {
  612. explicit_count *= bottom_dims[i];
  613. }
  614. }
  615. // parsing dim 0 and -1
  616. for (size_t i = 0; i < copy_axes.size(); ++i) {
  617. const int64_t copy_axis_index = copy_axes[i];
  618. if ((start_axis + copy_axis_index) >= bottom_shape_size) {
  619. GeInfershapeErrReport(op.GetName(), op.GetOpType(), kAttrShape,
  620. "there was no corresponding bottom axis for dim 0");
  621. GE_OP_LOGE(op.GetName().c_str(), "there was no corresponding bottom axis for dim 0.");
  622. return GRAPH_FAILED;
  623. }
  624. top_shape[start_axis + copy_axis_index] = bottom_dims[start_axis + copy_axis_index];
  625. explicit_count *= bottom_dims[start_axis + copy_axis_index];
  626. }
  627. if (inferred_axis >= 0) {
  628. if (bottom_count_all % explicit_count != 0) {
  629. string reason =
  630. "The shape of the input cannot be divisible by the product "
  631. "of the specified dimensions, the product is [" +
  632. std::to_string(explicit_count) + "]";
  633. GeInfershapeErrReport(op.GetName(), op.GetOpType(), kAttrShape, reason);
  634. GE_OP_LOGE(
  635. op.GetName().c_str(),
  636. "The shape of the input cannot be divisible by the product of the specified dimensions, the product is %ld",
  637. explicit_count);
  638. return GRAPH_FAILED;
  639. }
  640. const int64_t inferred_dim = bottom_count_all / explicit_count;
  641. top_shape[start_axis + inferred_axis] = inferred_dim;
  642. }
  643. int64_t top_count_all = 1;
  644. for (size_t i = 0; i < top_shape.size(); ++i) {
  645. top_count_all *= top_shape[i];
  646. }
  647. if (top_count_all != bottom_count_all) {
  648. string reason = "output tensor count [ " + std::to_string(top_count_all) + "] does not match input tensor count [" +
  649. std::to_string(bottom_count_all) + "].";
  650. GeInfershapeErrReport(op.GetName(), op.GetOpType(), kAttrShape, reason);
  651. GE_OP_LOGE(op.GetName().c_str(), "output tensor count %lld does not match input tensor count %ld.", top_count_all,
  652. bottom_count_all);
  653. return GRAPH_FAILED;
  654. }
  655. // updata output shape info
  656. TensorDesc td = op.GetOutputDesc("y");
  657. td.SetShape(Shape(top_shape));
  658. td.SetDataType(op.GetInputDesc("x").GetDataType());
  659. (void)op.UpdateOutputDesc("y", td);
  660. return GRAPH_SUCCESS;
  661. }
  662. bool IsEmptyTensor(GeTensorDescPtr tensor_desc) {
  663. bool is_empty = false;
  664. for (const auto &dim : tensor_desc->MutableShape().GetDims()) {
  665. if (dim == 0) {
  666. is_empty = true;
  667. break;
  668. }
  669. }
  670. return is_empty;
  671. }
  672. template <typename T>
  673. graphStatus GetOutShapeFromTensor(OpDescPtr op_desc, GeTensorPtr tensor, std::vector<int64_t> &v_out) {
  674. auto shape_desc = tensor->MutableTensorDesc();
  675. T* shape_data = const_cast<T*>(reinterpret_cast<const T*>(tensor->GetData().GetData()));
  676. if (shape_data == nullptr) {
  677. GE_OP_LOGE(op_desc->GetName().c_str(), "const shape data is invalid");
  678. return GRAPH_PARAM_INVALID;
  679. }
  680. for (int i = 0; i < shape_desc.MutableShape().GetDim(0); i++) {
  681. v_out.emplace_back(shape_data[i]);
  682. }
  683. return GRAPH_SUCCESS;
  684. }
  685. graphStatus EmptyTensorProcess(const Operator &op, const GeTensorDesc &x_desc, const GeTensorPtr &shape_tensor,
  686. GeTensorDesc &out_desc) {
  687. GE_OP_LOGD("Start empty-tensor preprocess!");
  688. auto op_desc = OpDescUtils::GetOpDescFromOperator(op);
  689. auto shape_type = op_desc->MutableInputDesc("shape")->GetDataType();
  690. std::vector<int64_t> shape_shape;
  691. graphStatus ret = GRAPH_SUCCESS;
  692. if (shape_type == DT_INT32) {
  693. ret = GetOutShapeFromTensor<int32_t>(op_desc, shape_tensor, shape_shape);
  694. } else if (shape_type == DT_INT64) {
  695. ret = GetOutShapeFromTensor<int32_t>(op_desc, shape_tensor, shape_shape);
  696. } else {
  697. GeInfershapeErrReport(op.GetName(), op.GetOpType(), kShapeDtype,
  698. "Dim type must be DT_INT32 or DT_INT64.");
  699. GE_OP_LOGE(op.GetName().c_str(), "Dim type must be DT_INT32 or DT_INT64.");
  700. return GRAPH_PARAM_INVALID;
  701. }
  702. if (ret != GRAPH_SUCCESS) {
  703. return ret;
  704. }
  705. GE_OP_LOGD(op.GetName().c_str(), "x shape: %s shape shape: %s", x_desc.GetShape().ToString().c_str(),
  706. GeShape(shape_shape).ToString().c_str());
  707. int64_t num_of_neg_1 = 0;
  708. int64_t product = 1;
  709. for (auto &dim : shape_shape) {
  710. if (dim == -1) { // -1 stand for highest dim here
  711. num_of_neg_1++;
  712. dim = 0;
  713. }
  714. product *= dim;
  715. }
  716. // check valid
  717. if ((num_of_neg_1 == 0 && product == 0) || (num_of_neg_1 == 1)) {
  718. out_desc.SetShape(GeShape(shape_shape));
  719. out_desc.SetOriginShape(GeShape(shape_shape));
  720. out_desc.SetDataType(x_desc.GetDataType());
  721. out_desc.SetOriginDataType(x_desc.GetDataType());
  722. return GRAPH_SUCCESS;
  723. }
  724. GE_OP_LOGE(op.GetName().c_str(),
  725. "Param is invalid!.Please check!Input shape contains -1 num is %ld, product is %ld", num_of_neg_1, product);
  726. return GRAPH_FAILED;
  727. }
  728. IMPLEMT_INFERFUNC(Reshape, ReshapeInfer) {
  729. bool zero_flag = false;
  730. vector<int64_t> attr_dims;
  731. if (op.GetAttr("shape", attr_dims) == GRAPH_SUCCESS) {
  732. for (size_t i = 0; i < attr_dims.size(); ++i) {
  733. if (attr_dims[i] == 0) {
  734. zero_flag = true;
  735. break;
  736. }
  737. }
  738. }
  739. std::vector<string> dep_inputs = {"shape"};
  740. auto op_desc = OpDescUtils::GetOpDescFromOperator(op);
  741. op_desc->SetOpInferDepends(dep_inputs);
  742. auto x_desc = op_desc->MutableInputDesc("x");
  743. auto y_desc = op_desc->MutableOutputDesc("y");
  744. int64_t attr_axis = 0;
  745. op.GetAttr("axis", attr_axis);
  746. int64_t attr_num_axes = -1;
  747. op.GetAttr("num_axes", attr_num_axes);
  748. if (attr_axis != 0 || attr_num_axes != -1 || zero_flag) {
  749. GE_OP_LOGI(op.GetName().c_str(), "Get reshape_param successfully, shape size is %u, axis is %ld, num_axes is %ld",
  750. attr_dims.size(), attr_axis, attr_num_axes);
  751. graphStatus caffe_reshape_ret = CaffeReshapeInferShape(attr_dims, attr_axis, attr_num_axes, op);
  752. return caffe_reshape_ret;
  753. }
  754. GE_OP_LOGI(op.GetName().c_str(), "Reshape infer shape start");
  755. GeTensorPtr tensor = nullptr;
  756. auto node = NodeUtils::GetNodeFromOperator(op);
  757. if (node == nullptr) {
  758. OP_LOGE(op.GetName().c_str(), "get null node ptr!");
  759. return GRAPH_PARAM_INVALID;
  760. }
  761. graphStatus state = NodeUtils::GetInputConstData(node, "shape", tensor);
  762. if (state != GRAPH_SUCCESS) {
  763. GE_OP_LOGW(op.GetName().c_str(), "Op get input const data of shape failed");
  764. auto input_shape = op_desc->MutableInputDesc("x")->MutableShape();
  765. auto shape_input_desc = op_desc->MutableInputDesc("shape");
  766. auto shape_shape = shape_input_desc->MutableShape();
  767. // because shape's value stand for output shape, so it should be smaller than 1 dim
  768. auto shape_rank = shape_shape.GetDims().size();
  769. if (shape_rank > 1) {
  770. string reason =
  771. "shape dim[" + std::to_string(shape_shape.GetDims().size()) + "] should be smaller or equal than 1";
  772. GeInfershapeErrReport(op.GetName(), op.GetOpType(), kShape, reason);
  773. GE_OP_LOGE(op.GetName().c_str(), "shape dim[%zu] should be smaller or equal than 1",
  774. shape_shape.GetDims().size());
  775. return GRAPH_PARAM_INVALID;
  776. }
  777. if (shape_shape.GetDims() != UNKNOWN_RANK && shape_shape.GetDims() != UNKNOWN_SHAPE) {
  778. auto x_type = op_desc->MutableInputDesc("x")->GetDataType();
  779. auto td = op_desc->MutableOutputDesc("y");
  780. int64_t rank = (shape_rank == 0) ? 0 : shape_shape.GetDims().at(0);
  781. td->SetShape(GeShape(std::vector<int64_t>(rank, UNKNOWN_DIM)));
  782. td->SetOriginShape(GeShape(std::vector<int64_t>(rank, UNKNOWN_DIM)));
  783. td->SetDataType(x_type);
  784. // calc shape range
  785. if (input_shape.GetDims() == UNKNOWN_RANK) {
  786. GE_OP_LOGD("input x is unknown rank!no way to set shape range!");
  787. return GRAPH_SUCCESS;
  788. }
  789. auto input_shape_size = input_shape.GetShapeSize();
  790. int64_t range_max = 1;
  791. if (input_shape_size <= 0) {
  792. // unknown dim , by input shape range calc output range
  793. std::vector<std::pair<int64_t, int64_t>> x_range;
  794. (void)op_desc->MutableInputDesc("x")->GetShapeRange(x_range);
  795. if (x_range.empty()) {
  796. return GRAPH_SUCCESS;
  797. }
  798. ge::array_ops::ReshapeRangeInfer(op, x_range, range_max);
  799. } else {
  800. // known dim, shape size as range_max
  801. range_max = input_shape_size;
  802. }
  803. range_max = (range_max > INT32_MAX) ? INT32_MAX : range_max;
  804. std::vector<std::pair<int64_t, int64_t>> y_range(rank, {1, range_max});
  805. td->SetShapeRange(y_range);
  806. return GRAPH_SUCCESS;
  807. }
  808. auto x_type = op_desc->MutableInputDesc("x")->GetDataType();
  809. auto td = op_desc->MutableOutputDesc("y");
  810. td->SetShape(GeShape({-2}));
  811. td->SetOriginShape(GeShape({-2}));
  812. td->SetDataType(x_type);
  813. return GRAPH_SUCCESS;
  814. }
  815. if (IsEmptyTensor(x_desc)) {
  816. return EmptyTensorProcess(op, *x_desc, tensor, *y_desc);
  817. }
  818. std::vector<std::pair<int64_t, int64_t>> x_range;
  819. std::vector<std::pair<int64_t, int64_t>> y_range;
  820. op_desc->MutableInputDesc("x")->GetShapeRange(x_range);
  821. int64_t product = 1;
  822. int unknow_index = -1;
  823. GeShape output_shape;
  824. DataType shape_type = op_desc->MutableInputDesc("shape")->GetDataType();
  825. int64_t shape_size = op_desc->MutableInputDesc("shape")->MutableShape().GetShapeSize();
  826. graphStatus ret = GRAPH_SUCCESS;
  827. if (shape_type == DT_INT32) {
  828. ret = ValidateShape<int32_t>(tensor, product, unknow_index, output_shape, op);
  829. } else if (shape_type == DT_INT64) {
  830. ret = ValidateShape<int64_t>(tensor, product, unknow_index, output_shape, op);
  831. } else if (shape_size > 0) {
  832. GeInfershapeErrReport(op.GetName(), op.GetOpType(), kShapeDtype, "Dim type must be DT_INT32 or DT_INT64.");
  833. GE_OP_LOGE(op.GetName().c_str(), "Dim type must be DT_INT32 or DT_INT64.");
  834. return GRAPH_PARAM_INVALID;
  835. }
  836. if (ret != GRAPH_SUCCESS) {
  837. GE_OP_LOGE(op.GetName().c_str(), "ValidateShape failed, ret: %d", ret);
  838. return ret;
  839. }
  840. auto input_shape = op_desc->MutableInputDesc("x")->MutableShape();
  841. int64_t input_size = input_shape.GetShapeSize();
  842. // If input tensor is scalar,then input_size will return 0, assign to 1, which means convert scalar to vector.
  843. if (input_size == 0 && output_shape.GetShapeSize() == 1) {
  844. input_size = 1;
  845. }
  846. if (unknow_index != -1) {
  847. if (product <= 0) {
  848. GE_OP_LOGE(op.GetName().c_str(), "Reshape Op can't infer an empty tensor");
  849. return GRAPH_PARAM_INVALID;
  850. }
  851. if (input_shape.GetShapeSize() < 0) {
  852. GE_OP_LOGI("input x and input shape is all unknown!");
  853. auto td = op_desc->MutableOutputDesc("y");
  854. output_shape.SetDim(unknow_index, -1);
  855. td->SetOriginDataType(op_desc->MutableInputDesc("x")->GetDataType());
  856. td->SetShape(output_shape);
  857. td->SetOriginShape(output_shape);
  858. td->SetDataType(op_desc->MutableInputDesc("x")->GetDataType());
  859. auto max_input_dims = 1;
  860. // If last op does not set shape range ,do not set shape range
  861. if (x_range.empty()) {
  862. GE_OP_LOGI(op.GetName().c_str(), "input x doesnot have shape range!");
  863. } else {
  864. // If last op have already set shape range, try best to infer shape range
  865. ge::array_ops::ReshapeRangeInfer(op, x_range, y_range, output_shape);
  866. }
  867. td->SetShapeRange(y_range);
  868. return GRAPH_SUCCESS;
  869. }
  870. int64_t missing = input_size / product;
  871. if (product * missing != input_size) {
  872. string reason = "The shape of the input cannot be divisible from [" + std::to_string(product) + "]";
  873. GeInfershapeErrReport(op.GetName(), op.GetOpType(), kShape, reason);
  874. GE_OP_LOGE(op.GetName().c_str(), "The shape of the input cannot be divisible from %lld", product);
  875. return GRAPH_PARAM_INVALID;
  876. }
  877. output_shape.SetDim(unknow_index, missing);
  878. }
  879. auto dims = input_shape.GetDims();
  880. bool is_exist_unknown_shape = false;
  881. for (auto ele : dims) {
  882. is_exist_unknown_shape = (ele == -1) ? true : false;
  883. if (!is_exist_unknown_shape) {
  884. continue;
  885. }
  886. }
  887. if (SetScalarOutputDesc(string("x"), string("y"), op_desc, output_shape)) {
  888. return GRAPH_SUCCESS;
  889. }
  890. // Shape_size is 0, means shape tensor value is [], implying convert vector/scalar to scalar
  891. bool convert_to_scalar =
  892. (shape_size == 0 && (input_size == 1 || (input_size == 0 && input_shape.GetDims().size() == 0)));
  893. // Output_shape.GetShapeSize() > 0 and input_size <= 0 for dynamic shape
  894. bool shape_check_ok =
  895. ((input_size == output_shape.GetShapeSize()) || ((output_shape.GetShapeSize() > 0) && (input_size <= 0)) ||
  896. (is_exist_unknown_shape && (output_shape.GetShapeSize() > 0)));
  897. if (!shape_check_ok && !convert_to_scalar) {
  898. string reason = "Shape size is [" + std::to_string(shape_size) + "], input tensor with [" +
  899. std::to_string(input_size) + "] values, is input dynamic shape [" +
  900. std::to_string(is_exist_unknown_shape) + "], but requested shape has [" +
  901. std::to_string(output_shape.GetShapeSize()) + "] values";
  902. GeInfershapeErrReport(op.GetName(), op.GetOpType(), kShape, reason);
  903. GE_OP_LOGE(op.GetName().c_str(),
  904. "Shape size is %lld, input tensor with %lld values, is input dynamic shape :%d, but \
  905. requested shape has %lld values",
  906. shape_size, input_size, is_exist_unknown_shape, output_shape.GetShapeSize());
  907. return GRAPH_PARAM_INVALID;
  908. }
  909. auto td = op_desc->MutableOutputDesc("y");
  910. td->SetShape(output_shape);
  911. td->SetOriginShape(output_shape);
  912. td->SetDataType(op_desc->MutableInputDesc("x")->GetDataType());
  913. td->SetOriginDataType(op_desc->MutableInputDesc("x")->GetDataType());
  914. return GRAPH_SUCCESS;
  915. }
  916. INFER_FUNC_REG(Reshape, ReshapeInfer);
  917. IMPLEMT_INFERFORMAT_FUNC(Reshape, ReshapeInferFormat) {
  918. GE_OP_LOGI(op.GetName().c_str(), "Reshape infer format start");
  919. auto op_desc = OpDescUtils::GetOpDescFromOperator(op);
  920. auto input_descs = op_desc->GetAllInputsDescPtr();
  921. auto output_descs = op_desc->GetAllOutputsDescPtr();
  922. for (const auto& input_desc : input_descs) {
  923. if (input_desc->GetShape().GetDimNum() < 4) {
  924. input_desc->SetOriginFormat(FORMAT_ND);
  925. input_desc->SetFormat(FORMAT_ND);
  926. }
  927. }
  928. for (const auto& output_desc : output_descs) {
  929. if (output_desc->GetShape().GetDimNum() < 4) {
  930. output_desc->SetOriginFormat(FORMAT_ND);
  931. output_desc->SetFormat(FORMAT_ND);
  932. }
  933. }
  934. (void)op_desc->DefaultInferFormat();
  935. for (const auto& input_desc : input_descs) {
  936. if (input_desc->GetShape().GetDimNum() < 4) {
  937. input_desc->SetOriginFormat(FORMAT_ND);
  938. input_desc->SetFormat(FORMAT_ND);
  939. }
  940. }
  941. for (const auto& output_desc : output_descs) {
  942. if (output_desc->GetShape().GetDimNum() < 4) {
  943. output_desc->SetOriginFormat(FORMAT_ND);
  944. output_desc->SetFormat(FORMAT_ND);
  945. }
  946. }
  947. return GRAPH_SUCCESS;
  948. }
  949. INFER_FORMAT_FUNC_REG(Reshape, ReshapeInferFormat);
  950. IMPLEMT_VERIFIER(Squeeze, SqueezeVerify) {
  951. GE_OP_LOGD("Enter SqueezeVerify");
  952. auto op_desc = OpDescUtils::GetOpDescFromOperator(op);
  953. auto axis = op.get_attr_axis();
  954. auto input_desc_x = op_desc->MutableInputDesc("x");
  955. auto xShape = input_desc_x->MutableShape().GetDims();
  956. std::vector<std::pair<int64_t, int64_t>> x_range;
  957. input_desc_x->GetShapeRange(x_range);
  958. if ((xShape != UNKNOWN_RANK) && (!x_range.empty()) && (x_range.size() != xShape.size())) {
  959. // if it has set shape range, it should be same with input dim num
  960. GE_OP_LOGE("x_shape_range num [%zu] does not match x dims_num [%zu]", x_range.size(), xShape.size());
  961. return GRAPH_FAILED;
  962. }
  963. auto node = NodeUtils::GetNodeFromOperator(op);
  964. if (node == nullptr) {
  965. GE_OP_LOGE("node pointer is nullptr");
  966. return GRAPH_FAILED;
  967. }
  968. bool is_unknow = false;
  969. auto status = NodeUtils::GetNodeUnknownShapeStatus(*node, is_unknow);
  970. if (status != GRAPH_SUCCESS) {
  971. GE_OP_LOGE("Get node unknown shape status failed!");
  972. return GRAPH_FAILED;
  973. }
  974. if (is_unknow) {
  975. // when input is unknown , no way to check param "axis" whether valid. Do check when running
  976. return GRAPH_SUCCESS;
  977. }
  978. if (axis.size() > 0) {
  979. for (unsigned i = 0; i < axis.size(); i++) {
  980. if (axis[i] < 0)
  981. axis[i] += xShape.size();
  982. bool flag = (0 <= axis[i]) && (axis[i] < static_cast<int64_t>(xShape.size()));
  983. if (!flag) {
  984. GeInfershapeErrReport(op.GetName(), op.GetOpType(), kAttrAxis,
  985. "axis value is out of range of [-rank(input), rank(input)).");
  986. GE_OP_LOGE(op.GetName().c_str(), "axis value is out of range of [-rank(input), rank(input)).");
  987. return GRAPH_FAILED;
  988. }
  989. if (!(xShape[axis[i]] == 1)) {
  990. GeInfershapeErrReport(op.GetName(), op.GetOpType(), kShape, "input shape has dim not equal to 1.");
  991. GE_OP_LOGE(op.GetName().c_str(), "input shape has dim not equal to 1.");
  992. return GRAPH_FAILED;
  993. }
  994. }
  995. }
  996. GE_OP_LOGD("SqueezeVerify Success!");
  997. return GRAPH_SUCCESS;
  998. }
  999. VERIFY_FUNC_REG(Squeeze, SqueezeVerify);
  1000. IMPLEMT_INFERFUNC(Squeeze, SqueezeInfer) {
  1001. GE_OP_LOGD("Enter Squeeze Infershape!");
  1002. auto op_desc = OpDescUtils::GetOpDescFromOperator(op);
  1003. auto axis = op.get_attr_axis();
  1004. auto input_desc_x = op_desc->MutableInputDesc("x");
  1005. auto output_desc_y = op_desc->MutableOutputDesc("y");
  1006. auto input_shape = input_desc_x->MutableShape();
  1007. int64_t dim_size = input_shape.GetDimNum();
  1008. auto x_data_type = input_desc_x->GetDataType();
  1009. int32_t axis_num = axis.size();
  1010. // process -2(UnknownRank)
  1011. if (input_shape.GetDims() == UNKNOWN_RANK) {
  1012. GE_OP_LOGD("Input x shape is -2!");
  1013. output_desc_y->SetShape(GeShape(UNKNOWN_RANK));
  1014. output_desc_y->SetOriginShape(GeShape(UNKNOWN_RANK));
  1015. output_desc_y->SetDataType(x_data_type);
  1016. return GRAPH_SUCCESS;
  1017. }
  1018. std::vector<std::pair<int64_t, int64_t>> x_range;
  1019. std::vector<std::pair<int64_t, int64_t>> y_range;
  1020. input_desc_x->GetShapeRange(x_range);
  1021. std::unordered_set<int32_t> squeeze_dims;
  1022. for (int32_t i = 0; i < axis_num; ++i) {
  1023. int32_t dim = axis[i];
  1024. if (dim < -dim_size || dim >= dim_size) {
  1025. string reason = "Tried to squeeze dim index[" + std::to_string(dim) + "] for tensor with [" +
  1026. std::to_string(dim_size) + "] dimensions";
  1027. GeInfershapeErrReport(op.GetName(), op.GetOpType(), kAttrAxis, reason);
  1028. GE_OP_LOGE(op.GetName().c_str(), "Tried to squeeze dim index[%d] for tensor with [%lld] dimensions", dim,
  1029. dim_size);
  1030. return GRAPH_FAILED;
  1031. }
  1032. if (dim < 0) {
  1033. dim = dim_size + dim;
  1034. }
  1035. squeeze_dims.insert(dim);
  1036. }
  1037. vector<int64_t> out_shape;
  1038. for (int i = 0; i < dim_size; i++) {
  1039. auto exist_dim = input_shape.GetDim(i);
  1040. // If squeeze_set is non-empty, only squeeze those dimensions.
  1041. if (!squeeze_dims.empty()) {
  1042. if (squeeze_dims.count(i) > 0) {
  1043. // If dim is -1 and been pointed by axis , do think -1 is 1.because no method to do verify
  1044. if (exist_dim != 1 && exist_dim != UNKNOWN_DIM) {
  1045. string reason = "Can not squeeze dim[" + std::to_string(i) + "], expected a dimension of 1, got [" +
  1046. std::to_string(exist_dim) + "]";
  1047. GeInfershapeErrReport(op.GetName(), op.GetOpType(), kShape, reason);
  1048. GE_OP_LOGE(op.GetName().c_str(), "Can not squeeze dim[%d], expected a dimension of 1, got %lld", i,
  1049. exist_dim);
  1050. return GRAPH_FAILED;
  1051. }
  1052. } else {
  1053. out_shape.emplace_back(exist_dim);
  1054. // after verified, it has ensure x_range ele num is same with dims num
  1055. if (!x_range.empty()) {
  1056. y_range.emplace_back(x_range[i]);
  1057. }
  1058. }
  1059. } else {
  1060. // Copy over all non-1-length dimensions.
  1061. // here no methed to ensure which -1 is 1, so do warning
  1062. if (exist_dim != 1) {
  1063. if (exist_dim == -1) {
  1064. GE_OP_LOGW("the [%d] dim is -1, it will not execute squeeze on it! maybe influence result", exist_dim);
  1065. }
  1066. out_shape.emplace_back(exist_dim);
  1067. // after verified, it has ensure x_range ele num is same with dims num
  1068. if (!x_range.empty()) {
  1069. y_range.emplace_back(x_range[i]);
  1070. }
  1071. }
  1072. }
  1073. }
  1074. output_desc_y->SetShape(GeShape(out_shape));
  1075. output_desc_y->SetOriginShape(GeShape(out_shape));
  1076. output_desc_y->SetDataType(x_data_type);
  1077. if (!y_range.empty()) {
  1078. output_desc_y->SetShapeRange(y_range);
  1079. }
  1080. return GRAPH_SUCCESS;
  1081. }
  1082. INFER_FUNC_REG(Squeeze, SqueezeInfer);
  1083. IMPLEMT_INFERFUNC(Unsqueeze, UnsqueezeInfer) {
  1084. auto axis_arr = op.get_attr_axes();
  1085. auto axis_nums = axis_arr.size();
  1086. if (axis_nums <= 0) {
  1087. string reason = "Axis_nums[" + std::to_string(axis_nums) + "] must be greater than 0";
  1088. GeInfershapeErrReport(op.GetName(), op.GetOpType(), kAttrAxis, reason);
  1089. GE_OP_LOGE(op.GetName().c_str(), "Axis_nums[%zu] must be greater than 0", axis_nums);
  1090. return GRAPH_PARAM_INVALID;
  1091. }
  1092. std::unordered_set<int64_t> values(axis_arr.begin(), axis_arr.end());
  1093. if (values.size() != axis_arr.size()) {
  1094. string reason = "Axis attribute must not contain any duplicates.";
  1095. GeInfershapeErrReport(op.GetName(), op.GetOpType(), kAttrAxis, reason);
  1096. GE_OP_LOGE(op.GetName().c_str(), "Axis attribute must not contain any duplicates.");
  1097. return GRAPH_PARAM_INVALID;
  1098. }
  1099. Shape input_shape = op.get_input_desc_x().GetShape();
  1100. int64_t dim_num = input_shape.GetDimNum() + axis_nums;
  1101. std::vector<int64_t> vec_dim(dim_num, 0);
  1102. for (size_t i = 0; i < axis_nums; i++) {
  1103. int64_t axis = axis_arr[i];
  1104. if ((axis < -dim_num) || (axis > (dim_num - 1))) {
  1105. string reason = "axis[" + std::to_string(axis_nums) + "]'s range is not in [" + std::to_string(-dim_num) + ", " +
  1106. std::to_string(dim_num - 1) + "]";
  1107. GeInfershapeErrReport(op.GetName(), op.GetOpType(), kAttrAxis, reason);
  1108. GE_OP_LOGE(op.GetName().c_str(), "Axis %ld not in [%ld, %ld]", axis, -dim_num, dim_num);
  1109. return GRAPH_PARAM_INVALID;
  1110. }
  1111. if (axis < 0) {
  1112. axis += dim_num;
  1113. }
  1114. vec_dim.at(axis) = 1;
  1115. }
  1116. int64_t index = 0;
  1117. for (int64_t i = 0; i < dim_num; i++) {
  1118. if (vec_dim.at(i) != 1) {
  1119. vec_dim.at(i) = input_shape.GetDim(index);
  1120. index++;
  1121. }
  1122. }
  1123. TensorDesc td = op.get_output_desc_y();
  1124. td.SetShape(Shape(vec_dim));
  1125. td.SetDataType(op.get_input_desc_x().GetDataType());
  1126. (void)op.update_output_desc_y(td);
  1127. return GRAPH_SUCCESS;
  1128. }
  1129. INFER_FUNC_REG(Unsqueeze, UnsqueezeInfer);
  1130. IMPLEMT_INFERFUNC(Rank, RankInfer) {
  1131. OP_LOGI(op.GetName().c_str(), "Rank infershape start");
  1132. auto op_desc = OpDescUtils::GetOpDescFromOperator(op);
  1133. auto output_desc_y = op_desc->MutableOutputDesc("y");
  1134. std::vector<int64_t> oShapeVector;
  1135. output_desc_y->SetShape(GeShape(oShapeVector));
  1136. output_desc_y->SetOriginShape(GeShape(oShapeVector));
  1137. output_desc_y->SetDataType(DT_INT32);
  1138. OP_LOGI(op.GetName().c_str(), "Rank infershape end");
  1139. return GRAPH_SUCCESS;
  1140. }
  1141. INFER_FUNC_REG(Rank, RankInfer);
  1142. IMPLEMT_INFERFUNC(Size, SizeInfer) {
  1143. OP_LOGI(op.GetName().c_str(), "Size infershape start");
  1144. auto op_desc = OpDescUtils::GetOpDescFromOperator(op);
  1145. auto output_desc_y = op_desc->MutableOutputDesc("y");
  1146. std::vector<int64_t> oShapeVector;
  1147. output_desc_y->SetShape(GeShape(oShapeVector));
  1148. DataType out_type = DT_INT32;
  1149. GeAttrValue out_type_value;
  1150. op_desc->GetAttr("dtype", out_type_value);
  1151. out_type_value.GetValue<DataType>(out_type);
  1152. output_desc_y->SetDataType(out_type);
  1153. OP_LOGI(op.GetName().c_str(), "Size infershape end");
  1154. return GRAPH_SUCCESS;
  1155. }
  1156. INFER_FUNC_REG(Size, SizeInfer);
  1157. COMMON_INFER_FUNC_REG(Data, ELMTWISE_INFER_SHAPEANDTYPE("x", "y"));
  1158. COMMON_INFER_FUNC_REG(PlaceHolder, ELMTWISE_INFER_SHAPEANDTYPE("x", "y"));
  1159. COMMON_INFER_FUNC_REG(End, ELMTWISE_INFER_SHAPEANDTYPE("x", "y"));
  1160. IMPLEMT_INFERFUNC(PlaceholderWithDefault, PlaceholderWithDefaultInfer) {
  1161. TensorDesc input_desc = op.GetInputDesc("x");
  1162. auto dims = input_desc.GetShape().GetDims();
  1163. auto data_type = input_desc.GetDataType();
  1164. TensorDesc output_desc = op.GetOutputDesc("y");
  1165. output_desc.SetDataType(ge::DataType(data_type));
  1166. output_desc.SetShape(Shape(dims));
  1167. (void)op.UpdateOutputDesc("y", output_desc);
  1168. return GRAPH_SUCCESS;
  1169. }
  1170. INFER_FUNC_REG(PlaceholderWithDefault, PlaceholderWithDefaultInfer);
  1171. IMPLEMT_INFERFUNC(Shape, ShapeInfer) {
  1172. auto op_desc = OpDescUtils::GetOpDescFromOperator(op);
  1173. auto td = op_desc->MutableOutputDesc("y");
  1174. auto input_dims = op_desc->MutableInputDesc("x")->MutableShape().GetDims();
  1175. if (input_dims == UNKNOWN_RANK) {
  1176. td->SetShape(ge::GeShape(UNKNOWN_SHAPE));
  1177. td->SetOriginShape(ge::GeShape(UNKNOWN_SHAPE));
  1178. td->SetShapeRange(std::vector<std::pair<int64_t, int64_t>>{{1,kMaxDimNum}});
  1179. } else {
  1180. int64_t size = static_cast<int64_t>(input_dims.size());
  1181. std::vector<int64_t> size_v{size};
  1182. td->SetShape(ge::GeShape(size_v));
  1183. td->SetOriginShape(ge::GeShape(size_v));
  1184. }
  1185. uint32_t out_type = DT_INT32;
  1186. (void)op.GetAttr("dtype", out_type);
  1187. td->SetDataType((DataType)out_type);
  1188. std::vector<std::pair<int64_t, int64_t>> inRange;
  1189. op_desc->MutableInputDesc("x")->GetShapeRange(inRange);
  1190. if (!inRange.empty()) {
  1191. std::vector<int64_t> pre_op_range;
  1192. pre_op_range.resize(2*inRange.size());
  1193. for (int i = 0; i < pre_op_range.size(); i = i + 2) {
  1194. pre_op_range[i] = inRange[i/2].first;
  1195. pre_op_range[i + 1] = inRange[i/2].second;
  1196. }
  1197. ge::AttrUtils::SetListInt(*td, kPreOpInputShapeRange, pre_op_range);
  1198. OP_LOGD(op.GetName().c_str(), "Shape op set pre_op_range success");
  1199. }
  1200. return GRAPH_SUCCESS;
  1201. }
  1202. INFER_FUNC_REG(Shape, ShapeInfer);
  1203. IMPLEMT_INFERFUNC(ShapeN, ShapeNInfer) {
  1204. auto op_desc = OpDescUtils::GetOpDescFromOperator(op);
  1205. for (size_t i = 0; i < op.GetInputsSize(); i++) {
  1206. auto td = op_desc->MutableOutputDesc(i);
  1207. auto input_dims = op_desc->MutableInputDesc(i)->MutableShape().GetDims();
  1208. if (input_dims == UNKNOWN_RANK) {
  1209. td->SetShape(ge::GeShape(UNKNOWN_SHAPE));
  1210. td->SetOriginShape(ge::GeShape(UNKNOWN_SHAPE));
  1211. td->SetShapeRange(std::vector<std::pair<int64_t, int64_t>>{{1,kMaxDimNum}});
  1212. } else {
  1213. int64_t size = static_cast<int64_t>(input_dims.size());
  1214. GE_OP_LOGD(op.GetName().c_str(), "output value %ld", size);
  1215. std::vector<int64_t> size_v{size};
  1216. td->SetShape(ge::GeShape(size_v));
  1217. td->SetOriginShape(ge::GeShape(size_v));
  1218. }
  1219. uint32_t out_type = DT_INT32;
  1220. (void)op.GetAttr("dtype", out_type);
  1221. td->SetDataType((DataType)out_type);
  1222. std::vector<std::pair<int64_t, int64_t>> inRange;
  1223. op_desc->MutableInputDesc(i)->GetShapeRange(inRange);
  1224. if (!inRange.empty()) {
  1225. std::vector<int64_t> pre_op_range;
  1226. pre_op_range.resize(2*inRange.size());
  1227. for (int i = 0; i < pre_op_range.size(); i = i + 2) {
  1228. pre_op_range[i] = inRange[i/2].first;
  1229. pre_op_range[i + 1] = inRange[i/2].second;
  1230. }
  1231. ge::AttrUtils::SetListInt(*td, kPreOpInputShapeRange, pre_op_range);
  1232. OP_LOGD(op.GetName().c_str(), "ShapeN op set pre_op_range success");
  1233. }
  1234. }
  1235. return GRAPH_SUCCESS;
  1236. }
  1237. INFER_FUNC_REG(ShapeN, ShapeNInfer);
  1238. IMPLEMT_INFERFUNC(IdentityN, IdentityNInfer) {
  1239. OP_LOGI(op.GetName().c_str(), "IdentityN infershape start");
  1240. auto op_desc = OpDescUtils::GetOpDescFromOperator(op);
  1241. for (size_t i = 0; i < op.GetInputsSize(); i++) {
  1242. auto input_desc = op_desc->MutableInputDesc(i);
  1243. auto input_dims = input_desc->MutableShape().GetDims();
  1244. auto output_desc = op_desc->MutableOutputDesc(i);
  1245. auto intput_dtype = input_desc->GetDataType();
  1246. std::vector<std::pair<int64_t, int64_t>> input_range;
  1247. input_desc->GetShapeRange(input_range);
  1248. output_desc->SetShape(GeShape(input_dims));
  1249. output_desc->SetOriginShape(GeShape(input_dims));
  1250. output_desc->SetDataType(intput_dtype);
  1251. output_desc->SetShapeRange(input_range);
  1252. }
  1253. OP_LOGI(op.GetName().c_str(), "IdentityN infershape end");
  1254. return GRAPH_SUCCESS;
  1255. }
  1256. INFER_FUNC_REG(IdentityN, IdentityNInfer);
  1257. IMPLEMT_INFERFUNC(Identity, IdentityInfer) {
  1258. OP_LOGI(op.GetName().c_str(), "Identity infershape start");
  1259. auto op_desc = OpDescUtils::GetOpDescFromOperator(op);
  1260. auto input_desc_x = op_desc->MutableInputDesc("x");
  1261. auto output_desc_y = op_desc->MutableOutputDesc("y");
  1262. std::vector<int64_t> vec_dim;
  1263. vec_dim = input_desc_x->MutableShape().GetDims();
  1264. std::vector<std::pair<int64_t, int64_t>> x_range;
  1265. input_desc_x->GetShapeRange(x_range);
  1266. DataType data_type = input_desc_x->GetDataType();
  1267. output_desc_y->SetDataType(data_type);
  1268. output_desc_y->SetShape(GeShape(vec_dim));
  1269. output_desc_y->SetOriginShape(GeShape(vec_dim));
  1270. output_desc_y->SetShapeRange(x_range);
  1271. OP_LOGI(op.GetName().c_str(), "Identity infershape end");
  1272. return GRAPH_SUCCESS;
  1273. }
  1274. INFER_FUNC_REG(Identity, IdentityInfer);
  1275. IMPLEMT_INFERFUNC(ReadVariableOp, ReadVariableOpInfer) {
  1276. TensorDesc input_desc = op.GetInputDesc("x");
  1277. (void)op.UpdateOutputDesc("y", input_desc);
  1278. return GRAPH_SUCCESS;
  1279. }
  1280. INFER_FUNC_REG(ReadVariableOp, ReadVariableOpInfer);
  1281. template <typename T>
  1282. static void CaclDims(const Tensor& data, std::vector<int64_t>& vec_dim) {
  1283. int32_t size = data.GetSize() / sizeof(T);
  1284. for (int32_t i = 0; i < size; i++) {
  1285. T dim = *((T*)data.GetData() + i);
  1286. if (dim != 0) {
  1287. vec_dim.push_back(dim);
  1288. } else {
  1289. vec_dim.clear();
  1290. break;
  1291. }
  1292. }
  1293. }
  1294. template <typename T>
  1295. static void CaclDims(const GeTensorPtr& data, std::vector<int64_t>& vec_dim) {
  1296. int32_t size = data->GetData().GetSize() / sizeof(T);
  1297. for (int32_t i = 0; i < size; i++) {
  1298. void* data_ptr = (void*)data->GetData().GetData();
  1299. if (data_ptr == nullptr) {
  1300. return;
  1301. }
  1302. T dim = *((T*)data_ptr + i);
  1303. if (dim != 0) {
  1304. vec_dim.push_back(dim);
  1305. } else {
  1306. vec_dim.clear();
  1307. break;
  1308. }
  1309. }
  1310. }
  1311. IMPLEMT_INFERFUNC(Empty, EmptyInfer) {
  1312. OP_LOGI(op.GetName().c_str(), "Empty infershape start");
  1313. auto op_desc = OpDescUtils::GetOpDescFromOperator(op);
  1314. std::vector<string> dep_inputs = {"shape"};
  1315. op_desc->SetOpInferDepends(dep_inputs);
  1316. auto input_desc_shape = op_desc->MutableInputDesc("shape");
  1317. auto output_desc_y = op_desc->MutableOutputDesc("y");
  1318. auto dtype = op.get_attr_dtype();
  1319. std::vector<std::pair<int64_t, int64_t>> shape_range;
  1320. std::vector<std::pair<int64_t, int64_t>> y_range;
  1321. input_desc_shape->GetShapeRange(shape_range);
  1322. DataType data_type = input_desc_shape->GetDataType();
  1323. std::vector<int64_t> vec_dim;
  1324. if (data_type == DT_INT32) {
  1325. vec_dim = input_desc_shape->MutableShape().GetDims();
  1326. } else {
  1327. GeInfershapeErrReport(op.GetName(), op.GetOpType(), "dtype", "Empty only support shape type 'DT_INT32'");
  1328. GE_OP_LOGE(op.GetName().c_str(), "Empty only support shape type 'DT_INT32'");
  1329. return GRAPH_PARAM_INVALID;
  1330. }
  1331. if (vec_dim == UNKNOWN_RANK) {
  1332. GE_OP_LOGD(op.GetName().c_str(), "all inputs are unknown rank!");
  1333. output_desc_y->SetShape(GeShape(UNKNOWN_SHAPE));
  1334. output_desc_y->SetOriginShape(GeShape(UNKNOWN_SHAPE));
  1335. output_desc_y->SetDataType((DataType)dtype);
  1336. return GRAPH_SUCCESS;
  1337. }
  1338. if (vec_dim == UNKNOWN_SHAPE) {
  1339. GE_OP_LOGD(op.GetName().c_str(), "shape is unknown shape!");
  1340. std::pair<int64_t, int64_t> pair({1, shape_range.size()});
  1341. y_range.emplace_back(pair);
  1342. output_desc_y->SetShape(GeShape(UNKNOWN_SHAPE));
  1343. output_desc_y->SetOriginShape(GeShape(UNKNOWN_SHAPE));
  1344. output_desc_y->SetDataType((DataType)dtype);
  1345. output_desc_y->SetShapeRange(y_range);
  1346. return GRAPH_SUCCESS;
  1347. }
  1348. auto node = NodeUtils::GetNodeFromOperator(op);
  1349. if (node == nullptr) {
  1350. OP_LOGE(op.GetName().c_str(), "Get null node ptr.");
  1351. return GRAPH_PARAM_INVALID;
  1352. }
  1353. GeTensorPtr shape_data;
  1354. std::vector<int64_t> shape_dims;
  1355. auto result = NodeUtils::GetInputConstData(node, "shape", shape_data);
  1356. if(result == GRAPH_SUCCESS) {
  1357. DataType data_type = shape_data->GetTensorDesc().GetDataType();
  1358. if (data_type == DT_INT32) {
  1359. CaclDims<int32_t>(shape_data,shape_dims);
  1360. } else if (data_type == DT_INT64) {
  1361. CaclDims<int64_t>(shape_data, shape_dims);
  1362. }
  1363. OP_LOGD(op.GetName().c_str(), "Get input const data success.");
  1364. std::pair<int64_t, int64_t> pair({1,shape_range.size()});
  1365. y_range.emplace_back(pair);
  1366. output_desc_y->SetShape(GeShape(shape_dims));
  1367. output_desc_y->SetOriginShape(GeShape(shape_dims));
  1368. output_desc_y->SetDataType((DataType)dtype);
  1369. output_desc_y->SetShapeRange(y_range);
  1370. return GRAPH_SUCCESS;
  1371. } else {
  1372. OP_LOGD(op.GetName().c_str(), "Get input const data failed!");
  1373. std::pair<int64_t, int64_t> pair({1,shape_range.size()});
  1374. y_range.emplace_back(pair);
  1375. output_desc_y->SetShape(GeShape(UNKNOWN_SHAPE));
  1376. output_desc_y->SetOriginShape(GeShape(UNKNOWN_SHAPE));
  1377. output_desc_y->SetDataType((DataType)dtype);
  1378. output_desc_y->SetShapeRange(y_range);
  1379. return GRAPH_SUCCESS;
  1380. }
  1381. output_desc_y->SetShape(GeShape(vec_dim));
  1382. output_desc_y->SetOriginShape(GeShape(vec_dim));
  1383. output_desc_y->SetDataType((DataType)dtype);
  1384. OP_LOGD(op.GetName().c_str(), "Empty infershape end");
  1385. return GRAPH_SUCCESS;
  1386. }
  1387. INFER_FUNC_REG(Empty, EmptyInfer);
  1388. IMPLEMT_INFERFUNC(Where, WhereInfer) {
  1389. OpDescPtr op_desc = OpDescUtils::GetOpDescFromOperator(op);
  1390. GeTensorDescPtr x_desc = op_desc->MutableInputDesc(0);
  1391. GeShape x_shape;
  1392. if (WithRankAtLeast(x_desc, 1, x_shape) != GRAPH_SUCCESS) {
  1393. OP_LOGE(op.GetName().c_str(), "input x must be at least 1D.");
  1394. return GRAPH_FAILED;
  1395. }
  1396. if (WithRankAtMost(x_desc, 5, x_shape) != GRAPH_SUCCESS) {
  1397. OP_LOGE(op.GetName().c_str(), "input x must be at most 5D.");
  1398. return GRAPH_FAILED;
  1399. }
  1400. GeTensorDescPtr y_desc = op_desc->MutableOutputDesc(0);
  1401. y_desc->SetDataType(DT_INT64);
  1402. vector<int64_t> y_shape;
  1403. auto input_dims = x_shape.GetDims();
  1404. int64_t input_shape_size = x_shape.GetShapeSize();
  1405. if (input_shape_size != UNKNOWN_DIM) {
  1406. // input shape: known
  1407. y_shape.push_back(UNKNOWN_DIM);
  1408. y_shape.push_back(input_dims.size());
  1409. std::vector<std::pair<int64_t, int64_t>> range;
  1410. int64_t dims_num = x_shape.GetDimNum();
  1411. range.emplace_back(std::make_pair(1, input_shape_size));
  1412. range.emplace_back(std::make_pair(dims_num, dims_num));
  1413. y_desc->SetShapeRange(range);
  1414. } else {
  1415. if (input_dims == UNKNOWN_RANK) {
  1416. // input shape: unknown rank
  1417. y_shape.push_back(UNKNOWN_DIM);
  1418. y_shape.push_back(UNKNOWN_DIM);
  1419. } else {
  1420. // input shape: unknown dims
  1421. y_shape.push_back(UNKNOWN_DIM);
  1422. y_shape.push_back(input_dims.size());
  1423. }
  1424. }
  1425. y_desc->SetShape(GeShape(y_shape));
  1426. y_desc->SetOriginShape(GeShape(y_shape));
  1427. return GRAPH_SUCCESS;
  1428. }
  1429. INFER_FUNC_REG(Where, WhereInfer);
  1430. IMPLEMT_INFERFUNC(TransShape, TransShapeInfer) {
  1431. TensorDesc y_desc = op.GetOutputDesc("y");
  1432. vector<int64_t> output_shape;
  1433. auto ret = op.GetAttr("outShape", output_shape);
  1434. if (ret != GRAPH_SUCCESS) {
  1435. OP_LOGE(op.GetName().c_str(), "Failed to get attribute value.");
  1436. return GRAPH_SUCCESS;
  1437. }
  1438. y_desc.SetShape(Shape(output_shape));
  1439. if (op.UpdateOutputDesc("y", y_desc) != GRAPH_SUCCESS) {
  1440. return GRAPH_FAILED;
  1441. }
  1442. return GRAPH_SUCCESS;
  1443. }
  1444. INFER_FUNC_REG(TransShape, TransShapeInfer);
  1445. // ----------------SortV2 Begin-------------------
  1446. IMPLEMT_INFERFUNC(SortV2, SortV2InferShape) {
  1447. TensorDesc tensordesc_input = op.GetInputDesc("x");
  1448. Shape input_shape = tensordesc_input.GetShape();
  1449. DataType input_dtype = tensordesc_input.GetDataType();
  1450. std::vector<int64_t> dims_input = input_shape.GetDims();
  1451. TensorDesc tensordesc_output1 = op.GetOutputDesc("y");
  1452. tensordesc_output1.SetShape(ge::Shape(dims_input));
  1453. tensordesc_output1.SetDataType(input_dtype);
  1454. (void)op.UpdateOutputDesc("y", tensordesc_output1);
  1455. return GRAPH_SUCCESS;
  1456. }
  1457. IMPLEMT_VERIFIER(SortV2, SortV2Verify) { return GRAPH_SUCCESS; }
  1458. INFER_FUNC_REG(SortV2, SortV2InferShape);
  1459. VERIFY_FUNC_REG(SortV2, SortV2Verify);
  1460. // ----------------SortV2 END---------------------
  1461. // ----------------Expand Begin-------------------
  1462. template<typename T> static bool ExpandCalDim(const Tensor &data,
  1463. std::vector<int64_t> &vec_dim,
  1464. std::vector<int64_t> &vec_x) {
  1465. uint32_t size_shape = data.GetSize() / sizeof(T);
  1466. uint32_t size_x = vec_x.size();
  1467. if (size_shape < size_x) {
  1468. uint32_t diff = size_x - size_shape;
  1469. for (int32_t i = 0; i < size_x; i++) {
  1470. if (i < diff) {
  1471. vec_dim.push_back(vec_x[i]);
  1472. } else {
  1473. T dim = *((T *)data.GetData() + (i - diff));
  1474. if ((vec_x[i] != dim) && (vec_x[i] != 1) && (dim != 1)) {
  1475. return false;
  1476. }
  1477. if (vec_x[i] > dim) {
  1478. vec_dim.push_back(vec_x[i]);
  1479. } else {
  1480. vec_dim.push_back(dim);
  1481. }
  1482. }
  1483. }
  1484. } else {
  1485. uint32_t diff = size_shape - size_x;
  1486. for (int32_t i = 0; i < size_shape; i++) {
  1487. T dim = *((T *)data.GetData() + i);
  1488. if (i < diff) {
  1489. vec_dim.push_back(dim);
  1490. } else {
  1491. if ((vec_x[i - diff] != dim) && (vec_x[i-diff] != 1) && (dim != 1)) {
  1492. return false;
  1493. }
  1494. if (vec_x[i - diff] > dim) {
  1495. vec_dim.push_back(vec_x[i - diff]);
  1496. } else {
  1497. vec_dim.push_back(dim);
  1498. }
  1499. }
  1500. }
  1501. }
  1502. return true;
  1503. }
  1504. IMPLEMT_COMMON_INFERFUNC(ExpandInferShape) {
  1505. Shape x_shape = op.GetInputDesc("x").GetShape();
  1506. DataType x_dtype = op.GetInputDesc("x").GetDataType();
  1507. std::vector <int64_t> dims_x = x_shape.GetDims();
  1508. Tensor data;
  1509. std::vector <int64_t> vec_dim;
  1510. TensorDesc td = op.GetOutputDesc("y");
  1511. if (op.GetInputConstData("shape", data) != GRAPH_SUCCESS) {
  1512. OP_LOGE(op.GetName().c_str(), "Get constValue failed of [shape]");
  1513. return GRAPH_FAILED;
  1514. } else {
  1515. DataType data_type = data.GetTensorDesc().GetDataType();
  1516. std::vector <int64_t> vec_dim;
  1517. if (data_type == DT_INT32) {
  1518. if (!ExpandCalDim <int32_t>(data, vec_dim, dims_x)) {
  1519. OP_LOGE(op.GetName().c_str(), "Data shape are not compatible!");
  1520. return GRAPH_FAILED;
  1521. }
  1522. } else if (data_type == DT_INT64) {
  1523. if (!ExpandCalDim <int64_t>(data, vec_dim, dims_x)) {
  1524. OP_LOGE(op.GetName().c_str(), "Data shape are not compatible!");
  1525. return GRAPH_FAILED;
  1526. }
  1527. } else {
  1528. OP_LOGE(op.GetName().c_str(), "Data type not supported!");
  1529. return GRAPH_PARAM_INVALID;
  1530. }
  1531. td.SetShape(ge::Shape(vec_dim));
  1532. td.SetDataType(x_dtype);
  1533. (void)op.UpdateOutputDesc("y", td);
  1534. return GRAPH_SUCCESS;
  1535. }
  1536. }
  1537. COMMON_INFER_FUNC_REG(Expand, ExpandInferShape);
  1538. // ----------------Expand END---------------------
  1539. // ----------------ExpandD Begin-------------------
  1540. IMPLEMT_COMMON_INFERFUNC(ExpandDInferShape) {
  1541. Shape x_shape = op.GetInputDesc("x").GetShape();
  1542. DataType x_dtype = op.GetInputDesc("x").GetDataType();
  1543. std::vector<int64_t> shape;
  1544. op.GetAttr("shape", shape);
  1545. std::vector<int64_t> dims_x = x_shape.GetDims();
  1546. TensorDesc td = op.GetOutputDesc("y");
  1547. std::vector<int64_t> dim_vec;
  1548. if (shape.size() < dims_x.size()) {
  1549. std::vector<int64_t> dims_tmp = shape;
  1550. shape = dims_x;
  1551. dims_x = dims_tmp;
  1552. }
  1553. if (shape.size() != dims_x.size()) {
  1554. int dec = shape.size() - dims_x.size();
  1555. for (int i = 0; i < dec; i++) {
  1556. dims_x.insert(dims_x.begin(), (int64_t)1);
  1557. }
  1558. }
  1559. for (size_t i = 0; i < shape.size(); i++) {
  1560. if ((shape[i] != dims_x[i]) && (shape[i] != 1) && (dims_x[i] != 1)) {
  1561. OP_LOGE(op.GetName().c_str(), "The input shape and attr shape are not compatible.");
  1562. return GRAPH_FAILED;
  1563. }
  1564. if (shape[i] > dims_x[i]) {
  1565. dim_vec.push_back(shape[i]);
  1566. } else {
  1567. dim_vec.push_back(dims_x[i]);
  1568. }
  1569. }
  1570. td.SetShape(ge::Shape(dim_vec));
  1571. td.SetDataType(x_dtype);
  1572. (void)op.UpdateOutputDesc("y", td);
  1573. return GRAPH_SUCCESS;
  1574. }
  1575. COMMON_INFER_FUNC_REG(ExpandD, ExpandDInferShape);
  1576. // ----------------Expand END---------------------
  1577. } // namespace ge

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示