You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

operator.cc 96 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821182218231824182518261827182818291830183118321833183418351836183718381839184018411842184318441845184618471848184918501851185218531854185518561857185818591860186118621863186418651866186718681869187018711872187318741875187618771878187918801881188218831884188518861887188818891890189118921893189418951896189718981899190019011902190319041905190619071908190919101911191219131914191519161917191819191920192119221923192419251926192719281929193019311932193319341935193619371938193919401941194219431944194519461947194819491950195119521953195419551956195719581959196019611962196319641965196619671968196919701971197219731974197519761977197819791980198119821983198419851986198719881989199019911992199319941995199619971998199920002001200220032004200520062007200820092010201120122013201420152016201720182019202020212022202320242025202620272028202920302031203220332034203520362037203820392040204120422043204420452046204720482049205020512052205320542055205620572058205920602061206220632064206520662067206820692070207120722073207420752076207720782079208020812082208320842085208620872088208920902091209220932094209520962097209820992100210121022103210421052106210721082109211021112112211321142115211621172118211921202121212221232124212521262127212821292130213121322133213421352136213721382139214021412142214321442145214621472148214921502151215221532154215521562157215821592160216121622163216421652166216721682169217021712172217321742175217621772178217921802181218221832184218521862187218821892190219121922193219421952196219721982199220022012202220322042205220622072208220922102211221222132214221522162217221822192220222122222223222422252226222722282229223022312232223322342235223622372238223922402241224222432244224522462247224822492250225122522253225422552256225722582259226022612262226322642265226622672268226922702271227222732274227522762277227822792280228122822283228422852286228722882289229022912292
  1. /**
  2. * Copyright 2019-2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "external/graph/operator.h"
  17. #include "external/graph/operator_factory.h"
  18. #include <cstdint>
  19. #include <algorithm>
  20. #include <mutex>
  21. #include <queue>
  22. #include <set>
  23. #include "array_ops.h"
  24. #include "debug/ge_log.h"
  25. #include "debug/ge_op_types.h"
  26. #include "debug/ge_util.h"
  27. #include "external/graph/attr_value.h"
  28. #include "external/graph/types.h"
  29. #include "framework/common/debug/ge_log.h"
  30. #include "graph/compute_graph.h"
  31. #include "graph/ge_attr_value.h"
  32. #include "graph/ge_context.h"
  33. #include "graph/ge_tensor.h"
  34. #include "graph/node.h"
  35. #include "graph/op_desc.h"
  36. #include "graph/runtime_inference_context.h"
  37. #include "graph/usr_types.h"
  38. #include "graph/utils/node_utils.h"
  39. #include "graph/debug/ge_attr_define.h"
  40. #include "utils/graph_utils.h"
  41. #include "utils/op_desc_utils.h"
  42. #include "utils/tensor_adapter.h"
  43. #include "utils/tensor_utils.h"
  44. #include "utils/type_utils.h"
  45. #include <algorithm>
  46. #include <mutex>
  47. #include <queue>
  48. #include <set>
  49. using std::enable_shared_from_this;
  50. using std::make_pair;
  51. using std::shared_ptr;
  52. using std::string;
  53. using std::to_string;
  54. using std::vector;
  55. /*lint -save -e529 -e728*/
  56. namespace ge {
  57. /*lint -e446 -e732*/
  58. /*lint -e665*/
  59. class OpIO {
  60. public:
  61. OpIO(const string &name, int index, const OperatorImplPtr &owner) : name_(name), index_(index), owner_(owner) {}
  62. ~OpIO() = default;
  63. string GetName() const { return name_; }
  64. int GetIndex() const { return index_; }
  65. OperatorImplPtr GetOwner() const { return owner_; }
  66. bool operator==(const OpIO &r_value) const {
  67. return (this->name_ == r_value.GetName()) && (this->index_ == r_value.GetIndex()) &&
  68. (this->GetOwner() == r_value.GetOwner());
  69. }
  70. private:
  71. string name_;
  72. int index_;
  73. std::shared_ptr<OperatorImpl> owner_;
  74. };
  75. class TensorTypeImpl {
  76. public:
  77. TensorTypeImpl() = default;
  78. ~TensorTypeImpl() = default;
  79. std::vector<DataType> dt_vec_;
  80. };
  81. TensorType::TensorType(DataType dt) {
  82. tensor_type_impl_ = ComGraphMakeShared<TensorTypeImpl>();
  83. if (tensor_type_impl_ != nullptr) {
  84. tensor_type_impl_->dt_vec_.push_back(dt);
  85. }
  86. }
  87. TensorType::TensorType(const std::initializer_list<DataType> &types) {
  88. tensor_type_impl_ = ComGraphMakeShared<TensorTypeImpl>();
  89. if (tensor_type_impl_ != nullptr) {
  90. tensor_type_impl_->dt_vec_ = types;
  91. }
  92. }
  93. class OperatorImpl : public std::enable_shared_from_this<OperatorImpl> {
  94. friend class GraphBuilderImpl;
  95. friend class OpDescUtils;
  96. public:
  97. explicit OperatorImpl(const string &name, const string &type) : op_desc_(ComGraphMakeShared<OpDesc>(name, type)) {
  98. if (op_desc_ == nullptr) {
  99. GELOGW("OpDesc make shared failed");
  100. }
  101. }
  102. explicit OperatorImpl(const OpDescPtr &op_desc) : op_desc_(op_desc) {}
  103. explicit OperatorImpl(ge::ConstNodePtr node) : node_(std::move(node)) {
  104. if (node_ != nullptr && node_->GetOpDesc() != nullptr) {
  105. op_desc_ = node_->GetOpDesc();
  106. }
  107. }
  108. ~OperatorImpl() {}
  109. void SetInputImpl(const string &dst_name, const ge::Operator &src_oprt) {
  110. GE_CHK_BOOL_EXEC(!dst_name.empty(), return, "dst name is empty");
  111. GE_CHK_BOOL_EXEC(op_desc_ != nullptr, return, "op_desc_ is nullptr.");
  112. GE_CHK_BOOL_EXEC(src_oprt.operator_impl_ != nullptr, return, "operator_impl_ is nullptr.");
  113. GE_CHK_BOOL_EXEC(src_oprt.operator_impl_->op_desc_ != nullptr, return, "op_desc_ is nullptr.");
  114. auto src_op_impl = src_oprt.GetOperatorImplPtr();
  115. GE_CHK_BOOL_EXEC(src_op_impl != nullptr, return, "Src impl is null.");
  116. GE_CHK_BOOL_EXEC(src_op_impl->op_desc_ != nullptr, return, "Src impl's opdesc is null.");
  117. GE_CHK_BOOL_EXEC(src_oprt.operator_impl_->op_desc_->GetOutputsSize() == 1, return,
  118. "The source operator[%s] must has one output",
  119. src_oprt.operator_impl_->op_desc_->GetName().c_str())
  120. uint32_t src_index = 0;
  121. string src_name = src_op_impl->op_desc_->GetOutputNameByIndex(src_index);
  122. GE_CHK_BOOL_EXEC(!src_name.empty(), return, "Src output's name is empty.");
  123. OpIO out_handler(src_name, src_index, src_op_impl);
  124. input_link_.insert(std::make_pair(dst_name, out_handler));
  125. int dst_index = op_desc_->GetInputIndexByName(dst_name);
  126. GE_CHK_BOOL_EXEC(dst_index >= 0, return, "Find input index by name failed. name[%s], op name:%s", dst_name.c_str(),
  127. op_desc_->GetName().c_str());
  128. bool is_const = false;
  129. if (src_oprt.GetOpType() == CONSTANT) {
  130. is_const = true;
  131. }
  132. auto is_input_const = op_desc_->GetIsInputConst();
  133. for (int i = static_cast<int>(is_input_const.size()); i <= dst_index; ++i) {
  134. is_input_const.push_back(false);
  135. }
  136. is_input_const[dst_index] = is_const;
  137. op_desc_->SetIsInputConst(is_input_const);
  138. OpIO op_dst(dst_name, dst_index, shared_from_this());
  139. src_op_impl->UpdateLinkMapImpl(src_name, op_dst);
  140. auto output_desc = src_op_impl->GetOutputDesc(src_name);
  141. auto input_desc = op_desc_->GetInputDesc(dst_name);
  142. if (input_desc.GetFormat() == FORMAT_RESERVED) {
  143. output_desc.SetFormat(FORMAT_ND);
  144. } else {
  145. output_desc.SetFormat(input_desc.GetFormat());
  146. }
  147. // Fix for linking opdesc
  148. if (op_desc_->UpdateInputDesc(dst_name, output_desc) != GRAPH_SUCCESS) {
  149. GELOGE(GRAPH_FAILED, "Update inputdesc failed,dst name is %s, src name is %s", dst_name.c_str(),
  150. src_name.c_str());
  151. return;
  152. }
  153. }
  154. void SetInputImpl(const string &dst_name, const ge::OutHandler &out_handler) {
  155. GE_CHK_BOOL_EXEC(!dst_name.empty(), return, "dst name is empty");
  156. GE_CHK_BOOL_EXEC(out_handler != nullptr, return, "SetInputImpl faild, out_handler is nullptr.");
  157. GE_CHK_BOOL_EXEC(op_desc_ != nullptr, return, "op_desc_ is nullptr.");
  158. input_link_.insert(std::make_pair(dst_name, *out_handler));
  159. string src_name = out_handler->GetName();
  160. int dst_index = op_desc_->GetInputIndexByName(dst_name);
  161. GE_CHK_BOOL_EXEC(dst_index >= 0, return, "Find input index by name failed. name[%s], op name:%s", dst_name.c_str(),
  162. op_desc_->GetName().c_str());
  163. auto out_op_impl = out_handler->GetOwner();
  164. GE_CHK_BOOL_EXEC(out_op_impl != nullptr && out_op_impl->GetOpDescImpl() != nullptr, return,
  165. "out_handler invalid. name[%s]", dst_name.c_str());
  166. bool is_const = false;
  167. if (out_op_impl->GetOpDescImpl()->GetType() == CONSTANT) {
  168. is_const = true;
  169. }
  170. auto is_input_const = op_desc_->GetIsInputConst();
  171. for (int i = static_cast<int>(is_input_const.size()); i <= dst_index; ++i) {
  172. is_input_const.push_back(false);
  173. }
  174. is_input_const[dst_index] = is_const;
  175. op_desc_->SetIsInputConst(is_input_const);
  176. OpIO in_handler(dst_name, dst_index, shared_from_this());
  177. GE_CHK_BOOL_EXEC(out_op_impl != nullptr, return, "Get out_handler's impl failed.");
  178. out_op_impl->UpdateLinkMapImpl(src_name, in_handler);
  179. auto src_output_desc = out_op_impl->GetOutputDesc(src_name);
  180. auto dst_input_desc = op_desc_->GetInputDesc(dst_name);
  181. if (dst_input_desc.GetFormat() == FORMAT_RESERVED) {
  182. src_output_desc.SetFormat(FORMAT_ND);
  183. } else {
  184. src_output_desc.SetFormat(dst_input_desc.GetFormat());
  185. }
  186. GE_CHK_BOOL_EXEC(op_desc_->UpdateInputDesc(dst_name, src_output_desc) == GRAPH_SUCCESS, return,
  187. "Update input desc failed,dst name is %s,src name is %s", dst_name.c_str(),
  188. src_name.c_str()); // fix for linking opdesc
  189. }
  190. void AddControlInputImp(const ge::Operator &src_oprt) {
  191. if (src_oprt.operator_impl_ == nullptr) {
  192. GELOGE(FAILED, "Src operator impl is nullptr");
  193. return;
  194. }
  195. for (auto &input : control_input_link_) {
  196. if (input.lock() == src_oprt.operator_impl_) {
  197. return;
  198. }
  199. }
  200. control_input_link_.push_back(src_oprt.operator_impl_);
  201. src_oprt.operator_impl_->control_output_link_.push_back(shared_from_this());
  202. }
  203. graphStatus GetInputImpl(const string &dst_name, ge::OpIO &out_handler) {
  204. auto out = input_link_.find(dst_name);
  205. if (out == input_link_.end()) {
  206. return GRAPH_FAILED;
  207. }
  208. out_handler = out->second;
  209. return GRAPH_SUCCESS;
  210. }
  211. graphStatus GetInputConstData(const string &dst_name, Tensor &data) {
  212. auto node_ptr = GetNode();
  213. if (node_ptr != nullptr) {
  214. // For inner compute graph
  215. auto op_desc = node_ptr->GetOpDesc();
  216. GE_CHECK_NOTNULL(op_desc);
  217. auto index = op_desc->GetInputIndexByName(dst_name);
  218. auto in_data_anchor = node_ptr->GetInDataAnchor(index);
  219. GE_CHECK_NOTNULL(in_data_anchor);
  220. auto out_data_anchor = in_data_anchor->GetPeerOutAnchor();
  221. GE_CHECK_NOTNULL(out_data_anchor);
  222. auto peer_node = out_data_anchor->GetOwnerNode();
  223. if (peer_node->GetType() == ENTER || peer_node->GetType() == REFENTER) {
  224. auto enter_in_data_anchor = peer_node->GetInDataAnchor(0);
  225. GE_CHECK_NOTNULL(enter_in_data_anchor);
  226. auto enter_peer_out_data_anchor = enter_in_data_anchor->GetPeerOutAnchor();
  227. GE_CHECK_NOTNULL(enter_peer_out_data_anchor);
  228. peer_node = enter_peer_out_data_anchor->GetOwnerNode();
  229. }
  230. auto peer_op_desc = peer_node->GetOpDesc();
  231. GE_CHECK_NOTNULL(peer_op_desc);
  232. auto peer_op_type = peer_op_desc->GetType();
  233. if (peer_op_type == CONSTANTOP || peer_op_type == CONSTANT) {
  234. auto const_op_impl = ComGraphMakeShared<OperatorImpl>(peer_node);
  235. GE_CHECK_NOTNULL(const_op_impl);
  236. Operator const_op(std::move(const_op_impl));
  237. return const_op.GetAttr(ATTR_NAME_WEIGHTS, data);
  238. } else if (peer_op_type == DATA) {
  239. auto parent_node = NodeUtils::GetParentInput(peer_node);
  240. while ((parent_node != nullptr) && (parent_node->GetType() == DATA)) {
  241. parent_node = NodeUtils::GetParentInput(parent_node);
  242. }
  243. if ((parent_node != nullptr)
  244. && ((parent_node->GetType() == CONSTANT) || (parent_node->GetType() == CONSTANTOP))) {
  245. auto const_op_impl = ComGraphMakeShared<OperatorImpl>(parent_node);
  246. GE_CHECK_NOTNULL(const_op_impl);
  247. Operator const_op(std::move(const_op_impl));
  248. return const_op.GetAttr(ATTR_NAME_WEIGHTS, data);
  249. }
  250. }
  251. // Try get from runtime inference context
  252. auto session_id = std::to_string(GetContext().SessionId());
  253. RuntimeInferenceContext *runtime_infer_ctx = nullptr;
  254. if (RuntimeInferenceContext::GetContext(session_id, &runtime_infer_ctx) == GRAPH_SUCCESS) {
  255. GELOGD("To get constant from runtime inference context. session_id = %s", session_id.c_str());
  256. auto ret = runtime_infer_ctx->GetTensor(peer_node->GetOpDesc()->GetId(), out_data_anchor->GetIdx(), data);
  257. if (ret == GRAPH_SUCCESS) {
  258. return GRAPH_SUCCESS;
  259. }
  260. }
  261. } else {
  262. // For outer graph
  263. return GetInputConstDataOut(dst_name, data);
  264. }
  265. auto op_name = GetName();
  266. GELOGW("node[%s]'s input[%s]'s peer node is not const", op_name.c_str(), dst_name.c_str());
  267. return GRAPH_FAILED;
  268. }
  269. graphStatus GetInputConstDataOut(const string &dst_name, Tensor &data) {
  270. ge::OpIO out_handle("", 0, nullptr);
  271. if (GetInputImpl(dst_name, out_handle) != GRAPH_SUCCESS) {
  272. GELOGE(FAILED, "%s get input impl failed", dst_name.c_str());
  273. return GRAPH_FAILED;
  274. }
  275. if (out_handle.GetOwner() != nullptr && out_handle.GetOwner()->GetOpDescImpl() != nullptr) {
  276. Operator const_op(out_handle.GetOwner());
  277. const auto &op_desc_impl_type = out_handle.GetOwner()->GetOpDescImpl()->GetType();
  278. if (op_desc_impl_type == CONSTANTOP) {
  279. return const_op.GetAttr(op::Constant::name_attr_value(), data);
  280. } else if (op_desc_impl_type == CONSTANT) {
  281. return const_op.GetAttr(op::Const::name_attr_value(), data);
  282. }
  283. }
  284. return GRAPH_FAILED;
  285. }
  286. bool InputIsSet(const string &name) {
  287. GE_CHK_BOOL_EXEC(op_desc_ != nullptr, return false, "op_desc_ is nullptr.");
  288. return op_desc_->InputIsSet(name);
  289. }
  290. string GetName() const {
  291. GE_CHK_BOOL_EXEC(op_desc_ != nullptr, return string(), "op_desc_ is nullptr.");
  292. return op_desc_->GetName();
  293. }
  294. GeTensorDesc GetInputDesc(const string &name) const {
  295. GE_CHK_BOOL_EXEC(op_desc_ != nullptr, return GeTensorDesc(), "op_desc_ is nullptr.");
  296. return op_desc_->GetInputDesc(name);
  297. }
  298. GeTensorDesc GetInputDesc(uint32_t index) const {
  299. GE_CHK_BOOL_EXEC(op_desc_ != nullptr, return GeTensorDesc(), "op_desc_ is nullptr.");
  300. return op_desc_->GetInputDesc(index);
  301. }
  302. graphStatus UpdateInputDesc(const string &name, const GeTensorDesc &tensor_desc) {
  303. GE_CHK_BOOL_EXEC(op_desc_ != nullptr, return GRAPH_FAILED, "op_desc_ is nullptr.");
  304. return op_desc_->UpdateInputDesc(name, tensor_desc);
  305. }
  306. OutHandler GetOutput(const string &name) {
  307. GE_CHK_BOOL_EXEC(op_desc_ != nullptr, return nullptr, "op_desc_ is nullptr.");
  308. int src_index = op_desc_->GetOutputIndexByName(name);
  309. GE_CHK_BOOL_EXEC(src_index >= 0, return nullptr, "Find src index by name failed. name[%s]", name.c_str());
  310. shared_ptr<OpIO> output_ptr = ComGraphMakeShared<OpIO>(name, src_index, shared_from_this());
  311. if (output_ptr == nullptr) {
  312. GELOGE(GRAPH_FAILED, "OpIO make shared failed");
  313. return nullptr;
  314. }
  315. return output_ptr;
  316. }
  317. OutHandler GetOutput(uint32_t index) {
  318. GE_CHK_BOOL_EXEC(op_desc_ != nullptr, return nullptr, "op_desc_ is nullptr.");
  319. string name = op_desc_->GetOutputNameByIndex(index);
  320. if (name.empty()) {
  321. GELOGE(GRAPH_FAILED, "Find src name by index failed. index[%u]", index);
  322. return nullptr;
  323. }
  324. shared_ptr<OpIO> output_ptr = ComGraphMakeShared<OpIO>(name, index, shared_from_this());
  325. if (output_ptr == nullptr) {
  326. GELOGE(GRAPH_FAILED, "OpIO make shared failed");
  327. return nullptr;
  328. }
  329. return output_ptr;
  330. }
  331. GeTensorDesc GetOutputDesc(const string &name) const {
  332. GE_CHK_BOOL_EXEC(op_desc_ != nullptr, return GeTensorDesc(), "op_desc_ is nullptr.");
  333. return op_desc_->GetOutputDesc(name);
  334. }
  335. GeTensorDesc GetOutputDesc(uint32_t index) const {
  336. GE_CHK_BOOL_EXEC(op_desc_ != nullptr, return GeTensorDesc(), "op_desc_ is nullptr.");
  337. return op_desc_->GetOutputDesc(index);
  338. }
  339. graphStatus UpdateOutputDesc(const string &name, const GeTensorDesc &tensor_desc) {
  340. GE_CHK_BOOL_RET_STATUS(op_desc_ != nullptr, GRAPH_FAILED, "op_desc is nullptr.");
  341. auto res = op_desc_->UpdateOutputDesc(name, tensor_desc);
  342. if (res == GRAPH_SUCCESS) {
  343. for (auto ol : output_links_[name]) {
  344. if (ol.GetOwner() == nullptr) {
  345. GELOGW("%s get owner is nullptr", ol.GetName().c_str());
  346. continue;
  347. }
  348. GE_CHK_BOOL_RET_STATUS(ol.GetOwner()->UpdateInputDesc(ol.GetName(), tensor_desc) == GRAPH_SUCCESS, GRAPH_FAILED,
  349. "Could not update next operator's input %s.", ol.GetName().c_str());
  350. }
  351. }
  352. return res;
  353. }
  354. size_t GetInputsSize() const {
  355. GE_IF_BOOL_EXEC(op_desc_ == nullptr, return 0);
  356. return op_desc_->GetInputsSize();
  357. }
  358. size_t GetOutputsSize() const {
  359. GE_IF_BOOL_EXEC(op_desc_ == nullptr, return 0);
  360. return op_desc_->GetOutputsSize();
  361. }
  362. graphStatus SetAttr(const string &name, GeAttrValue &&attr_value) {
  363. GE_CHK_BOOL_RET_STATUS(op_desc_ != nullptr, GRAPH_FAILED, "op_desc is nullptr.");
  364. return op_desc_->SetAttr(name, std::move(attr_value));
  365. }
  366. graphStatus GetAttr(const string &name, GeAttrValue &attr_value) const {
  367. GE_CHK_BOOL_RET_STATUS(op_desc_ != nullptr, GRAPH_FAILED, "op_desc is nullptr.");
  368. return op_desc_->GetAttr(name, attr_value);
  369. }
  370. OpDescPtr GetOpDescImpl() const { return op_desc_; }
  371. void UpdateLinkMapImpl(const string &src_name, OpIO &op_dst) {
  372. auto it_find = output_links_.find(src_name);
  373. if (it_find == output_links_.end()) {
  374. std::vector<OpIO> dsts{op_dst};
  375. output_links_.insert(std::make_pair(src_name, dsts));
  376. } else {
  377. it_find->second.push_back(op_dst);
  378. }
  379. }
  380. Operator ToOperator() { return Operator(shared_from_this()); }
  381. static OpDescPtr GetOpDesc(const Operator &oprt) {
  382. GE_IF_BOOL_EXEC(oprt.operator_impl_ == nullptr, return nullptr);
  383. return oprt.operator_impl_->op_desc_;
  384. }
  385. void ClearOutputLinks() noexcept { output_links_.clear(); }
  386. void ClearInputLinks() noexcept { input_link_.clear(); }
  387. ge::ConstNodePtr GetNode() { return node_; }
  388. void SetInferenceContext(const InferenceContextPtr &inference_context) { inference_context_ = inference_context; }
  389. InferenceContextPtr GetInferenceContext() const { return inference_context_; }
  390. void SubgraphRegister(const std::string &ir_name, bool dynamic) {
  391. op_desc_->RegisterSubgraphIrName(ir_name, dynamic ? kDynamic : kStatic);
  392. }
  393. void SubgraphCountRegister(const std::string &ir_name, uint32_t count) {
  394. if (op_desc_->GetSubgraphTypeByIrName(ir_name) == kStatic) {
  395. op_desc_->AddSubgraphName(ir_name);
  396. subgraph_names_to_builders_[ir_name] = nullptr;
  397. } else {
  398. for (uint32_t i = 0; i < count; ++i) {
  399. string key_name = ir_name + std::to_string(i);
  400. op_desc_->AddSubgraphName(key_name);
  401. subgraph_names_to_builders_[key_name] = nullptr;
  402. }
  403. }
  404. }
  405. void SetSubgraphBuilder(const std::string &ir_name, uint32_t index, const SubgraphBuilder &builder) {
  406. string key_name = ir_name;
  407. if (op_desc_->GetSubgraphTypeByIrName(ir_name) == kDynamic) {
  408. key_name += std::to_string(index);
  409. }
  410. auto it = subgraph_names_to_builders_.find(key_name);
  411. if (it == subgraph_names_to_builders_.end()) {
  412. GELOGE(PARAM_INVALID, "Failed to set subgraph builder for name %s index %u.", ir_name.c_str(), index);
  413. return;
  414. }
  415. it->second = builder;
  416. }
  417. SubgraphBuilder GetSubgraphBuilder(const std::string &ir_name, uint32_t index) const {
  418. string key_name = ir_name;
  419. if (op_desc_->GetSubgraphTypeByIrName(ir_name) == kDynamic) {
  420. key_name += std::to_string(index);
  421. }
  422. return GetSubgraphBuilder(key_name);
  423. }
  424. SubgraphBuilder GetSubgraphBuilder(const std::string &name) const {
  425. auto iter = subgraph_names_to_builders_.find(name);
  426. if (iter == subgraph_names_to_builders_.end()) {
  427. GELOGE(PARAM_INVALID, "Failed to get subgraph builder for name %s", name.c_str());
  428. return nullptr;
  429. }
  430. return iter->second;
  431. }
  432. std::vector<std::string> GetSubgraphNames() const {
  433. std::vector<std::string> names;
  434. for (const auto &subgraph_name_to_type : op_desc_->GetSubgraphIrNames()) {
  435. names.emplace_back(subgraph_name_to_type.first);
  436. }
  437. return names;
  438. }
  439. size_t GetSubgraphNamesCount() const {
  440. return op_desc_->GetSubgraphIrNames().size();
  441. }
  442. OpDescPtr op_desc_ = nullptr;
  443. private:
  444. ge::ConstNodePtr node_{nullptr};
  445. ge::InferenceContextPtr inference_context_;
  446. std::map<string, std::vector<OpIO>> output_links_{};
  447. std::map<string, OpIO> input_link_{};
  448. std::vector<std::weak_ptr<OperatorImpl>> control_input_link_{};
  449. std::vector<std::weak_ptr<OperatorImpl>> control_output_link_{};
  450. std::map<std::string, SubgraphBuilder> subgraph_names_to_builders_;
  451. };
  452. // Used to manage OperatorImpl instances created by ge api.
  453. class OperatorKeeper {
  454. private:
  455. OperatorKeeper() = default;
  456. ~OperatorKeeper() {
  457. for (const auto &iter : operators_) {
  458. if (iter) {
  459. iter->ClearInputLinks();
  460. iter->ClearOutputLinks();
  461. }
  462. }
  463. }
  464. std::set<OperatorImplPtr> operators_;
  465. std::mutex mutex_;
  466. public:
  467. static OperatorKeeper &GetInstance() {
  468. static OperatorKeeper instance;
  469. return instance;
  470. }
  471. void CheckInOperator(const OperatorImplPtr &op_impl) {
  472. if (op_impl) {
  473. std::lock_guard<std::mutex> lock(mutex_);
  474. operators_.insert(op_impl);
  475. }
  476. }
  477. void CheckOutOperator(const OperatorImplPtr &op_impl) {
  478. if (op_impl) {
  479. std::lock_guard<std::mutex> lock(mutex_);
  480. operators_.erase(op_impl);
  481. }
  482. }
  483. };
  484. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Operator OpDescUtils::CreateOperatorFromNode(ge::ConstNodePtr node_ptr) {
  485. ge::OperatorImplPtr operator_impl_ptr = ComGraphMakeShared<OperatorImpl>(node_ptr);
  486. if (operator_impl_ptr == nullptr) {
  487. GELOGE(GRAPH_FAILED, "OperatorImpl make shared failed");
  488. return Operator("default");
  489. }
  490. return operator_impl_ptr->ToOperator();
  491. }
  492. Operator::Operator(const std::string &type) {
  493. static uint32_t index = 0;
  494. string name = type + "_" + std::to_string(index++);
  495. operator_impl_ = ComGraphMakeShared<OperatorImpl>(name, type);
  496. if (operator_impl_ == nullptr) {
  497. GELOGW("OperatorImpl make shared failed");
  498. }
  499. OperatorKeeper::GetInstance().CheckInOperator(operator_impl_);
  500. }
  501. Operator::Operator(const char *type) {
  502. if (type != nullptr) {
  503. std::string op_type = type;
  504. static uint32_t index = 0;
  505. string name = op_type + "_" + std::to_string(index++);
  506. operator_impl_ = ComGraphMakeShared<OperatorImpl>(name, op_type);
  507. if (operator_impl_ == nullptr) {
  508. GELOGW("OperatorImpl make shared failed");
  509. }
  510. OperatorKeeper::GetInstance().CheckInOperator(operator_impl_);
  511. } else {
  512. GELOGW("Operator type is nullptr.");
  513. }
  514. }
  515. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Operator OpDescUtils::CreateOperatorFromOpDesc(OpDescPtr op_desc) {
  516. shared_ptr<OperatorImpl> operator_impl_ptr;
  517. operator_impl_ptr = ComGraphMakeShared<OperatorImpl>(op_desc);
  518. if (operator_impl_ptr == nullptr) {
  519. GELOGE(GRAPH_FAILED, "OperatorImpl make shared failed");
  520. return Operator("default");
  521. }
  522. OperatorKeeper::GetInstance().CheckInOperator(operator_impl_ptr);
  523. return operator_impl_ptr->ToOperator();
  524. }
  525. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescPtr OpDescUtils::GetOpDescFromOperator(const Operator &oprt) {
  526. return OperatorImpl::GetOpDesc(oprt);
  527. }
  528. GE_FUNC_HOST_VISIBILITY Operator::Operator(const string &name, const string &type) {
  529. operator_impl_ = ComGraphMakeShared<OperatorImpl>(name, type);
  530. if (operator_impl_ == nullptr) {
  531. GELOGE(GRAPH_FAILED, "OperatorImpl make shared failed");
  532. return;
  533. }
  534. OperatorKeeper::GetInstance().CheckInOperator(operator_impl_);
  535. }
  536. GE_FUNC_HOST_VISIBILITY Operator::Operator(const AscendString &name, const AscendString &type) {
  537. if ((name.GetString() != nullptr) && (type.GetString() != nullptr)) {
  538. string op_name = name.GetString();
  539. string op_type = type.GetString();
  540. operator_impl_ = ComGraphMakeShared<OperatorImpl>(op_name, op_type);
  541. if (operator_impl_ == nullptr) {
  542. GELOGE(GRAPH_FAILED, "OperatorImpl make shared failed");
  543. return;
  544. }
  545. OperatorKeeper::GetInstance().CheckInOperator(operator_impl_);
  546. } else {
  547. GELOGW("Operator input parameter is nullptr.");
  548. }
  549. }
  550. GE_FUNC_HOST_VISIBILITY Operator::Operator(const char *name, const char *type) {
  551. if ((name != nullptr) && (type != nullptr)) {
  552. string op_name = name;
  553. string op_type = type;
  554. operator_impl_ = ComGraphMakeShared<OperatorImpl>(op_name, op_type);
  555. if (operator_impl_ == nullptr) {
  556. GELOGE(GRAPH_FAILED, "OperatorImpl make shared failed");
  557. return;
  558. }
  559. OperatorKeeper::GetInstance().CheckInOperator(operator_impl_);
  560. } else {
  561. GELOGW("Operator input parameter is nullptr.");
  562. }
  563. }
  564. Operator::Operator(ge::OperatorImplPtr &&op_impl) { operator_impl_ = std::move(op_impl); }
  565. bool Operator::IsEmpty() const {
  566. if (operator_impl_ == nullptr) {
  567. return true;
  568. }
  569. return false;
  570. }
  571. string Operator::GetName() const {
  572. if (operator_impl_ != nullptr) {
  573. return operator_impl_->GetName();
  574. }
  575. return "";
  576. }
  577. graphStatus Operator::GetName(AscendString &name) const {
  578. if (operator_impl_ != nullptr) {
  579. string op_name = operator_impl_->GetName();
  580. name = op_name.c_str();
  581. }
  582. return GRAPH_SUCCESS;
  583. }
  584. GE_FUNC_HOST_VISIBILITY Operator &Operator::SetInput(const string &dst_name, const ge::Operator &src_oprt) {
  585. // Describe the connection relationship between operators, no create action
  586. GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return *this, "operator impl is nullptr.");
  587. operator_impl_->SetInputImpl(dst_name, src_oprt);
  588. return *this;
  589. }
  590. GE_FUNC_HOST_VISIBILITY Operator &Operator::SetInput(const char *dst_name, const ge::Operator &src_oprt) {
  591. GE_CHK_BOOL_EXEC(dst_name != nullptr, return *this, "Operator dst name is nullptr.");
  592. // Describe the connection relationship between operators, no create action
  593. GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return *this, "Operator impl is nullptr.");
  594. std::string dst_op_name = dst_name;
  595. operator_impl_->SetInputImpl(dst_op_name, src_oprt);
  596. return *this;
  597. }
  598. Operator &Operator::SetInput(const string &dst_name, const ge::OutHandler &out_handler) {
  599. GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return *this, "operator impl is nullptr.");
  600. operator_impl_->SetInputImpl(dst_name, out_handler);
  601. return *this;
  602. }
  603. Operator &Operator::SetInput(const std::string &dst_name, const ge::Operator &src_oprt, const std::string &name) {
  604. auto out_handler = src_oprt.GetOutput(name);
  605. GE_CHK_BOOL_EXEC(out_handler != nullptr, return *this, "out_handler is nullptr.");
  606. (void)SetInput(dst_name, out_handler);
  607. return *this;
  608. }
  609. Operator &Operator::SetInput(const char *dst_name, const ge::Operator &src_oprt, const char *name) {
  610. GE_CHK_BOOL_EXEC(dst_name != nullptr, return *this, "Dst name is nullptr.");
  611. GE_CHK_BOOL_EXEC(name != nullptr, return *this, "Name is nullptr.");
  612. std::string op_name = name;
  613. std::string dst_op_name = dst_name;
  614. auto out_handler = src_oprt.GetOutput(op_name);
  615. GE_CHK_BOOL_EXEC(out_handler != nullptr, return *this, "Out_handler is nullptr.");
  616. (void)SetInput(dst_op_name, out_handler);
  617. return *this;
  618. }
  619. Operator &Operator::SetInput(const std::string &dst_name, const ge::Operator &src_oprt, uint32_t index) {
  620. auto out_handler = src_oprt.GetOutput(index);
  621. GE_CHK_BOOL_EXEC(out_handler != nullptr, return *this, "out_handler is nullptr.");
  622. (void)SetInput(dst_name, out_handler);
  623. return *this;
  624. }
  625. Operator &Operator::SetInput(const char *dst_name, const ge::Operator &src_oprt, uint32_t index) {
  626. GE_CHK_BOOL_EXEC(dst_name != nullptr, return *this, "Dst name is nullptr.");
  627. auto out_handler = src_oprt.GetOutput(index);
  628. GE_CHK_BOOL_EXEC(out_handler != nullptr, return *this, "out_handler is nullptr.");
  629. std::string op_dst_name = dst_name;
  630. (void)SetInput(dst_name, out_handler);
  631. return *this;
  632. }
  633. Operator &Operator::AddControlInput(const Operator &src_oprt) {
  634. if (operator_impl_ == nullptr) {
  635. GELOGE(GRAPH_FAILED, "operator impl is nullptr.");
  636. return *this;
  637. }
  638. operator_impl_->AddControlInputImp(src_oprt);
  639. return *this;
  640. }
  641. graphStatus Operator::GetInputConstData(const string &dst_name, Tensor &data) const {
  642. GE_CHECK_NOTNULL(operator_impl_);
  643. graphStatus ret = operator_impl_->GetInputConstData(dst_name, data);
  644. if (ret != GRAPH_SUCCESS) {
  645. GELOGW("%s get input const data failed", dst_name.c_str());
  646. return ret;
  647. }
  648. return GRAPH_SUCCESS;
  649. }
  650. graphStatus Operator::GetInputConstData(const char *dst_name, Tensor &data) const {
  651. GE_CHECK_NOTNULL(dst_name);
  652. GE_CHECK_NOTNULL(operator_impl_);
  653. std::string op_dst_name = dst_name;
  654. graphStatus ret = operator_impl_->GetInputConstData(op_dst_name, data);
  655. if (ret != GRAPH_SUCCESS) {
  656. GELOGW("%s get input const data failed", op_dst_name.c_str());
  657. return ret;
  658. }
  659. return GRAPH_SUCCESS;
  660. }
  661. graphStatus Operator::GetInputConstDataOut(const string &dst_name, Tensor &data) const {
  662. GE_CHECK_NOTNULL(operator_impl_);
  663. if (operator_impl_->GetInputConstDataOut(dst_name, data) != GRAPH_SUCCESS) {
  664. GELOGE(GRAPH_FAILED, "%s get input const data out failed", dst_name.c_str());
  665. return GRAPH_FAILED;
  666. }
  667. return GRAPH_SUCCESS;
  668. }
  669. std::shared_ptr<const Node> Operator::GetNode() const {
  670. GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return nullptr, "operator impl is nullptr.");
  671. return operator_impl_->GetNode();
  672. }
  673. TensorDesc Operator::GetInputDesc(const std::string &name) const {
  674. GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return TensorDesc(), "operator impl is nullptr.");
  675. return TensorAdapter::GeTensorDesc2TensorDesc(operator_impl_->GetInputDesc(name));
  676. }
  677. TensorDesc Operator::GetInputDescByName(const char *name) const {
  678. GE_CHK_BOOL_EXEC(name != nullptr, return TensorDesc(), "Operator name is nullptr.");
  679. GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return TensorDesc(), "Operator impl is nullptr.");
  680. std::string op_name = name;
  681. return TensorAdapter::GeTensorDesc2TensorDesc(operator_impl_->GetInputDesc(op_name));
  682. }
  683. void Operator::SetInferenceContext(const InferenceContextPtr &inference_context) {
  684. GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return, "operator impl is nullptr.");
  685. operator_impl_->SetInferenceContext(inference_context);
  686. }
  687. InferenceContextPtr Operator::GetInferenceContext() const {
  688. GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return nullptr, "operator impl is nullptr.");
  689. return operator_impl_->GetInferenceContext();
  690. }
  691. TensorDesc Operator::GetInputDesc(uint32_t index) const {
  692. GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return TensorDesc(), "operator impl is nullptr.");
  693. return TensorAdapter::GeTensorDesc2TensorDesc(operator_impl_->GetInputDesc(index));
  694. }
  695. graphStatus Operator::TryGetInputDesc(const string &name, TensorDesc &tensor_desc) const {
  696. GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return GRAPH_FAILED, "operator impl is nullptr.");
  697. auto check = operator_impl_->InputIsSet(name);
  698. if (check)
  699. tensor_desc = TensorAdapter::GeTensorDesc2TensorDesc(operator_impl_->GetInputDesc(name));
  700. return check ? GRAPH_SUCCESS : GRAPH_FAILED;
  701. }
  702. graphStatus Operator::TryGetInputDesc(const char *name, TensorDesc &tensor_desc) const {
  703. GE_CHK_BOOL_EXEC(name != nullptr, return GRAPH_FAILED, "Operator name is nullptr.");
  704. GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return GRAPH_FAILED, "Operator impl is nullptr.");
  705. std::string op_name = name;
  706. auto check = operator_impl_->InputIsSet(op_name);
  707. if (check)
  708. tensor_desc = TensorAdapter::GeTensorDesc2TensorDesc(operator_impl_->GetInputDesc(op_name));
  709. return check ? GRAPH_SUCCESS : GRAPH_FAILED;
  710. }
  711. graphStatus Operator::UpdateInputDesc(const std::string &name, const ge::TensorDesc &tensor_desc) {
  712. GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return GRAPH_FAILED, "operator impl is nullptr.");
  713. return operator_impl_->UpdateInputDesc(name, TensorAdapter::TensorDesc2GeTensorDesc(tensor_desc));
  714. }
  715. graphStatus Operator::UpdateInputDesc(const char *name, const ge::TensorDesc &tensor_desc) {
  716. GE_CHK_BOOL_EXEC(name != nullptr, return GRAPH_FAILED, "Operator name is nullptr.");
  717. GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return GRAPH_FAILED, "Operator impl is nullptr.");
  718. std::string op_name = name;
  719. return operator_impl_->UpdateInputDesc(op_name, TensorAdapter::TensorDesc2GeTensorDesc(tensor_desc));
  720. }
  721. OutHandler Operator::GetOutput(const string &name) const {
  722. GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return nullptr, "operator impl is nullptr.");
  723. return operator_impl_->GetOutput(name);
  724. }
  725. OutHandler Operator::GetOutput(uint32_t index) const {
  726. GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return nullptr, "operator impl is nullptr.");
  727. return operator_impl_->GetOutput(index);
  728. }
  729. TensorDesc Operator::GetOutputDesc(const std::string &name) const {
  730. GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return TensorDesc(), "operator impl is nullptr.");
  731. return TensorAdapter::GeTensorDesc2TensorDesc(operator_impl_->GetOutputDesc(name));
  732. }
  733. TensorDesc Operator::GetOutputDescByName(const char *name) const {
  734. GE_CHK_BOOL_EXEC(name != nullptr, return TensorDesc(), "Operator name is nullptr.");
  735. GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return TensorDesc(), "Operator impl is nullptr.");
  736. std::string op_name = name;
  737. return TensorAdapter::GeTensorDesc2TensorDesc(operator_impl_->GetOutputDesc(op_name));
  738. }
  739. TensorDesc Operator::GetOutputDesc(uint32_t index) const {
  740. GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return TensorDesc(), "operator impl is nullptr.");
  741. return TensorAdapter::GeTensorDesc2TensorDesc(operator_impl_->GetOutputDesc(index));
  742. }
  743. graphStatus Operator::UpdateOutputDesc(const std::string &name, const ge::TensorDesc &tensor_desc) {
  744. GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return GRAPH_FAILED, "operator impl is nullptr.");
  745. return operator_impl_->UpdateOutputDesc(name, TensorAdapter::TensorDesc2GeTensorDesc(tensor_desc));
  746. }
  747. graphStatus Operator::UpdateOutputDesc(const char *name, const ge::TensorDesc &tensor_desc) {
  748. GE_CHK_BOOL_EXEC(name != nullptr, return GRAPH_FAILED, "Operator name is nullptr.");
  749. GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return GRAPH_FAILED, "Operator impl is nullptr.");
  750. std::string op_name = name;
  751. return operator_impl_->UpdateOutputDesc(op_name, TensorAdapter::TensorDesc2GeTensorDesc(tensor_desc));
  752. }
  753. TensorDesc Operator::GetDynamicInputDesc(const string &name, uint32_t index) const {
  754. GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return TensorDesc(), "operator impl is nullptr.");
  755. return TensorAdapter::GeTensorDesc2TensorDesc(operator_impl_->GetInputDesc(name + std::to_string(index)));
  756. }
  757. TensorDesc Operator::GetDynamicInputDesc(const char *name, uint32_t index) const {
  758. GE_CHK_BOOL_EXEC(name != nullptr, return TensorDesc(), "Operator name is nullptr.");
  759. GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return TensorDesc(), "Operator impl is nullptr.");
  760. std::string op_name = name;
  761. return TensorAdapter::GeTensorDesc2TensorDesc(operator_impl_->GetInputDesc(op_name + std::to_string(index)));
  762. }
  763. graphStatus Operator::UpdateDynamicInputDesc(const string &name, uint32_t index, const TensorDesc &tensor_desc) {
  764. GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return GRAPH_FAILED, "operator impl is nullptr.");
  765. return operator_impl_->UpdateInputDesc(name + std::to_string(index),
  766. TensorAdapter::TensorDesc2GeTensorDesc(tensor_desc));
  767. }
  768. graphStatus Operator::UpdateDynamicInputDesc(const char *name, uint32_t index, const TensorDesc &tensor_desc) {
  769. GE_CHK_BOOL_EXEC(name != nullptr, return GRAPH_FAILED, "Operator name is nullptr.");
  770. std::string op_name = name;
  771. return operator_impl_->UpdateInputDesc(op_name + std::to_string(index),
  772. TensorAdapter::TensorDesc2GeTensorDesc(tensor_desc));
  773. }
  774. TensorDesc Operator::GetDynamicOutputDesc(const string &name, uint32_t index) const {
  775. GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return TensorDesc(), "operator impl is nullptr.");
  776. return TensorAdapter::GeTensorDesc2TensorDesc(operator_impl_->GetOutputDesc(name + std::to_string(index)));
  777. }
  778. TensorDesc Operator::GetDynamicOutputDesc(const char *name, uint32_t index) const {
  779. GE_CHK_BOOL_EXEC(name != nullptr, return TensorDesc(), "Operator name is nullptr.");
  780. GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return TensorDesc(), "Operator impl is nullptr.");
  781. std::string op_name = name;
  782. return TensorAdapter::GeTensorDesc2TensorDesc(operator_impl_->GetOutputDesc(op_name + std::to_string(index)));
  783. }
  784. graphStatus Operator::UpdateDynamicOutputDesc(const string &name, uint32_t index, const TensorDesc &tensor_desc) {
  785. GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return GRAPH_FAILED, "operator impl is nullptr.");
  786. return operator_impl_->UpdateOutputDesc(name + std::to_string(index),
  787. TensorAdapter::TensorDesc2GeTensorDesc(tensor_desc));
  788. }
  789. graphStatus Operator::UpdateDynamicOutputDesc(const char *name, uint32_t index, const TensorDesc &tensor_desc) {
  790. GE_CHK_BOOL_EXEC(name != nullptr, return GRAPH_FAILED, "Operator name is nullptr.");
  791. GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return GRAPH_FAILED, "Operator impl is nullptr.");
  792. std::string op_name = name;
  793. return operator_impl_->UpdateOutputDesc(op_name + std::to_string(index),
  794. TensorAdapter::TensorDesc2GeTensorDesc(tensor_desc));
  795. }
  796. graphStatus Operator::InferShapeAndType() {
  797. GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return GRAPH_FAILED, "operator impl is nullptr.");
  798. GE_CHK_BOOL_EXEC(operator_impl_->GetOpDescImpl() != nullptr, return GRAPH_FAILED, "GetOpDescImpl is nullptr.");
  799. return operator_impl_->GetOpDescImpl()->CallInferFunc(*this);
  800. }
  801. graphStatus Operator::VerifyAllAttr(bool disable_common_verifier) {
  802. GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return GRAPH_FAILED, "operator impl is nullptr.");
  803. GE_CHK_BOOL_EXEC(operator_impl_->GetOpDescImpl() != nullptr, return GRAPH_FAILED, "GetOpDescImpl is nullptr.");
  804. if (!disable_common_verifier && (graphStatus)Operator::VerifyAll() == GRAPH_FAILED) {
  805. return GRAPH_FAILED;
  806. } else {
  807. return (graphStatus)operator_impl_->GetOpDescImpl()->OpVerify();
  808. }
  809. }
  810. GE_FUNC_HOST_VISIBILITY size_t Operator::GetInputsSize() const {
  811. GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return 0, "OperatorImpl_ is nullptr");
  812. return operator_impl_->GetInputsSize();
  813. }
  814. GE_FUNC_HOST_VISIBILITY size_t Operator::GetOutputsSize() const {
  815. GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return 0, "OperatorImpl_ is nullptr");
  816. return operator_impl_->GetOutputsSize();
  817. }
  818. // According to op get the attrs name and type
  819. namespace {
  820. const std::map<GeAttrValue::ValueType, std::string> kAttrTypesMap = {
  821. {GeAttrValue::VT_NONE, "VT_STRING"},
  822. {GeAttrValue::VT_STRING, "VT_STRING"},
  823. {GeAttrValue::VT_FLOAT, "VT_FLOAT"},
  824. {GeAttrValue::VT_BOOL, "VT_BOOL"},
  825. {GeAttrValue::VT_INT, "VT_INT"},
  826. {GeAttrValue::VT_TENSOR_DESC, "VT_TENSOR_DESC"},
  827. {GeAttrValue::VT_TENSOR, "VT_TENSOR"},
  828. {GeAttrValue::VT_BYTES, "VT_BYTES"},
  829. {GeAttrValue::VT_GRAPH, "VT_GRAPH"},
  830. {GeAttrValue::VT_NAMED_ATTRS, "VT_NAMED_ATTRS"},
  831. {GeAttrValue::VT_LIST_LIST_INT, "VT_LIST_LIST_INT"},
  832. {GeAttrValue::VT_DATA_TYPE, "VT_DATA_TYPE"},
  833. {GeAttrValue::VT_LIST_BASE, "VT_LIST_BASE"},
  834. {GeAttrValue::VT_LIST_STRING, "VT_LIST_STRING"},
  835. {GeAttrValue::VT_LIST_FLOAT, "VT_LIST_FLOAT"},
  836. {GeAttrValue::VT_LIST_BOOL, "VT_LIST_BOOL"},
  837. {GeAttrValue::VT_LIST_INT, "VT_LIST_INT"},
  838. {GeAttrValue::VT_LIST_TENSOR_DESC, "VT_LIST_TENSOR_DESC"},
  839. {GeAttrValue::VT_LIST_TENSOR, "VT_LIST_TENSOR"},
  840. {GeAttrValue::VT_LIST_BYTES, "VT_LIST_BYTES"},
  841. {GeAttrValue::VT_GRAPH, "VT_GRAPH"},
  842. {GeAttrValue::VT_LIST_NAMED_ATTRS, "VT_LIST_NAMED_ATTRS"},
  843. {GeAttrValue::VT_LIST_DATA_TYPE, "VT_LIST_DATA_TYPE"},
  844. };
  845. } // namespace
  846. const std::map<std::string, std::string> Operator::GetAllAttrNamesAndTypes() const {
  847. std::map<std::string, std::string> attr_types;
  848. GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return attr_types, "operator impl is nullptr.");
  849. GE_CHK_BOOL_EXEC(operator_impl_->GetOpDescImpl() != nullptr, return attr_types, "GetOpDescImpl is nullptr.");
  850. std::map<string, GeAttrValue> attr_map = operator_impl_->GetOpDescImpl()->GetAllAttrs();
  851. map<string, GeAttrValue>::iterator iter;
  852. for (iter = attr_map.begin(); iter != attr_map.end(); ++iter) {
  853. string name = iter->first;
  854. GeAttrValue attr_value = iter->second;
  855. GeAttrValue::ValueType type = attr_value.GetValueType();
  856. auto iter2 = kAttrTypesMap.find(type);
  857. if (iter2 != kAttrTypesMap.end()) {
  858. attr_types[name] = iter2->second;
  859. }
  860. }
  861. return attr_types;
  862. }
  863. graphStatus Operator::GetAllAttrNamesAndTypes(std::map<AscendString, AscendString> &attr_name_types) const {
  864. GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return GRAPH_FAILED, "Operator impl is nullptr.");
  865. GE_CHK_BOOL_EXEC(operator_impl_->GetOpDescImpl() != nullptr, return GRAPH_FAILED, "GetOpDescImpl is nullptr.");
  866. std::map<string, GeAttrValue> attr_map = operator_impl_->GetOpDescImpl()->GetAllAttrs();
  867. map<string, GeAttrValue>::iterator iter;
  868. for (iter = attr_map.begin(); iter != attr_map.end(); ++iter) {
  869. string name = iter->first;
  870. GeAttrValue attr_value = iter->second;
  871. GeAttrValue::ValueType type = attr_value.GetValueType();
  872. auto iter2 = kAttrTypesMap.find(type);
  873. if (iter2 != kAttrTypesMap.end()) {
  874. AscendString temp(name.c_str());
  875. attr_name_types[temp] = AscendString(iter2->second.c_str());
  876. }
  877. }
  878. return GRAPH_SUCCESS;
  879. }
  880. void Operator::InputRegister(const string &name) {
  881. GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return, "operator impl is nullptr.");
  882. GE_CHK_BOOL_EXEC(operator_impl_->GetOpDescImpl() != nullptr, return, "GetOpDescImpl is nullptr.");
  883. (void)operator_impl_->GetOpDescImpl()->AddInputDesc(name, GeTensorDesc());
  884. }
  885. void Operator::OptionalInputRegister(const string &name) {
  886. GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return, "operator impl is nullptr.");
  887. GE_CHK_BOOL_EXEC(operator_impl_->GetOpDescImpl() != nullptr, return, "GetOpDescImpl is nullptr.");
  888. // [No need to verify return value]
  889. (void)operator_impl_->GetOpDescImpl()->AddOptionalInputDesc(name,
  890. GeTensorDesc(GeShape(), FORMAT_RESERVED, DT_UNDEFINED));
  891. }
  892. void Operator::InferFuncRegister(const std::function<graphStatus(Operator &)> &func) {
  893. GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return, "operator impl is nullptr.");
  894. GE_CHK_BOOL_EXEC(operator_impl_->GetOpDescImpl() != nullptr, return, "GetOpDescImpl is nullptr.");
  895. // [No need to verify return value]
  896. (void)operator_impl_->GetOpDescImpl()->AddInferFunc(func);
  897. }
  898. void Operator::InferFormatFuncRegister(const std::function<graphStatus(Operator &)> &func) {
  899. GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return, "operator impl is nullptr.");
  900. GE_CHK_BOOL_EXEC(operator_impl_->GetOpDescImpl() != nullptr, return, "GetOpDescImpl is nullptr.");
  901. // [No need to verify return value]
  902. (void)operator_impl_->GetOpDescImpl()->AddInferFormatFunc(func);
  903. }
  904. void Operator::VerifierFuncRegister(const std::function<graphStatus(Operator &)> &func) {
  905. GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return, "operator impl is nullptr.");
  906. GE_CHK_BOOL_EXEC(operator_impl_->GetOpDescImpl() != nullptr, return, "GetOpDescImpl is nullptr.");
  907. // [No need to verify return value]
  908. (void)operator_impl_->GetOpDescImpl()->AddVerifierFunc(func);
  909. }
  910. void Operator::OutputRegister(const string &name) {
  911. GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return, "operator impl is nullptr.");
  912. GE_CHK_BOOL_EXEC(operator_impl_->GetOpDescImpl() != nullptr, return, "GetOpDescImpl is nullptr.");
  913. // [No need to verify return value]
  914. (void)operator_impl_->GetOpDescImpl()->AddOutputDesc(name, GeTensorDesc());
  915. }
  916. void Operator::DynamicInputRegister(const string &name, const unsigned int num, bool is_push_back) {
  917. GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return, "operator impl is nullptr.");
  918. GE_CHK_BOOL_EXEC(operator_impl_->GetOpDescImpl() != nullptr, return, "GetOpDescImpl is nullptr.");
  919. GE_CHK_BOOL_EXEC(AttrUtils::SetInt(operator_impl_->GetOpDescImpl(), DYNAMIC_INPUT_TD_NUM(name), num), return,
  920. "set int failed");
  921. (void)operator_impl_->GetOpDescImpl()->AddDynamicInputDesc(name, num, is_push_back);
  922. }
  923. void Operator::DynamicInputRegisterByIndex(const string &name, const unsigned int num, size_t index) {
  924. GE_CHK_BOOL_EXEC(!!operator_impl_, return, "operator impl is nullptr.");
  925. GE_CHK_BOOL_EXEC(nullptr != operator_impl_->GetOpDescImpl(), return, "GetOpDescImpl is nullptr.");
  926. operator_impl_->GetOpDescImpl()->AddDynamicInputDescByIndex(name, num, index);
  927. }
  928. int Operator::GetDynamicInputNum(const string &name) const {
  929. GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return 0, "operator impl is nullptr.");
  930. GE_CHK_BOOL_EXEC(operator_impl_->GetOpDescImpl() != nullptr, return 0, "GetOpDescImpl is nullptr.");
  931. int num = 0;
  932. GE_CHK_BOOL_EXEC(AttrUtils::GetInt(operator_impl_->GetOpDescImpl(), DYNAMIC_INPUT_TD_NUM(name), num), return num,
  933. "Get %s int failed", name.c_str());
  934. return num;
  935. }
  936. int Operator::GetDynamicInputNum(const char *name) const {
  937. GE_CHK_BOOL_EXEC(name != nullptr, return 0, "Operator name is nullptr.");
  938. GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return 0, "Operator impl is nullptr.");
  939. GE_CHK_BOOL_EXEC(operator_impl_->GetOpDescImpl() != nullptr, return 0, "GetOpDescImpl is nullptr.");
  940. string op_name = name;
  941. int num = 0;
  942. GE_CHK_BOOL_EXEC(AttrUtils::GetInt(operator_impl_->GetOpDescImpl(), DYNAMIC_INPUT_TD_NUM(op_name), num), return num,
  943. "Get %s int failed", op_name.c_str());
  944. return num;
  945. }
  946. void Operator::DynamicOutputRegister(const string &name, const unsigned int num, bool is_push_back) {
  947. GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return, "operator impl is nullptr.");
  948. GE_CHK_BOOL_EXEC(operator_impl_->GetOpDescImpl() != nullptr, return, "GetOpDescImpl is nullptr.");
  949. GE_CHK_BOOL_EXEC(AttrUtils::SetInt(operator_impl_->GetOpDescImpl(), DYNAMIC_OUTPUT_TD_NUM(name), num), return,
  950. "Set %s int failed", name.c_str());
  951. (void)operator_impl_->GetOpDescImpl()->AddDynamicOutputDesc(name, num, is_push_back);
  952. }
  953. int Operator::GetDynamicOutputNum(const string &name) const {
  954. GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return 0, "operator impl is nullptr.");
  955. GE_CHK_BOOL_EXEC(operator_impl_->GetOpDescImpl() != nullptr, return 0, "GetOpDescImpl is nullptr.");
  956. int num = 0;
  957. GE_CHK_BOOL_EXEC(AttrUtils::GetInt(operator_impl_->GetOpDescImpl(), DYNAMIC_OUTPUT_TD_NUM(name), num), return num,
  958. "Get %s int failed", name.c_str());
  959. return num;
  960. }
  961. int Operator::GetDynamicOutputNum(const char *name) const {
  962. GE_CHK_BOOL_EXEC(name != nullptr, return 0, "Operator name is nullptr.");
  963. GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return 0, "Operator impl is nullptr.");
  964. GE_CHK_BOOL_EXEC(operator_impl_->GetOpDescImpl() != nullptr, return 0, "GetOpDescImpl is nullptr.");
  965. std::string op_name = name;
  966. int num = 0;
  967. GE_CHK_BOOL_EXEC(AttrUtils::GetInt(operator_impl_->GetOpDescImpl(), DYNAMIC_OUTPUT_TD_NUM(op_name), num), return num,
  968. "Get %s int failed", op_name.c_str());
  969. return num;
  970. }
  971. void Operator::RequiredAttrRegister(const string &name) {
  972. GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return, "operator impl is nullptr.");
  973. GE_CHK_BOOL_EXEC(operator_impl_->GetOpDescImpl() != nullptr, return, "GetOpDescImpl is nullptr.");
  974. operator_impl_->GetOpDescImpl()->AddRequiredAttr(name);
  975. }
  976. graphStatus Operator::VerifyAll() {
  977. GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return GRAPH_FAILED, "operator impl is nullptr.");
  978. GE_CHK_BOOL_EXEC(operator_impl_->GetOpDescImpl() != nullptr, return GRAPH_FAILED, "GetOpDescImpl is nullptr.");
  979. // Check all inputs defined
  980. for (const string &iname : operator_impl_->GetOpDescImpl()->GetAllInputNames()) {
  981. GE_CHK_BOOL_RET_STATUS(operator_impl_->GetOpDescImpl()->IsOptionalInput(iname) || operator_impl_->InputIsSet(iname),
  982. GRAPH_FAILED, "operator input %s is not linked.", iname.c_str());
  983. vector<int64_t> ishape = operator_impl_->GetOpDescImpl()->GetInputDesc(iname).GetShape().GetDims();
  984. for (int64_t dim : ishape) {
  985. GE_CHK_BOOL_RET_STATUS(dim > 0, GRAPH_FAILED, "operator input %s shape contains negative or zero dimension.",
  986. iname.c_str());
  987. }
  988. }
  989. // Check all attributes defined
  990. const auto all_attributes = operator_impl_->GetOpDescImpl()->GetAllAttrs();
  991. for (const auto &name : operator_impl_->GetOpDescImpl()->GetAllAttrNames()) {
  992. GE_CHK_BOOL_RET_STATUS(all_attributes.find(name) != all_attributes.end(), GRAPH_FAILED,
  993. "operator attribute %s is empty.", name.c_str());
  994. }
  995. return GRAPH_SUCCESS;
  996. }
  997. string Operator::GetOpType() const {
  998. GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return "Data", "operator impl is nullptr.");
  999. return OperatorImpl::GetOpDesc(*this)->GetType();
  1000. }
  1001. graphStatus Operator::GetOpType(AscendString &type) const {
  1002. GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return GRAPH_FAILED, "Operator impl is nullptr.");
  1003. std::string op_type = OperatorImpl::GetOpDesc(*this)->GetType();
  1004. type = op_type.c_str();
  1005. return GRAPH_SUCCESS;
  1006. }
  1007. Operator &Operator::SetInput(const std::string &dst_name, uint32_t dst_index, const ge::Operator &src_oprt) {
  1008. string dynamic_dst_name = DYNAMIN_INPUT_NAME(dst_name, dst_index);
  1009. return SetInput(dynamic_dst_name, src_oprt);
  1010. }
  1011. Operator &Operator::SetInput(const std::string &dst_name, uint32_t dst_index, const ge::Operator &src_oprt,
  1012. const std::string &name) {
  1013. string dynamic_dst_name = DYNAMIN_INPUT_NAME(dst_name, dst_index);
  1014. return SetInput(dynamic_dst_name, src_oprt, name);
  1015. }
  1016. OperatorImplPtr Operator::GetOperatorImplPtr() const { return operator_impl_; }
  1017. #define OP_ATTR_SET_IMP(ArgType, AttrUtilsFun) \
  1018. Operator &Operator::SetAttr(const string &name, ArgType attr_value) { \
  1019. if (operator_impl_ == nullptr || operator_impl_->GetOpDescImpl() == nullptr) { \
  1020. GELOGE(GRAPH_FAILED, "operator impl is nullptr, name %s.", name.c_str()); \
  1021. return *this; \
  1022. } \
  1023. if (!AttrUtils::Set##AttrUtilsFun(operator_impl_->GetOpDescImpl(), name, attr_value)) { \
  1024. GELOGW("set attr name %s failed.", name.c_str()); \
  1025. } \
  1026. return *this; \
  1027. } \
  1028. Operator &Operator::SetAttr(const char *name, ArgType attr_value) { \
  1029. if (name == nullptr) { \
  1030. GELOGE(GRAPH_FAILED, "operator attr name is nullptr."); \
  1031. return *this; \
  1032. } \
  1033. std::string op_name = name; \
  1034. if (operator_impl_ == nullptr || operator_impl_->GetOpDescImpl() == nullptr) { \
  1035. GELOGE(GRAPH_FAILED, "operator impl is nullptr, name %s.", op_name.c_str()); \
  1036. return *this; \
  1037. } \
  1038. if (!AttrUtils::Set##AttrUtilsFun(operator_impl_->GetOpDescImpl(), op_name, attr_value)) { \
  1039. GELOGW("set attr name %s failed.", op_name.c_str()); \
  1040. } \
  1041. return *this; \
  1042. }
  1043. #define OP_ATTR_GET_IMP(ArgType, AttrUtilsFun) \
  1044. graphStatus Operator::GetAttr(const string &name, ArgType attr_value) const { \
  1045. if (operator_impl_ == nullptr || operator_impl_->GetOpDescImpl() == nullptr) { \
  1046. GELOGE(GRAPH_FAILED, "operator impl is nullptr, name %s.", name.c_str()); \
  1047. return GRAPH_FAILED; \
  1048. } \
  1049. if (!AttrUtils::Get##AttrUtilsFun(operator_impl_->GetOpDescImpl(), name, attr_value)) { \
  1050. GELOGW("get attr name %s failed.", name.c_str()); \
  1051. return GRAPH_FAILED; \
  1052. } \
  1053. return GRAPH_SUCCESS; \
  1054. } \
  1055. graphStatus Operator::GetAttr(const char *name, ArgType attr_value) const { \
  1056. if (name == nullptr) { \
  1057. GELOGE(GRAPH_FAILED, "operator attr name is nullptr."); \
  1058. return GRAPH_FAILED; \
  1059. } \
  1060. if (operator_impl_ == nullptr || operator_impl_->GetOpDescImpl() == nullptr) { \
  1061. GELOGE(GRAPH_FAILED, "operator impl is nullptr, name %s.", name); \
  1062. return GRAPH_FAILED; \
  1063. } \
  1064. std::string op_name = name; \
  1065. if (!AttrUtils::Get##AttrUtilsFun(operator_impl_->GetOpDescImpl(), op_name, attr_value)) { \
  1066. GELOGW("get attr name %s failed.", op_name.c_str()); \
  1067. return GRAPH_FAILED; \
  1068. } \
  1069. return GRAPH_SUCCESS; \
  1070. }
  1071. void Operator::BreakConnect() const {
  1072. if (operator_impl_ == nullptr) {
  1073. GELOGW("operator impl is nullptr.");
  1074. return;
  1075. }
  1076. operator_impl_->ClearInputLinks();
  1077. operator_impl_->ClearOutputLinks();
  1078. OperatorKeeper::GetInstance().CheckOutOperator(operator_impl_);
  1079. }
  1080. #define OP_ATTR_REG_IMP(ArgType, AttrUtilsFun) \
  1081. void Operator::AttrRegister(const string &name, ArgType attr_value) { \
  1082. if (operator_impl_ == nullptr || operator_impl_->GetOpDescImpl() == nullptr) { \
  1083. GELOGE(GRAPH_FAILED, "operator impl is nullptr, name %s.", name.c_str()); \
  1084. return; \
  1085. } \
  1086. if (!AttrUtils::Set##AttrUtilsFun(operator_impl_->GetOpDescImpl(), name, attr_value)) { \
  1087. GELOGW("reg attr name %s failed.", name.c_str()); \
  1088. } \
  1089. } // lint !e665
  1090. OP_ATTR_SET_IMP(int64_t, Int)
  1091. OP_ATTR_SET_IMP(int32_t, Int)
  1092. OP_ATTR_SET_IMP(uint32_t, Int)
  1093. OP_ATTR_GET_IMP(int64_t &, Int)
  1094. OP_ATTR_GET_IMP(int32_t &, Int)
  1095. OP_ATTR_GET_IMP(uint32_t &, Int)
  1096. OP_ATTR_SET_IMP(const vector<int64_t> &, ListInt)
  1097. OP_ATTR_SET_IMP(const vector<int32_t> &, ListInt)
  1098. OP_ATTR_SET_IMP(const vector<uint32_t> &, ListInt)
  1099. OP_ATTR_SET_IMP(std::initializer_list<int64_t> &&, ListInt)
  1100. OP_ATTR_GET_IMP(vector<int64_t> &, ListInt)
  1101. OP_ATTR_GET_IMP(vector<int32_t> &, ListInt)
  1102. OP_ATTR_GET_IMP(vector<uint32_t> &, ListInt)
  1103. OP_ATTR_GET_IMP(vector<vector<int64_t>> &, ListListInt)
  1104. OP_ATTR_SET_IMP(const vector<vector<int64_t>> &, ListListInt)
  1105. OP_ATTR_SET_IMP(float, Float)
  1106. OP_ATTR_GET_IMP(float &, Float)
  1107. OP_ATTR_SET_IMP(const vector<float> &, ListFloat)
  1108. OP_ATTR_GET_IMP(vector<float> &, ListFloat) // lint !e665
  1109. OP_ATTR_SET_IMP(bool, Bool)
  1110. OP_ATTR_GET_IMP(bool &, Bool)
  1111. OP_ATTR_SET_IMP(const vector<bool> &, ListBool)
  1112. OP_ATTR_GET_IMP(vector<bool> &, ListBool) // lint !e665
  1113. OP_ATTR_SET_IMP(const GeAttrValue::NAMED_ATTRS &, NamedAttrs)
  1114. OP_ATTR_GET_IMP(GeAttrValue::NAMED_ATTRS &, NamedAttrs)
  1115. OP_ATTR_SET_IMP(const vector<GeAttrValue::NAMED_ATTRS> &, ListNamedAttrs)
  1116. OP_ATTR_GET_IMP(vector<GeAttrValue::NAMED_ATTRS> &, ListNamedAttrs) // lint !e665
  1117. OP_ATTR_REG_IMP(int64_t, Int)
  1118. OP_ATTR_REG_IMP(const vector<int64_t> &, ListInt)
  1119. OP_ATTR_REG_IMP(float, Float)
  1120. OP_ATTR_REG_IMP(const vector<float> &, ListFloat)
  1121. OP_ATTR_REG_IMP(const string &, Str)
  1122. OP_ATTR_REG_IMP(const vector<string> &, ListStr)
  1123. OP_ATTR_REG_IMP(bool, Bool)
  1124. OP_ATTR_REG_IMP(const vector<bool> &, ListBool)
  1125. OP_ATTR_REG_IMP(const vector<vector<int64_t>> &, ListListInt)
  1126. OP_ATTR_REG_IMP(const GeAttrValue::NAMED_ATTRS &, NamedAttrs)
  1127. OP_ATTR_REG_IMP(const vector<GeAttrValue::NAMED_ATTRS> &, ListNamedAttrs)
  1128. #undef OP_ATTR_SET_IMP
  1129. #undef OP_ATTR_GET_IMP
  1130. #undef OP_ATTR_REG_IMP
  1131. void Operator::AttrRegister(const string &name, const AscendString &attr_value) {
  1132. if (attr_value.GetString() == nullptr) {
  1133. GELOGE(GRAPH_FAILED, "Attr %s register param is invalid.", name.c_str());
  1134. return;
  1135. }
  1136. if (operator_impl_ == nullptr || operator_impl_->GetOpDescImpl() == nullptr) {
  1137. GELOGE(GRAPH_FAILED, "Operator impl is nullptr, name %s.", name.c_str());
  1138. return;
  1139. }
  1140. std::string str_attr_value = attr_value.GetString();
  1141. if (!AttrUtils::SetStr(operator_impl_->GetOpDescImpl(), name, str_attr_value)) {
  1142. GELOGW("Reg attr name %s failed.", name.c_str());
  1143. }
  1144. }
  1145. void Operator::AttrRegister(const string &name, const std::vector<AscendString> &attr_value) {
  1146. std::vector<std::string> str_attr_values;
  1147. for (auto &val : attr_value) {
  1148. if (val.GetString() == nullptr) {
  1149. GELOGE(GRAPH_FAILED, "Attr %s register value is invalid.", name.c_str());
  1150. return;
  1151. }
  1152. str_attr_values.emplace_back(val.GetString());
  1153. }
  1154. if (operator_impl_ == nullptr || operator_impl_->GetOpDescImpl() == nullptr) {
  1155. GELOGE(GRAPH_FAILED, "Operator impl is nullptr, name %s.", name.c_str());
  1156. return;
  1157. }
  1158. if (!AttrUtils::SetListStr(operator_impl_->GetOpDescImpl(), name, str_attr_values)) {
  1159. GELOGW("Reg attr name %s failed.", name.c_str());
  1160. }
  1161. }
  1162. Operator &Operator::SetAttr(const string &name, const string &attr_value) {
  1163. if (operator_impl_ == nullptr || operator_impl_->GetOpDescImpl() == nullptr) {
  1164. GELOGE(GRAPH_FAILED, "Operator impl is nullptr, name %s.", name.c_str());
  1165. return *this;
  1166. }
  1167. if (!AttrUtils::SetStr(operator_impl_->GetOpDescImpl(), name, attr_value)) {
  1168. GELOGW("Set attr name %s failed.", name.c_str());
  1169. }
  1170. return *this;
  1171. }
  1172. graphStatus Operator::GetAttr(const string &name, string &attr_value) const {
  1173. if (operator_impl_ == nullptr || operator_impl_->GetOpDescImpl() == nullptr) {
  1174. GELOGE(GRAPH_FAILED, "Operator impl is nullptr, name %s.", name.c_str());
  1175. return GRAPH_FAILED;
  1176. }
  1177. if (!AttrUtils::GetStr(operator_impl_->GetOpDescImpl(), name, attr_value)) {
  1178. GELOGW("Get attr name %s failed.", name.c_str());
  1179. return GRAPH_FAILED;
  1180. }
  1181. return GRAPH_SUCCESS;
  1182. }
  1183. Operator &Operator::SetAttr(const string &name, const std::vector<string> &attr_value) {
  1184. if (operator_impl_ == nullptr || operator_impl_->GetOpDescImpl() == nullptr) {
  1185. GELOGE(GRAPH_FAILED, "Operator impl is nullptr, name %s.", name.c_str());
  1186. return *this;
  1187. }
  1188. if (!AttrUtils::SetListStr(operator_impl_->GetOpDescImpl(), name, attr_value)) {
  1189. GELOGW("Set attr name %s failed.", name.c_str());
  1190. }
  1191. return *this;
  1192. }
  1193. graphStatus Operator::GetAttr(const string &name, std::vector<string> &attr_value) const {
  1194. if (operator_impl_ == nullptr || operator_impl_->GetOpDescImpl() == nullptr) {
  1195. GELOGE(GRAPH_FAILED, "Operator impl is nullptr, name %s.", name.c_str());
  1196. return GRAPH_FAILED;
  1197. }
  1198. if (!AttrUtils::GetListStr(operator_impl_->GetOpDescImpl(), name, attr_value)) {
  1199. GELOGW("Get attr name %s failed.", name.c_str());
  1200. return GRAPH_FAILED;
  1201. }
  1202. return GRAPH_SUCCESS;
  1203. }
  1204. Operator &Operator::SetAttr(const char *name, const char *attr_value) {
  1205. if (name == nullptr || attr_value == nullptr) {
  1206. GELOGE(GRAPH_FAILED, "Operator input parameters is nullptr.");
  1207. return *this;
  1208. }
  1209. if (operator_impl_ == nullptr || operator_impl_->GetOpDescImpl() == nullptr) {
  1210. GELOGE(GRAPH_FAILED, "Operator impl is nullptr, name %s.", name);
  1211. return *this;
  1212. }
  1213. std::string op_name = name;
  1214. std::string op_attr_value = attr_value;
  1215. if (!AttrUtils::SetStr(operator_impl_->GetOpDescImpl(), op_name, op_attr_value)) {
  1216. GELOGW("Set attr name %s failed.", op_name.c_str());
  1217. }
  1218. return *this;
  1219. }
  1220. Operator &Operator::SetAttr(const char *name, const AscendString &attr_value) {
  1221. if (name == nullptr || attr_value.GetString() == nullptr) {
  1222. GELOGE(GRAPH_FAILED, "Operator input parameters is nullptr.");
  1223. return *this;
  1224. }
  1225. if (operator_impl_ == nullptr || operator_impl_->GetOpDescImpl() == nullptr) {
  1226. GELOGE(GRAPH_FAILED, "Operator impl is nullptr, name %s.", name);
  1227. return *this;
  1228. }
  1229. std::string op_name = name;
  1230. std::string op_attr_value = attr_value.GetString();
  1231. if (!AttrUtils::SetStr(operator_impl_->GetOpDescImpl(), op_name, op_attr_value)) {
  1232. GELOGW("Set attr name %s failed.", op_name.c_str());
  1233. }
  1234. return *this;
  1235. }
  1236. graphStatus Operator::GetAttr(const char *name, AscendString &attr_value) const {
  1237. if (name == nullptr) {
  1238. GELOGE(GRAPH_FAILED, "Operator input parameters is nullptr.");
  1239. return GRAPH_FAILED;
  1240. }
  1241. if (operator_impl_ == nullptr || operator_impl_->GetOpDescImpl() == nullptr) {
  1242. GELOGE(GRAPH_FAILED, "Operator impl is nullptr, name %s.", name);
  1243. return GRAPH_FAILED;
  1244. }
  1245. std::string op_name = name;
  1246. std::string op_attr_value;
  1247. if (!AttrUtils::GetStr(operator_impl_->GetOpDescImpl(), op_name, op_attr_value)) {
  1248. GELOGW("Get attr name %s failed.", op_name.c_str());
  1249. return GRAPH_FAILED;
  1250. }
  1251. attr_value = AscendString(op_attr_value.c_str());
  1252. return GRAPH_SUCCESS;
  1253. }
  1254. Operator &Operator::SetAttr(const char *name, const std::vector<AscendString> &attr_values) {
  1255. if (name == nullptr) {
  1256. GELOGE(GRAPH_FAILED, "Operator name is nullptr.");
  1257. return *this;
  1258. }
  1259. std::vector<std::string> op_attr_values;
  1260. for (auto &attr_value : attr_values) {
  1261. if (attr_value.GetString() == nullptr) {
  1262. GELOGE(GRAPH_FAILED, "Operator ascend string name is nullptr.");
  1263. return *this;
  1264. }
  1265. op_attr_values.emplace_back(attr_value.GetString());
  1266. }
  1267. if (operator_impl_ == nullptr || operator_impl_->GetOpDescImpl() == nullptr) {
  1268. GELOGE(GRAPH_FAILED, "Operator impl is nullptr, name %s.", name);
  1269. return *this;
  1270. }
  1271. std::string op_name = name;
  1272. if (!AttrUtils::SetListStr(operator_impl_->GetOpDescImpl(), op_name, op_attr_values)) {
  1273. GELOGW("Set attr name %s failed.", op_name.c_str());
  1274. }
  1275. return *this;
  1276. }
  1277. graphStatus Operator::GetAttr(const char *name, std::vector<AscendString> &attr_value) const {
  1278. if (name == nullptr) {
  1279. GELOGE(GRAPH_FAILED, "Operator name is nullptr.");
  1280. return GRAPH_FAILED;
  1281. }
  1282. if (operator_impl_ == nullptr || operator_impl_->GetOpDescImpl() == nullptr) {
  1283. GELOGE(GRAPH_FAILED, "Operator impl is nullptr, name %s.", name);
  1284. return GRAPH_FAILED;
  1285. }
  1286. std::string op_name = name;
  1287. std::vector<std::string> op_attr_values;
  1288. if (!AttrUtils::GetListStr(operator_impl_->GetOpDescImpl(), op_name, op_attr_values)) {
  1289. GELOGW("Get attr name %s failed.", op_name.c_str());
  1290. return GRAPH_FAILED;
  1291. }
  1292. for (auto &op_attr_value : op_attr_values) {
  1293. attr_value.emplace_back(AscendString(op_attr_value.c_str()));
  1294. }
  1295. return GRAPH_SUCCESS;
  1296. }
  1297. Operator &Operator::SetAttr(const string &name, const Tensor &attr_value) {
  1298. if (operator_impl_ == nullptr || operator_impl_->GetOpDescImpl() == nullptr) {
  1299. GELOGE(GRAPH_FAILED, "operator impl is nullptr, name %s.", name.c_str());
  1300. return *this;
  1301. }
  1302. GeTensor tensor = TensorAdapter::AsGeTensor(attr_value);
  1303. if (!AttrUtils::SetTensor(operator_impl_->GetOpDescImpl(), name, tensor)) {
  1304. GELOGW("set attr name %s failed.", name.c_str());
  1305. }
  1306. return *this;
  1307. }
  1308. Operator &Operator::SetAttr(const char *name, const Tensor &attr_value) {
  1309. if (name == nullptr) {
  1310. GELOGE(GRAPH_FAILED, "Operator name is nullptr.");
  1311. return *this;
  1312. }
  1313. if (operator_impl_ == nullptr || operator_impl_->GetOpDescImpl() == nullptr) {
  1314. GELOGE(GRAPH_FAILED, "operator impl is nullptr, name %s.", name);
  1315. return *this;
  1316. }
  1317. std::string op_name = name;
  1318. GeTensor tensor = TensorAdapter::AsGeTensor(attr_value);
  1319. if (!AttrUtils::SetTensor(operator_impl_->GetOpDescImpl(), op_name, tensor)) {
  1320. GELOGW("set attr name %s failed.", op_name.c_str());
  1321. }
  1322. return *this;
  1323. }
  1324. Operator &Operator::SetAttr(const string &name, const vector<Tensor> &attr_value) {
  1325. if (operator_impl_ == nullptr || operator_impl_->GetOpDescImpl() == nullptr) {
  1326. GELOGE(GRAPH_FAILED, "operator impl is nullptr, name %s.", name.c_str());
  1327. return *this;
  1328. }
  1329. vector<GeTensor> val_list;
  1330. for (const auto &item : attr_value) {
  1331. auto tensor = TensorAdapter::AsGeTensor(item);
  1332. val_list.push_back(tensor);
  1333. }
  1334. if (!AttrUtils::SetListTensor(operator_impl_->GetOpDescImpl(), name, val_list)) {
  1335. GELOGW("set attr name %s failed.", name.c_str());
  1336. }
  1337. return *this;
  1338. }
  1339. Operator &Operator::SetAttr(const char *name, const vector<Tensor> &attr_value) {
  1340. if (name == nullptr) {
  1341. GELOGE(GRAPH_FAILED, "Operator name is nullptr.");
  1342. return *this;
  1343. }
  1344. if (operator_impl_ == nullptr || operator_impl_->GetOpDescImpl() == nullptr) {
  1345. GELOGE(GRAPH_FAILED, "Operator impl is nullptr, name %s.", name);
  1346. return *this;
  1347. }
  1348. std::string op_name = name;
  1349. vector<GeTensor> val_list;
  1350. for (const auto &item : attr_value) {
  1351. auto tensor = TensorAdapter::AsGeTensor(item);
  1352. val_list.push_back(tensor);
  1353. }
  1354. if (!AttrUtils::SetListTensor(operator_impl_->GetOpDescImpl(), op_name, val_list)) {
  1355. GELOGW("Set attr name %s failed.", op_name.c_str());
  1356. }
  1357. return *this;
  1358. }
  1359. graphStatus Operator::GetAttr(const string &name, Tensor &attr_value) const {
  1360. if (operator_impl_ == nullptr || operator_impl_->GetOpDescImpl() == nullptr) {
  1361. GELOGE(GRAPH_FAILED, "operator impl is nullptr, name %s.", name.c_str());
  1362. return GRAPH_FAILED;
  1363. }
  1364. ConstGeTensorPtr tensor;
  1365. if (!AttrUtils::GetTensor(operator_impl_->GetOpDescImpl(), name, tensor)) {
  1366. GELOGW("get attr name %s failed.", name.c_str());
  1367. return GRAPH_FAILED;
  1368. }
  1369. attr_value = TensorAdapter::GeTensor2Tensor(tensor);
  1370. return GRAPH_SUCCESS;
  1371. }
  1372. graphStatus Operator::GetAttr(const char *name, Tensor &attr_value) const {
  1373. if (name == nullptr) {
  1374. GELOGE(GRAPH_FAILED, "Operator name is nullptr.");
  1375. return GRAPH_FAILED;
  1376. }
  1377. if (operator_impl_ == nullptr || operator_impl_->GetOpDescImpl() == nullptr) {
  1378. GELOGE(GRAPH_FAILED, "operator impl is nullptr, name %s.", name);
  1379. return GRAPH_FAILED;
  1380. }
  1381. std::string op_name = name;
  1382. ConstGeTensorPtr tensor;
  1383. if (!AttrUtils::GetTensor(operator_impl_->GetOpDescImpl(), op_name, tensor)) {
  1384. GELOGW("get attr name %s failed.", op_name.c_str());
  1385. return GRAPH_FAILED;
  1386. }
  1387. attr_value = TensorAdapter::GeTensor2Tensor(tensor);
  1388. return GRAPH_SUCCESS;
  1389. }
  1390. graphStatus Operator::GetAttr(const string &name, vector<Tensor> &attr_value) const {
  1391. attr_value.clear();
  1392. if (operator_impl_ == nullptr || operator_impl_->GetOpDescImpl() == nullptr) {
  1393. GELOGE(GRAPH_FAILED, "operator impl is nullptr, name %s.", name.c_str());
  1394. return GRAPH_FAILED;
  1395. }
  1396. vector<ConstGeTensorPtr> val_list;
  1397. if (!AttrUtils::GetListTensor(operator_impl_->GetOpDescImpl(), name, val_list)) {
  1398. GELOGW("get attr name %s failed.", name.c_str());
  1399. return GRAPH_FAILED;
  1400. }
  1401. for (auto &tensor : val_list) {
  1402. attr_value.push_back(TensorAdapter::GeTensor2Tensor(tensor));
  1403. }
  1404. return GRAPH_SUCCESS;
  1405. }
  1406. graphStatus Operator::GetAttr(const char *name, vector<Tensor> &attr_value) const {
  1407. if (name == nullptr) {
  1408. GELOGE(GRAPH_FAILED, "Operator name is nullptr.");
  1409. return GRAPH_FAILED;
  1410. }
  1411. attr_value.clear();
  1412. if (operator_impl_ == nullptr || operator_impl_->GetOpDescImpl() == nullptr) {
  1413. GELOGE(GRAPH_FAILED, "Operator impl is nullptr, name %s.", name);
  1414. return GRAPH_FAILED;
  1415. }
  1416. std::string op_name = name;
  1417. vector<ConstGeTensorPtr> val_list;
  1418. if (!AttrUtils::GetListTensor(operator_impl_->GetOpDescImpl(), op_name, val_list)) {
  1419. GELOGW("get attr name %s failed.", op_name.c_str());
  1420. return GRAPH_FAILED;
  1421. }
  1422. for (auto &tensor : val_list) {
  1423. attr_value.push_back(TensorAdapter::GeTensor2Tensor(tensor));
  1424. }
  1425. return GRAPH_SUCCESS;
  1426. }
  1427. Operator &Operator::SetAttr(const string &name, const OpBytes &attr_value) {
  1428. if (operator_impl_ == nullptr || operator_impl_->GetOpDescImpl() == nullptr) {
  1429. GELOGE(GRAPH_FAILED, "operator impl is nullptr, name %s.", name.c_str());
  1430. return *this;
  1431. }
  1432. if (!AttrUtils::SetZeroCopyBytes(operator_impl_->GetOpDescImpl(), name,
  1433. Buffer::CopyFrom(attr_value.data(), attr_value.size()))) {
  1434. GELOGW("set attr name %s failed.", name.c_str());
  1435. }
  1436. return *this;
  1437. }
  1438. Operator &Operator::SetAttr(const char *name, const OpBytes &attr_value) {
  1439. if (name == nullptr) {
  1440. GELOGE(GRAPH_FAILED, "Operator name is nullptr.");
  1441. return *this;
  1442. }
  1443. if (operator_impl_ == nullptr || operator_impl_->GetOpDescImpl() == nullptr) {
  1444. GELOGE(GRAPH_FAILED, "Operator impl is nullptr, name %s.", name);
  1445. return *this;
  1446. }
  1447. std::string op_name = name;
  1448. if (!AttrUtils::SetZeroCopyBytes(operator_impl_->GetOpDescImpl(), op_name,
  1449. Buffer::CopyFrom(attr_value.data(), attr_value.size()))) {
  1450. GELOGW("Set attr name %s failed.", op_name.c_str());
  1451. }
  1452. return *this;
  1453. }
  1454. graphStatus Operator::GetAttr(const string &name, OpBytes &attr_value) const {
  1455. if (operator_impl_ == nullptr || operator_impl_->GetOpDescImpl() == nullptr) {
  1456. GELOGE(GRAPH_FAILED, "operator impl is nullptr, name %s.", name.c_str());
  1457. return GRAPH_FAILED;
  1458. }
  1459. Buffer buffer;
  1460. if (!AttrUtils::GetZeroCopyBytes(operator_impl_->GetOpDescImpl(), name, buffer)) {
  1461. GELOGW("get attr name %s failed.", name.c_str());
  1462. return GRAPH_FAILED;
  1463. }
  1464. attr_value.clear();
  1465. if (buffer.data() == nullptr) {
  1466. GELOGE(GRAPH_FAILED, "buffer data is null.");
  1467. return GRAPH_FAILED;
  1468. }
  1469. attr_value.assign(buffer.data(), buffer.data() + buffer.size());
  1470. return GRAPH_SUCCESS;
  1471. }
  1472. graphStatus Operator::GetAttr(const char *name, OpBytes &attr_value) const {
  1473. if (name == nullptr) {
  1474. GELOGE(GRAPH_FAILED, "Operator name is nullptr.");
  1475. return GRAPH_FAILED;
  1476. }
  1477. if (operator_impl_ == nullptr || operator_impl_->GetOpDescImpl() == nullptr) {
  1478. GELOGE(GRAPH_FAILED, "Operator impl is nullptr, name %s.", name);
  1479. return GRAPH_FAILED;
  1480. }
  1481. std::string op_name = name;
  1482. Buffer buffer;
  1483. if (!AttrUtils::GetZeroCopyBytes(operator_impl_->GetOpDescImpl(), op_name, buffer)) {
  1484. GELOGW("Get attr name %s failed.", op_name.c_str());
  1485. return GRAPH_FAILED;
  1486. }
  1487. attr_value.clear();
  1488. if (buffer.data() == nullptr) {
  1489. GELOGE(GRAPH_FAILED, "Buffer data is null.");
  1490. return GRAPH_FAILED;
  1491. }
  1492. attr_value.assign(buffer.data(), buffer.data() + buffer.size());
  1493. return GRAPH_SUCCESS;
  1494. }
  1495. Operator &Operator::SetAttr(const string &name, ge::AttrValue &&attrValue) {
  1496. GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return *this, "Operator impl is nullptr.");
  1497. (void)operator_impl_->SetAttr(name, std::move(attrValue.impl->geAttrValue_));
  1498. return *this;
  1499. }
  1500. Operator &Operator::SetAttr(const char *name, ge::AttrValue &&attrValue) {
  1501. GE_CHK_BOOL_EXEC(name != nullptr, return *this, "Operator name is nullptr.");
  1502. GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return *this, "Operator impl is nullptr.");
  1503. std::string op_name = name;
  1504. (void)operator_impl_->SetAttr(op_name, std::move(attrValue.impl->geAttrValue_));
  1505. return *this;
  1506. }
  1507. graphStatus Operator::GetAttr(const string &name, ge::AttrValue &attrValue) const {
  1508. GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return GRAPH_FAILED, "operator impl is nullptr.");
  1509. return operator_impl_->GetAttr(name, attrValue.impl->geAttrValue_);
  1510. }
  1511. graphStatus Operator::GetAttr(const char *name, ge::AttrValue &attrValue) const {
  1512. GE_CHK_BOOL_EXEC(name != nullptr, return GRAPH_FAILED, "Operator name is nullptr.");
  1513. GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return GRAPH_FAILED, "Operator impl is nullptr.");
  1514. std::string op_name = name;
  1515. return operator_impl_->GetAttr(op_name, attrValue.impl->geAttrValue_);
  1516. }
  1517. Operator &Operator::SetAttr(const string &name, const std::vector<ge::DataType> &attr_value) {
  1518. if (operator_impl_ == nullptr || !operator_impl_->GetOpDescImpl()) {
  1519. GELOGE(GRAPH_FAILED, "operator impl is nullptr, name %s.", name.c_str());
  1520. return *this;
  1521. }
  1522. if (!AttrUtils::SetListDataType(operator_impl_->GetOpDescImpl(), name, attr_value)) {
  1523. GELOGW("set attr name %s failed.", name.c_str());
  1524. }
  1525. return *this;
  1526. }
  1527. Operator &Operator::SetAttr(const char *name, const std::vector<ge::DataType> &attr_value) {
  1528. if (name == nullptr) {
  1529. GELOGE(GRAPH_FAILED, "Operator name is nullptr.");
  1530. return *this;
  1531. }
  1532. if (operator_impl_ == nullptr || !operator_impl_->GetOpDescImpl()) {
  1533. GELOGE(GRAPH_FAILED, "Operator impl is nullptr, name %s.", name);
  1534. return *this;
  1535. }
  1536. std::string op_name = name;
  1537. if (!AttrUtils::SetListDataType(operator_impl_->GetOpDescImpl(), op_name, attr_value)) {
  1538. GELOGW("Set attr name %s failed.", op_name.c_str());
  1539. }
  1540. return *this;
  1541. }
  1542. graphStatus Operator::GetAttr(const string &name, std::vector<ge::DataType> &attr_value) const {
  1543. attr_value.clear();
  1544. if (operator_impl_ == nullptr || operator_impl_->GetOpDescImpl() == nullptr) {
  1545. GELOGE(GRAPH_FAILED, "operator impl is nullptr, name %s.", name.c_str());
  1546. return GRAPH_FAILED;
  1547. }
  1548. if (!AttrUtils::GetListDataType(operator_impl_->GetOpDescImpl(), name, attr_value)) {
  1549. GELOGW("get attr name %s failed.", name.c_str());
  1550. return GRAPH_FAILED;
  1551. }
  1552. return GRAPH_SUCCESS;
  1553. }
  1554. graphStatus Operator::GetAttr(const char *name, std::vector<ge::DataType> &attr_value) const {
  1555. if (name == nullptr) {
  1556. GELOGE(GRAPH_FAILED, "Operator name is nullptr.");
  1557. return GRAPH_FAILED;
  1558. }
  1559. attr_value.clear();
  1560. if (operator_impl_ == nullptr || operator_impl_->GetOpDescImpl() == nullptr) {
  1561. GELOGE(GRAPH_FAILED, "Operator impl is nullptr, name %s.", name);
  1562. return GRAPH_FAILED;
  1563. }
  1564. std::string op_name = name;
  1565. if (!AttrUtils::GetListDataType(operator_impl_->GetOpDescImpl(), op_name, attr_value)) {
  1566. GELOGW("Get attr name %s failed.", op_name.c_str());
  1567. return GRAPH_FAILED;
  1568. }
  1569. return GRAPH_SUCCESS;
  1570. }
  1571. Operator &Operator::SetAttr(const string &name, const ge::DataType &attr_value) {
  1572. if (operator_impl_ == nullptr || operator_impl_->GetOpDescImpl() == nullptr) {
  1573. GELOGE(GRAPH_FAILED, "operator impl is nullptr, name %s.", name.c_str());
  1574. return *this;
  1575. }
  1576. if (!AttrUtils::SetDataType(operator_impl_->GetOpDescImpl(), name, attr_value)) {
  1577. GELOGW("set attr name %s failed.", name.c_str());
  1578. }
  1579. return *this;
  1580. }
  1581. Operator &Operator::SetAttr(const char *name, const ge::DataType &attr_value) {
  1582. if (name == nullptr) {
  1583. GELOGE(GRAPH_FAILED, "Operator name is nullptr.");
  1584. return *this;
  1585. }
  1586. if (operator_impl_ == nullptr || operator_impl_->GetOpDescImpl() == nullptr) {
  1587. GELOGE(GRAPH_FAILED, "Operator impl is nullptr, name %s.", name);
  1588. return *this;
  1589. }
  1590. std::string op_name = name;
  1591. if (!AttrUtils::SetDataType(operator_impl_->GetOpDescImpl(), op_name, attr_value)) {
  1592. GELOGW("Set attr name %s failed.", op_name.c_str());
  1593. }
  1594. return *this;
  1595. }
  1596. graphStatus Operator::GetAttr(const string &name, ge::DataType &attr_value) const {
  1597. if (operator_impl_ == nullptr || operator_impl_->GetOpDescImpl() == nullptr) {
  1598. GELOGE(GRAPH_FAILED, "operator impl is nullptr, name %s.", name.c_str());
  1599. return GRAPH_FAILED;
  1600. }
  1601. if (!AttrUtils::GetDataType(operator_impl_->GetOpDescImpl(), name, attr_value)) {
  1602. GELOGW("get attr name %s failed.", name.c_str());
  1603. return GRAPH_FAILED;
  1604. }
  1605. return GRAPH_SUCCESS;
  1606. }
  1607. graphStatus Operator::GetAttr(const char *name, ge::DataType &attr_value) const {
  1608. if (name == nullptr) {
  1609. GELOGE(GRAPH_FAILED, "Operator name is nullptr.");
  1610. return GRAPH_FAILED;
  1611. }
  1612. if (operator_impl_ == nullptr || operator_impl_->GetOpDescImpl() == nullptr) {
  1613. GELOGE(GRAPH_FAILED, "Operator impl is nullptr, name %s.", name);
  1614. return GRAPH_FAILED;
  1615. }
  1616. std::string op_name = name;
  1617. if (!AttrUtils::GetDataType(operator_impl_->GetOpDescImpl(), op_name, attr_value)) {
  1618. GELOGW("Get attr name %s failed.", op_name.c_str());
  1619. return GRAPH_FAILED;
  1620. }
  1621. return GRAPH_SUCCESS;
  1622. }
  1623. void Operator::AttrRegister(const string &name, const std::vector<ge::DataType> &attr_value) {
  1624. if (operator_impl_ == nullptr || operator_impl_->GetOpDescImpl() == nullptr) {
  1625. GELOGE(GRAPH_FAILED, "operator impl is nullptr, name %s.", name.c_str());
  1626. return;
  1627. }
  1628. if (!AttrUtils::SetListDataType(operator_impl_->GetOpDescImpl(), name, attr_value)) {
  1629. GELOGW("set attr name %s failed.", name.c_str());
  1630. }
  1631. }
  1632. void Operator::AttrRegister(const string &name, const ge::DataType &attr_value) {
  1633. if (operator_impl_ == nullptr || operator_impl_->GetOpDescImpl() == nullptr) {
  1634. GELOGE(GRAPH_FAILED, "operator impl is nullptr, name %s.", name.c_str());
  1635. return;
  1636. }
  1637. if (!AttrUtils::SetDataType(operator_impl_->GetOpDescImpl(), name, attr_value)) {
  1638. GELOGW("set attr name %s failed.", name.c_str());
  1639. }
  1640. }
  1641. void Operator::AttrRegister(const string &name, const Tensor &attr_value) {
  1642. if (operator_impl_ == nullptr || operator_impl_->GetOpDescImpl() == nullptr) {
  1643. GELOGE(GRAPH_FAILED, "operator impl is nullptr, name %s.", name.c_str());
  1644. return;
  1645. }
  1646. auto tensor = TensorAdapter::AsGeTensor(attr_value);
  1647. if (!AttrUtils::SetTensor(operator_impl_->GetOpDescImpl(), name, tensor)) {
  1648. GELOGW("reg attr name %s failed.", name.c_str());
  1649. }
  1650. }
  1651. void Operator::AttrRegister(const string &name, const vector<Tensor> &attr_value) {
  1652. if (operator_impl_ == nullptr || operator_impl_->GetOpDescImpl() == nullptr) {
  1653. GELOGE(GRAPH_FAILED, "operator impl is nullptr, name %s.", name.c_str());
  1654. return;
  1655. }
  1656. vector<GeTensor> val_list;
  1657. for (const auto &item : attr_value) {
  1658. val_list.push_back(TensorAdapter::AsGeTensor(item));
  1659. }
  1660. if (!AttrUtils::SetListTensor(operator_impl_->GetOpDescImpl(), name, val_list)) {
  1661. GELOGW("reg attr name %s failed.", name.c_str());
  1662. }
  1663. }
  1664. void Operator::AttrRegister(const string &name, const OpBytes &attr_value) {
  1665. if (operator_impl_ == nullptr || operator_impl_->GetOpDescImpl() == nullptr) {
  1666. GELOGE(GRAPH_FAILED, "operator impl is nullptr, name %s.", name.c_str());
  1667. return;
  1668. }
  1669. if (!AttrUtils::SetZeroCopyBytes(operator_impl_->GetOpDescImpl(), name,
  1670. Buffer::CopyFrom(attr_value.data(), attr_value.size()))) {
  1671. GELOGW("reg attr name %s failed.", name.c_str());
  1672. }
  1673. }
  1674. void Operator::SubgraphRegister(const std::string &name, bool dynamic) {
  1675. if (operator_impl_ == nullptr) {
  1676. GELOGE(GRAPH_FAILED, "operator impl is nullptr, name %s.", name.c_str());
  1677. return;
  1678. }
  1679. operator_impl_->SubgraphRegister(name, dynamic ? kDynamic : kStatic);
  1680. }
  1681. void Operator::SubgraphCountRegister(const std::string &name, uint32_t count) {
  1682. if (operator_impl_ == nullptr) {
  1683. GELOGE(GRAPH_FAILED, "operator impl is nullptr, name %s.", name.c_str());
  1684. return;
  1685. }
  1686. operator_impl_->SubgraphCountRegister(name, count);
  1687. }
  1688. void Operator::SetSubgraphBuilder(const std::string &ir_name, uint32_t index, const SubgraphBuilder &builder) {
  1689. if (operator_impl_ == nullptr) {
  1690. GELOGE(GRAPH_FAILED, "operator impl is nullptr, name %s.", ir_name.c_str());
  1691. return;
  1692. }
  1693. operator_impl_->SetSubgraphBuilder(ir_name, index, builder);
  1694. }
  1695. std::vector<std::string> Operator::GetSubgraphNames() const {
  1696. return operator_impl_->GetSubgraphNames();
  1697. }
  1698. graphStatus Operator::GetSubgraphNames(std::vector<AscendString> &names) const {
  1699. std::vector<std::string> subgraph_names = operator_impl_->GetSubgraphNames();
  1700. for (auto &subgraph_name : subgraph_names) {
  1701. names.emplace_back(subgraph_name.c_str());
  1702. }
  1703. return GRAPH_SUCCESS;
  1704. }
  1705. SubgraphBuilder Operator::GetDynamicSubgraphBuilder(const string &ir_name, uint32_t index) const {
  1706. if (operator_impl_ == nullptr) {
  1707. GELOGE(GRAPH_FAILED, "operator impl is nullptr.");
  1708. return nullptr;
  1709. }
  1710. return operator_impl_->GetSubgraphBuilder(ir_name, index);
  1711. }
  1712. SubgraphBuilder Operator::GetDynamicSubgraphBuilder(const char *ir_name, uint32_t index) const {
  1713. if (operator_impl_ == nullptr) {
  1714. GELOGE(GRAPH_FAILED, "Operator impl is nullptr.");
  1715. return nullptr;
  1716. }
  1717. if (ir_name == nullptr) {
  1718. GELOGE(GRAPH_FAILED, "Operator name is nullptr.");
  1719. return nullptr;
  1720. }
  1721. std::string op_ir_name = ir_name;
  1722. return operator_impl_->GetSubgraphBuilder(op_ir_name, index);
  1723. }
  1724. SubgraphBuilder Operator::GetSubgraphBuilder(const string &ir_name) const {
  1725. return GetDynamicSubgraphBuilder(ir_name, 0);
  1726. }
  1727. SubgraphBuilder Operator::GetSubgraphBuilder(const char *ir_name) const {
  1728. std::string graph_ir_name;
  1729. if (ir_name != nullptr) {
  1730. graph_ir_name = ir_name;
  1731. }
  1732. return GetDynamicSubgraphBuilder(graph_ir_name, 0);
  1733. }
  1734. Graph Operator::GetSubgraphImpl(const string &name) const {
  1735. if (operator_impl_ == nullptr) {
  1736. GE_LOGE("Failed to get subgraph %s, the operator impl is null", name.c_str());
  1737. return Graph("");
  1738. }
  1739. auto op_desc = OpDescUtils::GetOpDescFromOperator(*this);
  1740. if (op_desc == nullptr) {
  1741. GE_LOGE("Failed to get subgraph %s, the op_desc is null", name.c_str());
  1742. return Graph("");
  1743. }
  1744. const auto &subgraph_names_to_index = op_desc->GetSubgraphNameIndexes();
  1745. auto iter = subgraph_names_to_index.find(name);
  1746. if (iter == subgraph_names_to_index.end()) {
  1747. GE_LOGE("Failed to get subgraph %s, the name may be invalid", name.c_str());
  1748. return Graph("");
  1749. }
  1750. auto subgraph_instance_name = op_desc->GetSubgraphInstanceName(iter->second);
  1751. if (subgraph_instance_name.empty()) {
  1752. GE_LOGE("Failed to get subgraph %s index %u, the subgraph may not be added",
  1753. name.c_str(), iter->second);
  1754. return Graph("");
  1755. }
  1756. auto node = operator_impl_->GetNode();
  1757. if (node == nullptr) {
  1758. GE_LOGE("Failed to get subgraph %s, the node is null", name.c_str());
  1759. return Graph("");
  1760. }
  1761. auto root_graph = GraphUtils::FindRootGraph(node->GetOwnerComputeGraph());
  1762. if (root_graph == nullptr) {
  1763. GE_LOGE("Failed to get subgraph %s, can not find the root graph", name.c_str());
  1764. return Graph("");
  1765. }
  1766. auto subgraph = root_graph->GetSubgraph(subgraph_instance_name);
  1767. if (subgraph == nullptr) {
  1768. GE_LOGE("Failed to get subgraph %s index %u, can not find the instance %s from the root graph",
  1769. name.c_str(), iter->second, subgraph_instance_name.c_str());
  1770. return Graph("");
  1771. }
  1772. return GraphUtils::CreateGraphFromComputeGraph(subgraph);
  1773. }
  1774. Graph Operator::GetSubgraph(const string &name) const {
  1775. return GetSubgraphImpl(name);
  1776. }
  1777. Graph Operator::GetSubgraph(const char *name) const {
  1778. if (name == nullptr) {
  1779. GELOGE(GRAPH_FAILED, "Get subgraph failed, name is nullptr.");
  1780. return Graph("");
  1781. }
  1782. std::string op_name = name;
  1783. return GetSubgraphImpl(op_name);
  1784. }
  1785. Graph Operator::GetDynamicSubgraph(const string &name, uint32_t index) const {
  1786. return GetSubgraph(name + std::to_string(index));
  1787. }
  1788. Graph Operator::GetDynamicSubgraph(const char *name, uint32_t index) const {
  1789. if (name == nullptr) {
  1790. GELOGE(GRAPH_FAILED, "Operator name is nullptr.");
  1791. return Graph("");
  1792. }
  1793. std::string op_name = name;
  1794. return GetSubgraph(op_name + std::to_string(index));
  1795. }
  1796. size_t Operator::GetSubgraphNamesCount() const {
  1797. if (operator_impl_ == nullptr) {
  1798. GE_LOGE("Failed to get subgraph names count, the operator impl is null");
  1799. return 0;
  1800. }
  1801. return operator_impl_->GetSubgraphNamesCount();
  1802. }
  1803. class GraphBuilderImpl {
  1804. public:
  1805. explicit GraphBuilderImpl(const string &name) : graph_(ComGraphMakeShared<ComputeGraph>(name)) {
  1806. if (graph_ == nullptr) {
  1807. GELOGE(GRAPH_FAILED, "ComputeGraph make shared failed");
  1808. return;
  1809. }
  1810. }
  1811. ~GraphBuilderImpl() {}
  1812. ComputeGraphPtr BuildGraph(const std::vector<Operator> &inputs) {
  1813. std::vector<OperatorImplPtr> vec_inputs;
  1814. for (auto &it : inputs) {
  1815. auto src_op_impl = it.operator_impl_;
  1816. GE_CHK_BOOL_EXEC(src_op_impl != nullptr, return nullptr, "Operator Impl is null.");
  1817. GE_CHK_BOOL_EXEC(src_op_impl->op_desc_ != nullptr, return nullptr, "Operator impl's opdesc is null.");
  1818. string type = src_op_impl->op_desc_->GetType();
  1819. auto node_op = ge::OperatorFactory::CreateOperator("node_op", type);
  1820. auto tensor_desc = ge::OpDescUtils::GetOpDescFromOperator(node_op);
  1821. node_op.BreakConnect();
  1822. GE_CHK_BOOL_EXEC(tensor_desc != nullptr, continue, "tensor_desc is null.");
  1823. if ((tensor_desc->GetInputsSize() == 0 && tensor_desc->GetOutputsSize() > 0) || type == DATA ||
  1824. type == VARIABLE || type == INITDATA || type == GETNEXT) {
  1825. vec_inputs.push_back(it.operator_impl_);
  1826. } else {
  1827. GELOGW("Input operator should be Data, Variable operator or operator that has output but no input.");
  1828. }
  1829. }
  1830. GE_CHK_BOOL_EXEC(!vec_inputs.empty(), return nullptr, "User Input do not include operator such as "
  1831. "Data, Variable operator or operator that has output but no input.");
  1832. auto ret = WalkAllOperators(vec_inputs);
  1833. GE_CHK_BOOL_EXEC(ret == GRAPH_SUCCESS, return nullptr, "WalkAllOperators failed.");
  1834. ret = AddEdge();
  1835. GE_CHK_BOOL_EXEC(ret == GRAPH_SUCCESS, return nullptr, "AddEdge failed.");
  1836. return graph_;
  1837. }
  1838. const std::map<OperatorImplPtr, NodePtr> &GetAllNodesInfo() const { return all_nodes_info_; }
  1839. private:
  1840. graphStatus WalkAllOperators(const std::vector<OperatorImplPtr> &vec_ops) {
  1841. GE_CHK_BOOL_EXEC(graph_ != nullptr, return GRAPH_FAILED, "graph_ is null.")
  1842. std::queue<std::vector<OperatorImplPtr>> que;
  1843. que.push(vec_ops);
  1844. while (!que.empty()) {
  1845. auto vec_tem = que.front();
  1846. que.pop();
  1847. for (const auto &op_impl : vec_tem) {
  1848. GE_CHK_BOOL_EXEC(op_impl != nullptr, return GRAPH_FAILED, "Operator Impl is null.")
  1849. GE_CHK_BOOL_EXEC_INFO(all_nodes_info_.find(op_impl) == all_nodes_info_.end(), continue,
  1850. "This node %s has created.", op_impl->GetName().c_str())
  1851. auto node_ptr = graph_->AddNode(op_impl->op_desc_);
  1852. GE_CHK_BOOL_EXEC(node_ptr != nullptr, return GRAPH_FAILED, "Add node failed.");
  1853. all_nodes_info_.insert(std::make_pair(op_impl, node_ptr));
  1854. auto &out_links = op_impl->output_links_;
  1855. std::vector<OperatorImplPtr> vec_op_forward{};
  1856. for (const auto &out_link : out_links) {
  1857. for (const auto &op_forward : out_link.second) {
  1858. vec_op_forward.push_back(op_forward.GetOwner());
  1859. }
  1860. }
  1861. auto &out_control_links = op_impl->control_output_link_;
  1862. for (const auto &out_link : out_control_links) {
  1863. vec_op_forward.push_back(out_link.lock());
  1864. }
  1865. que.push(vec_op_forward);
  1866. auto &in_links = op_impl->input_link_;
  1867. std::vector<OperatorImplPtr> vec_op_back_forward{};
  1868. for (const auto &in_link : in_links) {
  1869. vec_op_back_forward.push_back(in_link.second.GetOwner());
  1870. }
  1871. auto &in_control_links = op_impl->control_input_link_;
  1872. for (const auto &in_link : in_control_links) {
  1873. vec_op_back_forward.push_back(in_link.lock());
  1874. }
  1875. que.push(vec_op_back_forward);
  1876. if (WalkAllSubgraphs(node_ptr, op_impl) != GRAPH_SUCCESS) {
  1877. return GRAPH_FAILED;
  1878. }
  1879. }
  1880. }
  1881. return MoveSubgraphToRoot(graph_);
  1882. }
  1883. graphStatus WalkAllSubgraphs(const NodePtr &node, const OperatorImplPtr &op_impl) {
  1884. const string name = node->GetName();
  1885. for (auto &name_idx : op_impl->op_desc_->GetSubgraphNameIndexes()) {
  1886. const SubgraphBuilder &builder = op_impl->GetSubgraphBuilder(name_idx.first);
  1887. if (builder == nullptr) {
  1888. GELOGW("Node: %s, Has no builder.", name.c_str());
  1889. continue;
  1890. }
  1891. Graph graph = builder(); // Build subgraph from user define builder.
  1892. const ComputeGraphPtr &subgraph = GraphUtils::GetComputeGraph(graph);
  1893. GE_CHK_BOOL_EXEC(subgraph != nullptr, return GRAPH_FAILED, "Node: %s, Build graph failed.", name.c_str());
  1894. subgraph->SetParentNode(node);
  1895. subgraph->SetParentGraph(graph_);
  1896. if (graph_->AddSubgraph(subgraph->GetName(), subgraph) != GRAPH_SUCCESS) {
  1897. return GRAPH_FAILED;
  1898. }
  1899. if (op_impl->op_desc_->SetSubgraphInstanceName(name_idx.second, subgraph->GetName()) != GRAPH_SUCCESS) {
  1900. GELOGE(GRAPH_FAILED, "Failed to set subgraph %s index %u", subgraph->GetName().c_str(), name_idx.second);
  1901. return GRAPH_FAILED;
  1902. }
  1903. }
  1904. return GRAPH_SUCCESS;
  1905. }
  1906. graphStatus MoveSubgraphToRoot(const ComputeGraphPtr &graph) {
  1907. const ComputeGraphPtr &root_graph = GraphUtils::FindRootGraph(graph);
  1908. if (root_graph == nullptr) {
  1909. GELOGE(GRAPH_FAILED, "Graph: %s, Find root graph failed.", graph->GetName().c_str());
  1910. return GRAPH_FAILED;
  1911. }
  1912. if (root_graph == graph) {
  1913. auto subgraphs = graph->GetAllSubgraphs();
  1914. for (auto &subgraph : subgraphs) {
  1915. if (MoveSubgraphToRoot(subgraph) != GRAPH_SUCCESS) {
  1916. return GRAPH_FAILED;
  1917. }
  1918. }
  1919. } else {
  1920. auto subgraphs = graph->GetAllSubgraphs();
  1921. for (auto &subgraph : subgraphs) {
  1922. if (root_graph->AddSubgraph(subgraph->GetName(), subgraph) != GRAPH_SUCCESS) {
  1923. return GRAPH_FAILED;
  1924. }
  1925. graph->RemoveSubgraph(subgraph->GetName());
  1926. if (MoveSubgraphToRoot(subgraph) != GRAPH_SUCCESS) {
  1927. return GRAPH_FAILED;
  1928. }
  1929. }
  1930. }
  1931. return GRAPH_SUCCESS;
  1932. }
  1933. graphStatus AddEdge() {
  1934. for (const auto &node_info : all_nodes_info_) {
  1935. auto src_op_impl_ptr = node_info.first;
  1936. auto src_node_ptr = node_info.second;
  1937. GE_IF_BOOL_EXEC(src_op_impl_ptr == nullptr || src_node_ptr == nullptr, continue);
  1938. auto out_links = src_op_impl_ptr->output_links_;
  1939. GE_CHK_BOOL_EXEC(src_op_impl_ptr->op_desc_ != nullptr, return GRAPH_FAILED,
  1940. "Src operator impl's op_desc is null.");
  1941. auto &op_desc = src_op_impl_ptr->op_desc_;
  1942. GE_IF_BOOL_EXEC(op_desc == nullptr, continue);
  1943. for (const auto &out : out_links) {
  1944. auto src_idx = op_desc->GetOutputIndexByName(out.first);
  1945. GE_CHK_BOOL_EXEC(src_idx >= 0, return GRAPH_FAILED, "Find output index by name failed");
  1946. auto src_anchor = src_node_ptr->GetOutDataAnchor(src_idx);
  1947. GE_CHK_BOOL_EXEC(src_anchor != nullptr, return GRAPH_FAILED, "GetOutDataAnchor failed.");
  1948. for (const auto &dst_opio : out.second) {
  1949. auto dst_node_info = all_nodes_info_.find(dst_opio.GetOwner());
  1950. GE_CHK_BOOL_EXEC(dst_node_info != all_nodes_info_.end(), return GRAPH_FAILED, "Find Dst node failed.");
  1951. GE_IF_BOOL_EXEC(dst_node_info->second == nullptr, continue);
  1952. auto dst_anchor = dst_node_info->second->GetInDataAnchor(dst_opio.GetIndex());
  1953. GE_CHK_BOOL_EXEC(dst_anchor != nullptr, return GRAPH_FAILED, "GetInDataAnchor failed.");
  1954. auto ret = GraphUtils::AddEdge(src_anchor, dst_anchor);
  1955. GE_CHK_BOOL_EXEC(ret == GRAPH_SUCCESS, return GRAPH_FAILED,
  1956. "from node[%s][%d] to node[%s][%d]AddEdge failed.",
  1957. src_node_ptr->GetName().c_str(), src_anchor->GetIdx(),
  1958. dst_node_info->second->GetName().c_str(), dst_anchor->GetIdx());
  1959. }
  1960. }
  1961. auto out_control_anchor = src_node_ptr->GetOutControlAnchor();
  1962. for (const auto &control_out : src_op_impl_ptr->control_output_link_) {
  1963. auto dst_node_info = all_nodes_info_.find(control_out.lock());
  1964. if (dst_node_info == all_nodes_info_.end()) {
  1965. GELOGE(GRAPH_FAILED, "Find Dst node failed.");
  1966. return GRAPH_FAILED;
  1967. }
  1968. GE_IF_BOOL_EXEC(dst_node_info->second == nullptr, continue);
  1969. auto in_control_anchor = dst_node_info->second->GetInControlAnchor();
  1970. auto ret = GraphUtils::AddEdge(out_control_anchor, in_control_anchor);
  1971. if (ret != GRAPH_SUCCESS) {
  1972. GELOGE(ret, "AddEdge failed. srcNode %s:%s, dstNode %s:%s", op_desc->GetName().c_str(),
  1973. op_desc->GetType().c_str(), dst_node_info->second->GetName().c_str(),
  1974. dst_node_info->second->GetType().c_str());
  1975. return ret;
  1976. }
  1977. }
  1978. }
  1979. return GRAPH_SUCCESS;
  1980. }
  1981. ComputeGraphPtr graph_ = nullptr;
  1982. std::map<OperatorImplPtr, NodePtr> all_nodes_info_{};
  1983. };
  1984. inline bool HasSameNameNode(const ComputeGraphPtr &compute_graph) {
  1985. for (const auto &graph : compute_graph->GetAllSubgraphs()) {
  1986. std::set<string> node_names;
  1987. for (auto const &node : graph->GetDirectNode()) {
  1988. auto result = node_names.insert(node->GetName());
  1989. if (!result.second) {
  1990. GELOGE(GRAPH_FAILED, "graph %s has same name node%s", graph->GetName().c_str(), node->GetName().c_str());
  1991. return true;
  1992. }
  1993. }
  1994. }
  1995. std::set<string> node_names;
  1996. for (auto const &node : compute_graph->GetDirectNode()) {
  1997. auto result = node_names.insert(node->GetName());
  1998. if (!result.second) {
  1999. GELOGE(GRAPH_FAILED, "graph %s has same name node%s", compute_graph->GetName().c_str(), node->GetName().c_str());
  2000. return true;
  2001. }
  2002. }
  2003. return false;
  2004. }
  2005. ComputeGraphPtr GraphUtils::CreateGraphFromOperator(const string &name, const vector<ge::Operator> &inputs) {
  2006. auto graph_builder_impl = GraphBuilderImpl(name);
  2007. ComputeGraphPtr compute_graph = graph_builder_impl.BuildGraph(inputs);
  2008. GE_CHK_BOOL_EXEC(compute_graph != nullptr, return compute_graph, "Computer graph is nullptr");
  2009. compute_graph->SetAllNodesInfo(graph_builder_impl.GetAllNodesInfo());
  2010. if (HasSameNameNode(compute_graph)) {
  2011. GELOGW("Compute do not allow has same name nodes.");
  2012. compute_graph = nullptr;
  2013. }
  2014. return compute_graph;
  2015. }
  2016. void GraphUtils::BreakConnect(const std::map<OperatorImplPtr, NodePtr> &all_nodes_infos) {
  2017. for (const auto &it : all_nodes_infos) {
  2018. OperatorImplPtr op_impl = it.first;
  2019. if (op_impl == nullptr) {
  2020. GELOGW("operator impl is nullptr.");
  2021. continue;
  2022. }
  2023. op_impl->ClearOutputLinks();
  2024. op_impl->ClearInputLinks();
  2025. OperatorKeeper::GetInstance().CheckOutOperator(op_impl);
  2026. }
  2027. }
  2028. /*lint +e446 +e732*/
  2029. /*lint +e665*/
  2030. } // namespace ge

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示