You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

operator.cc 65 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587
  1. /**
  2. * Copyright 2019-2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "external/graph/operator.h"
  17. #include "external/graph/operator_factory.h"
  18. #include <stdint.h>
  19. #include <algorithm>
  20. #include <mutex>
  21. #include <queue>
  22. #include <set>
  23. #include "./array_ops.h"
  24. #include "debug/ge_log.h"
  25. #include "debug/ge_op_types.h"
  26. #include "debug/ge_util.h"
  27. #include "external/graph/attr_value.h"
  28. #include "external/graph/types.h"
  29. #include "framework/common/debug/ge_log.h"
  30. #include "graph/compute_graph.h"
  31. #include "graph/ge_attr_value.h"
  32. #include "graph/ge_context.h"
  33. #include "graph/ge_tensor.h"
  34. #include "graph/node.h"
  35. #include "graph/op_desc.h"
  36. #include "graph/runtime_inference_context.h"
  37. #include "graph/usr_types.h"
  38. #include "graph/utils/node_utils.h"
  39. #include "graph/debug/ge_attr_define.h"
  40. #include "utils/graph_utils.h"
  41. #include "utils/op_desc_utils.h"
  42. #include "utils/tensor_adapter.h"
  43. #include "utils/tensor_utils.h"
  44. #include "utils/type_utils.h"
  45. #include <algorithm>
  46. #include <mutex>
  47. #include <queue>
  48. #include <set>
  49. #include <stdint.h>
  50. using std::enable_shared_from_this;
  51. using std::make_pair;
  52. using std::shared_ptr;
  53. using std::string;
  54. using std::to_string;
  55. using std::vector;
  56. /*lint -save -e529 -e728*/
  57. /*lint -e446 -e732*/
  58. /*lint -e665*/
  59. namespace ge {
  60. class OpIO {
  61. public:
  62. OpIO(const string &name, int index, const OperatorImplPtr &owner) : name_(name), index_(index), owner_(owner) {}
  63. ~OpIO() = default;
  64. string GetName() const { return name_; }
  65. int GetIndex() const { return index_; }
  66. OperatorImplPtr GetOwner() const { return owner_; }
  67. bool operator==(const OpIO &r_value) const {
  68. return (this->name_ == r_value.GetName()) && (this->index_ == r_value.GetIndex()) &&
  69. (this->GetOwner() == r_value.GetOwner());
  70. }
  71. private:
  72. string name_;
  73. int index_;
  74. std::shared_ptr<OperatorImpl> owner_;
  75. };
  76. class TensorTypeImpl {
  77. public:
  78. TensorTypeImpl() = default;
  79. ~TensorTypeImpl() = default;
  80. std::vector<DataType> dt_vec_;
  81. };
  82. TensorType::TensorType(DataType dt) {
  83. tensor_type_impl_ = ComGraphMakeShared<TensorTypeImpl>();
  84. if (tensor_type_impl_ != nullptr) {
  85. tensor_type_impl_->dt_vec_.push_back(dt);
  86. }
  87. }
  88. TensorType::TensorType(const std::initializer_list<DataType> &types) {
  89. tensor_type_impl_ = ComGraphMakeShared<TensorTypeImpl>();
  90. if (tensor_type_impl_ != nullptr) {
  91. tensor_type_impl_->dt_vec_ = types;
  92. }
  93. }
  94. class OperatorImpl : public std::enable_shared_from_this<OperatorImpl> {
  95. friend class GraphBuilderImpl;
  96. friend class OpDescUtils;
  97. public:
  98. explicit OperatorImpl(const string &name, const string &type) : op_desc_(ComGraphMakeShared<OpDesc>(name, type)) {
  99. if (op_desc_ == nullptr) {
  100. GELOGW("OpDesc make shared failed");
  101. }
  102. }
  103. explicit OperatorImpl(const OpDescPtr &op_desc) : op_desc_(op_desc) {}
  104. explicit OperatorImpl(ge::ConstNodePtr node) : node_(std::move(node)) {
  105. if (node_ != nullptr && node_->GetOpDesc() != nullptr) {
  106. op_desc_ = node_->GetOpDesc();
  107. }
  108. }
  109. ~OperatorImpl() {}
  110. void SetInputImpl(const string &dst_name, const ge::Operator &src_oprt) {
  111. GE_CHK_BOOL_EXEC(!dst_name.empty(), return, "dst name is empty");
  112. GE_CHK_BOOL_EXEC(op_desc_ != nullptr, return, "op_desc_ is nullptr.");
  113. GE_CHK_BOOL_EXEC(src_oprt.operator_impl_ != nullptr, return, "operator_impl_ is nullptr.");
  114. GE_CHK_BOOL_EXEC(src_oprt.operator_impl_->op_desc_ != nullptr, return, "op_desc_ is nullptr.");
  115. auto src_op_impl = src_oprt.GetOperatorImplPtr();
  116. GE_CHK_BOOL_EXEC(src_op_impl != nullptr, return, "Src impl is null.");
  117. GE_CHK_BOOL_EXEC(src_op_impl->op_desc_ != nullptr, return, "Src impl's opdesc is null.");
  118. GE_CHK_BOOL_EXEC(src_oprt.operator_impl_->op_desc_->GetOutputsSize() == 1, return,
  119. "The source operator[%s] must has one output",
  120. src_oprt.operator_impl_->op_desc_->GetName().c_str())
  121. uint32_t src_index = 0;
  122. string src_name = src_op_impl->op_desc_->GetOutputNameByIndex(src_index);
  123. GE_CHK_BOOL_EXEC(!src_name.empty(), return, "Src output's name is empty.");
  124. OpIO out_handler(src_name, src_index, src_op_impl);
  125. input_link_.insert(std::make_pair(dst_name, out_handler));
  126. int dst_index = op_desc_->GetInputIndexByName(dst_name);
  127. GE_CHK_BOOL_EXEC(dst_index >= 0, return, "Find input index by name failed. name[%s], op name:%s", dst_name.c_str(),
  128. op_desc_->GetName().c_str());
  129. bool is_const = false;
  130. if (src_oprt.GetOpType() == CONSTANT) {
  131. is_const = true;
  132. }
  133. auto is_input_const = op_desc_->GetIsInputConst();
  134. for (int i = static_cast<int>(is_input_const.size()); i <= dst_index; ++i) {
  135. is_input_const.push_back(false);
  136. }
  137. is_input_const[dst_index] = is_const;
  138. op_desc_->SetIsInputConst(is_input_const);
  139. OpIO op_dst(dst_name, dst_index, shared_from_this());
  140. src_op_impl->UpdateLinkMapImpl(src_name, op_dst);
  141. auto output_desc = src_op_impl->GetOutputDesc(src_name);
  142. auto input_desc = op_desc_->GetInputDesc(dst_name);
  143. if (input_desc.GetFormat() == FORMAT_RESERVED) {
  144. output_desc.SetFormat(FORMAT_ND);
  145. } else {
  146. output_desc.SetFormat(input_desc.GetFormat());
  147. }
  148. // Fix for linking opdesc
  149. if (op_desc_->UpdateInputDesc(dst_name, output_desc) != GRAPH_SUCCESS) {
  150. GELOGE(GRAPH_FAILED, "Update inputdesc failed,dst name is %s, src name is %s", dst_name.c_str(),
  151. src_name.c_str());
  152. return;
  153. }
  154. }
  155. void SetInputImpl(const string &dst_name, const ge::OutHandler &out_handler) {
  156. GE_CHK_BOOL_EXEC(!dst_name.empty(), return, "dst name is empty");
  157. GE_CHK_BOOL_EXEC(out_handler != nullptr, return, "SetInputImpl faild, out_handler is nullptr.");
  158. GE_CHK_BOOL_EXEC(op_desc_ != nullptr, return, "op_desc_ is nullptr.");
  159. input_link_.insert(std::make_pair(dst_name, *out_handler));
  160. string src_name = out_handler->GetName();
  161. int dst_index = op_desc_->GetInputIndexByName(dst_name);
  162. GE_CHK_BOOL_EXEC(dst_index >= 0, return, "Find input index by name failed. name[%s], op name:%s", dst_name.c_str(),
  163. op_desc_->GetName().c_str());
  164. auto out_op_impl = out_handler->GetOwner();
  165. GE_CHK_BOOL_EXEC(out_op_impl != nullptr && out_op_impl->GetOpDescImpl() != nullptr, return,
  166. "out_handler invalid. name[%s]", dst_name.c_str());
  167. bool is_const = false;
  168. if (out_op_impl->GetOpDescImpl()->GetType() == CONSTANT) {
  169. is_const = true;
  170. }
  171. auto is_input_const = op_desc_->GetIsInputConst();
  172. for (int i = static_cast<int>(is_input_const.size()); i <= dst_index; ++i) {
  173. is_input_const.push_back(false);
  174. }
  175. is_input_const[dst_index] = is_const;
  176. op_desc_->SetIsInputConst(is_input_const);
  177. OpIO in_handler(dst_name, dst_index, shared_from_this());
  178. GE_CHK_BOOL_EXEC(out_op_impl != nullptr, return, "Get out_handler's impl failed.");
  179. out_op_impl->UpdateLinkMapImpl(src_name, in_handler);
  180. auto src_output_desc = out_op_impl->GetOutputDesc(src_name);
  181. auto dst_input_desc = op_desc_->GetInputDesc(dst_name);
  182. if (dst_input_desc.GetFormat() == FORMAT_RESERVED) {
  183. src_output_desc.SetFormat(FORMAT_ND);
  184. } else {
  185. src_output_desc.SetFormat(dst_input_desc.GetFormat());
  186. }
  187. GE_CHK_BOOL_EXEC(op_desc_->UpdateInputDesc(dst_name, src_output_desc) == GRAPH_SUCCESS, return,
  188. "Update input desc failed,dst name is %s,src name is %s", dst_name.c_str(),
  189. src_name.c_str()); // fix for linking opdesc
  190. }
  191. void AddControlInputImp(const ge::Operator &src_oprt) {
  192. if (src_oprt.operator_impl_ == nullptr) {
  193. GELOGE(FAILED, "Src operator impl is nullptr");
  194. return;
  195. }
  196. for (auto &input : control_input_link_) {
  197. if (input.lock() == src_oprt.operator_impl_) {
  198. return;
  199. }
  200. }
  201. control_input_link_.push_back(src_oprt.operator_impl_);
  202. src_oprt.operator_impl_->control_output_link_.push_back(shared_from_this());
  203. }
  204. graphStatus GetInputImpl(const string &dst_name, ge::OpIO &out_handler) {
  205. auto out = input_link_.find(dst_name);
  206. if (out == input_link_.end()) {
  207. return GRAPH_FAILED;
  208. }
  209. out_handler = out->second;
  210. return GRAPH_SUCCESS;
  211. }
  212. bool InputIsSet(const string &name) {
  213. GE_CHK_BOOL_EXEC(op_desc_ != nullptr, return false, "op_desc_ is nullptr.");
  214. return op_desc_->InputIsSet(name);
  215. }
  216. string GetName() const {
  217. GE_CHK_BOOL_EXEC(op_desc_ != nullptr, return string(), "op_desc_ is nullptr.");
  218. return op_desc_->GetName();
  219. }
  220. GeTensorDesc GetInputDesc(const string &name) const {
  221. GE_CHK_BOOL_EXEC(op_desc_ != nullptr, return GeTensorDesc(), "op_desc_ is nullptr.");
  222. return op_desc_->GetInputDesc(name);
  223. }
  224. GeTensorDesc GetInputDesc(uint32_t index) const {
  225. GE_CHK_BOOL_EXEC(op_desc_ != nullptr, return GeTensorDesc(), "op_desc_ is nullptr.");
  226. return op_desc_->GetInputDesc(index);
  227. }
  228. graphStatus UpdateInputDesc(const string &name, const GeTensorDesc &tensor_desc) {
  229. GE_CHK_BOOL_EXEC(op_desc_ != nullptr, return GRAPH_FAILED, "op_desc_ is nullptr.");
  230. return op_desc_->UpdateInputDesc(name, tensor_desc);
  231. }
  232. OutHandler GetOutput(const string &name) {
  233. GE_CHK_BOOL_EXEC(op_desc_ != nullptr, return nullptr, "op_desc_ is nullptr.");
  234. int src_index = op_desc_->GetOutputIndexByName(name);
  235. GE_CHK_BOOL_EXEC(src_index >= 0, return nullptr, "Find src index by name failed. name[%s]", name.c_str());
  236. shared_ptr<OpIO> output_ptr = ComGraphMakeShared<OpIO>(name, src_index, shared_from_this());
  237. if (output_ptr == nullptr) {
  238. GELOGE(GRAPH_FAILED, "OpIO make shared failed");
  239. return nullptr;
  240. }
  241. return output_ptr;
  242. }
  243. OutHandler GetOutput(uint32_t index) {
  244. GE_CHK_BOOL_EXEC(op_desc_ != nullptr, return nullptr, "op_desc_ is nullptr.");
  245. string name = op_desc_->GetOutputNameByIndex(index);
  246. if (name.empty()) {
  247. GELOGE(GRAPH_FAILED, "Find src name by index failed. index[%u]", index);
  248. return nullptr;
  249. }
  250. shared_ptr<OpIO> output_ptr = ComGraphMakeShared<OpIO>(name, index, shared_from_this());
  251. if (output_ptr == nullptr) {
  252. GELOGE(GRAPH_FAILED, "OpIO make shared failed");
  253. return nullptr;
  254. }
  255. return output_ptr;
  256. }
  257. GeTensorDesc GetOutputDesc(const string &name) const {
  258. GE_CHK_BOOL_EXEC(op_desc_ != nullptr, return GeTensorDesc(), "op_desc_ is nullptr.");
  259. return op_desc_->GetOutputDesc(name);
  260. }
  261. GeTensorDesc GetOutputDesc(uint32_t index) const {
  262. GE_CHK_BOOL_EXEC(op_desc_ != nullptr, return GeTensorDesc(), "op_desc_ is nullptr.");
  263. return op_desc_->GetOutputDesc(index);
  264. }
  265. graphStatus UpdateOutputDesc(const string &name, const GeTensorDesc &tensor_desc) {
  266. GE_CHK_BOOL_RET_STATUS(op_desc_ != nullptr, GRAPH_FAILED, "op_desc is nullptr.");
  267. auto res = op_desc_->UpdateOutputDesc(name, tensor_desc);
  268. if (res == GRAPH_SUCCESS) {
  269. for (auto ol : output_links_[name]) {
  270. if (ol.GetOwner() == nullptr) {
  271. GELOGW("%s get owner is nullptr", ol.GetName().c_str());
  272. continue;
  273. }
  274. GE_CHK_BOOL_RET_STATUS(ol.GetOwner()->UpdateInputDesc(ol.GetName(), tensor_desc) == GRAPH_SUCCESS, GRAPH_FAILED,
  275. "Could not update next operator's input %s.", ol.GetName().c_str());
  276. }
  277. }
  278. return res;
  279. }
  280. size_t GetInputsSize() const {
  281. GE_IF_BOOL_EXEC(op_desc_ == nullptr, return 0);
  282. return op_desc_->GetInputsSize();
  283. }
  284. size_t GetOutputsSize() const {
  285. GE_IF_BOOL_EXEC(op_desc_ == nullptr, return 0);
  286. return op_desc_->GetOutputsSize();
  287. }
  288. graphStatus SetAttr(const string &name, GeAttrValue &&attr_value) {
  289. GE_CHK_BOOL_RET_STATUS(op_desc_ != nullptr, GRAPH_FAILED, "op_desc is nullptr.");
  290. return op_desc_->SetAttr(name, std::move(attr_value));
  291. }
  292. graphStatus GetAttr(const string &name, GeAttrValue &attr_value) const {
  293. GE_CHK_BOOL_RET_STATUS(op_desc_ != nullptr, GRAPH_FAILED, "op_desc is nullptr.");
  294. return op_desc_->GetAttr(name, attr_value);
  295. }
  296. OpDescPtr GetOpDescImpl() const { return op_desc_; }
  297. void UpdateLinkMapImpl(const string &src_name, OpIO &op_dst) {
  298. auto it_find = output_links_.find(src_name);
  299. if (it_find == output_links_.end()) {
  300. std::vector<OpIO> dsts{op_dst};
  301. output_links_.insert(std::make_pair(src_name, dsts));
  302. } else {
  303. it_find->second.push_back(op_dst);
  304. }
  305. }
  306. Operator ToOperator() { return Operator(shared_from_this()); }
  307. static OpDescPtr GetOpDesc(const Operator &oprt) {
  308. GE_IF_BOOL_EXEC(oprt.operator_impl_ == nullptr, return nullptr);
  309. return oprt.operator_impl_->op_desc_;
  310. }
  311. void ClearOutputLinks() noexcept { output_links_.clear(); }
  312. void ClearInputLinks() noexcept { input_link_.clear(); }
  313. ge::ConstNodePtr GetNode() { return node_; }
  314. void SetInferenceContext(const InferenceContextPtr &inference_context) { inference_context_ = inference_context; }
  315. InferenceContextPtr GetInferenceContext() const { return inference_context_; }
  316. void SubgraphRegister(const std::string &ir_name, bool dynamic) {
  317. op_desc_->RegisterSubgraphIrName(ir_name, dynamic ? kDynamic : kStatic);
  318. }
  319. void SubgraphCountRegister(const std::string &ir_name, uint32_t count) {
  320. if (op_desc_->GetSubgraphTypeByIrName(ir_name) == kStatic) {
  321. op_desc_->AddSubgraphName(ir_name);
  322. subgraph_names_to_builders_[ir_name] = nullptr;
  323. } else {
  324. for (uint32_t i = 0; i < count; ++i) {
  325. string key_name = ir_name + std::to_string(i);
  326. op_desc_->AddSubgraphName(key_name);
  327. subgraph_names_to_builders_[key_name] = nullptr;
  328. }
  329. }
  330. }
  331. void SetSubgraphBuilder(const std::string &ir_name, uint32_t index, const SubgraphBuilder &builder) {
  332. string key_name = ir_name;
  333. if (op_desc_->GetSubgraphTypeByIrName(ir_name) == kDynamic) {
  334. key_name += std::to_string(index);
  335. }
  336. auto it = subgraph_names_to_builders_.find(key_name);
  337. if (it == subgraph_names_to_builders_.end()) {
  338. GELOGE(PARAM_INVALID, "Failed to set subgraph builder for name %s index %u.", ir_name.c_str(), index);
  339. return;
  340. }
  341. it->second = builder;
  342. }
  343. SubgraphBuilder GetSubgraphBuilder(const std::string &ir_name, uint32_t index) const {
  344. string key_name = ir_name;
  345. if (op_desc_->GetSubgraphTypeByIrName(ir_name) == kDynamic) {
  346. key_name += std::to_string(index);
  347. }
  348. return GetSubgraphBuilder(key_name);
  349. }
  350. SubgraphBuilder GetSubgraphBuilder(const std::string &name) const {
  351. auto iter = subgraph_names_to_builders_.find(name);
  352. if (iter == subgraph_names_to_builders_.end()) {
  353. GELOGE(PARAM_INVALID, "Failed to get subgraph builder for name %s", name.c_str());
  354. return nullptr;
  355. }
  356. return iter->second;
  357. }
  358. std::vector<std::string> GetSubgraphNames() const {
  359. std::vector<std::string> names;
  360. for (const auto &subgraph_name_to_type : op_desc_->GetSubgraphIrNames()) {
  361. names.emplace_back(subgraph_name_to_type.first);
  362. }
  363. return names;
  364. }
  365. size_t GetSubgraphNamesCount() const { return op_desc_->GetSubgraphIrNames().size(); }
  366. OpDescPtr op_desc_ = nullptr;
  367. private:
  368. ge::ConstNodePtr node_{nullptr};
  369. ge::InferenceContextPtr inference_context_;
  370. std::map<string, std::vector<OpIO>> output_links_{};
  371. std::map<string, OpIO> input_link_{};
  372. std::vector<std::weak_ptr<OperatorImpl>> control_input_link_{};
  373. std::vector<std::weak_ptr<OperatorImpl>> control_output_link_{};
  374. std::map<std::string, SubgraphBuilder> subgraph_names_to_builders_;
  375. };
  376. // Used to manage OperatorImpl instances created by ge api.
  377. class OperatorKeeper {
  378. private:
  379. OperatorKeeper() = default;
  380. ~OperatorKeeper() {
  381. for (const auto &iter : operators_) {
  382. if (iter) {
  383. iter->ClearInputLinks();
  384. iter->ClearOutputLinks();
  385. }
  386. }
  387. }
  388. std::set<OperatorImplPtr> operators_;
  389. std::mutex mutex_;
  390. public:
  391. static OperatorKeeper &GetInstance() {
  392. static OperatorKeeper instance;
  393. return instance;
  394. }
  395. void CheckInOperator(const OperatorImplPtr &op_impl) {
  396. if (op_impl) {
  397. std::lock_guard<std::mutex> lock(mutex_);
  398. operators_.insert(op_impl);
  399. }
  400. }
  401. void CheckOutOperator(const OperatorImplPtr &op_impl) {
  402. if (op_impl) {
  403. std::lock_guard<std::mutex> lock(mutex_);
  404. operators_.erase(op_impl);
  405. }
  406. }
  407. };
  408. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Operator OpDescUtils::CreateOperatorFromNode(ge::ConstNodePtr node_ptr) {
  409. ge::OperatorImplPtr operator_impl_ptr = ComGraphMakeShared<OperatorImpl>(node_ptr);
  410. if (operator_impl_ptr == nullptr) {
  411. GELOGE(GRAPH_FAILED, "OperatorImpl make shared failed");
  412. return Operator("default");
  413. }
  414. return operator_impl_ptr->ToOperator();
  415. }
  416. Operator::Operator(const std::string &type) {
  417. static uint32_t index = 0;
  418. string name = type + "_" + std::to_string(index++);
  419. operator_impl_ = ComGraphMakeShared<OperatorImpl>(name, type);
  420. if (operator_impl_ == nullptr) {
  421. GELOGW("OperatorImpl make shared failed");
  422. }
  423. OperatorKeeper::GetInstance().CheckInOperator(operator_impl_);
  424. }
  425. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Operator OpDescUtils::CreateOperatorFromOpDesc(OpDescPtr op_desc) {
  426. shared_ptr<OperatorImpl> operator_impl_ptr;
  427. operator_impl_ptr = ComGraphMakeShared<OperatorImpl>(op_desc);
  428. if (operator_impl_ptr == nullptr) {
  429. GELOGE(GRAPH_FAILED, "OperatorImpl make shared failed");
  430. return Operator("default");
  431. }
  432. OperatorKeeper::GetInstance().CheckInOperator(operator_impl_ptr);
  433. return operator_impl_ptr->ToOperator();
  434. }
  435. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescPtr OpDescUtils::GetOpDescFromOperator(const Operator &oprt) {
  436. return OperatorImpl::GetOpDesc(oprt);
  437. }
  438. GE_FUNC_HOST_VISIBILITY Operator::Operator(const string &name, const string &type) {
  439. operator_impl_ = ComGraphMakeShared<OperatorImpl>(name, type);
  440. if (operator_impl_ == nullptr) {
  441. GELOGE(GRAPH_FAILED, "OperatorImpl make shared failed");
  442. return;
  443. }
  444. OperatorKeeper::GetInstance().CheckInOperator(operator_impl_);
  445. }
  446. Operator::Operator(ge::OperatorImplPtr &&op_impl) { operator_impl_ = std::move(op_impl); }
  447. bool Operator::IsEmpty() const {
  448. if (operator_impl_ == nullptr) {
  449. return true;
  450. }
  451. return false;
  452. }
  453. string Operator::GetName() const {
  454. if (operator_impl_ != nullptr) {
  455. return operator_impl_->GetName();
  456. }
  457. return "";
  458. }
  459. GE_FUNC_HOST_VISIBILITY Operator &Operator::SetInput(const string &dst_name, const ge::Operator &src_oprt) {
  460. // Describe the connection relationship between operators, no create action
  461. GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return *this, "operator impl is nullptr.");
  462. operator_impl_->SetInputImpl(dst_name, src_oprt);
  463. return *this;
  464. }
  465. Operator &Operator::SetInput(const string &dst_name, const ge::OutHandler &out_handler) {
  466. GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return *this, "operator impl is nullptr.");
  467. operator_impl_->SetInputImpl(dst_name, out_handler);
  468. return *this;
  469. }
  470. Operator &Operator::SetInput(const std::string &dst_name, const ge::Operator &src_oprt, const std::string &name) {
  471. auto out_handler = src_oprt.GetOutput(name);
  472. GE_CHK_BOOL_EXEC(out_handler != nullptr, return *this, "out_handler is nullptr.");
  473. (void)SetInput(dst_name, out_handler);
  474. return *this;
  475. }
  476. Operator &Operator::SetInput(const std::string &dst_name, const ge::Operator &src_oprt, uint32_t index) {
  477. auto out_handler = src_oprt.GetOutput(index);
  478. GE_CHK_BOOL_EXEC(out_handler != nullptr, return *this, "out_handler is nullptr.");
  479. (void)SetInput(dst_name, out_handler);
  480. return *this;
  481. }
  482. Operator &Operator::AddControlInput(const Operator &src_oprt) {
  483. if (operator_impl_ == nullptr) {
  484. GELOGE(GRAPH_FAILED, "operator impl is nullptr.");
  485. return *this;
  486. }
  487. operator_impl_->AddControlInputImp(src_oprt);
  488. return *this;
  489. }
  490. graphStatus Operator::GetInputConstData(const string &dst_name, Tensor &data) const {
  491. GE_CHECK_NOTNULL(operator_impl_);
  492. auto node_ptr = operator_impl_->GetNode();
  493. if (node_ptr != nullptr) {
  494. // For inner compute graph
  495. auto op_desc = node_ptr->GetOpDesc();
  496. GE_CHECK_NOTNULL(op_desc);
  497. auto index = op_desc->GetInputIndexByName(dst_name);
  498. auto in_data_anchor = node_ptr->GetInDataAnchor(index);
  499. GE_CHECK_NOTNULL(in_data_anchor);
  500. auto out_data_anchor = in_data_anchor->GetPeerOutAnchor();
  501. GE_CHECK_NOTNULL(out_data_anchor);
  502. auto peer_node = out_data_anchor->GetOwnerNode();
  503. GE_CHECK_NOTNULL(peer_node);
  504. auto peer_op_desc = peer_node->GetOpDesc();
  505. GE_CHECK_NOTNULL(peer_op_desc);
  506. auto peer_op_type = peer_op_desc->GetType();
  507. if (peer_op_type == CONSTANTOP || peer_op_type == CONSTANT) {
  508. auto const_op_impl = ComGraphMakeShared<OperatorImpl>(peer_node);
  509. GE_CHECK_NOTNULL(const_op_impl);
  510. Operator const_op(std::move(const_op_impl));
  511. return const_op.GetAttr(ATTR_NAME_WEIGHTS, data);
  512. } else if (peer_op_type == DATA) {
  513. auto parent_node = NodeUtils::GetParentInput(peer_node);
  514. while ((parent_node != nullptr) && (parent_node->GetType() == DATA)) {
  515. parent_node = NodeUtils::GetParentInput(parent_node);
  516. }
  517. if ((parent_node != nullptr) &&
  518. ((parent_node->GetType() == CONSTANT) || (parent_node->GetType() == CONSTANTOP))) {
  519. auto const_op_impl = ComGraphMakeShared<OperatorImpl>(parent_node);
  520. GE_CHECK_NOTNULL(const_op_impl);
  521. Operator const_op(std::move(const_op_impl));
  522. return const_op.GetAttr(ATTR_NAME_WEIGHTS, data);
  523. }
  524. }
  525. // Try get from runtime inference context
  526. auto session_id = std::to_string(GetContext().SessionId());
  527. RuntimeInferenceContext *runtime_infer_ctx = nullptr;
  528. if (RuntimeInferenceContext::GetContext(session_id, &runtime_infer_ctx) == GRAPH_SUCCESS) {
  529. GELOGD("To get constant from runtime inference context. session_id = %s", session_id.c_str());
  530. auto ret = runtime_infer_ctx->GetTensor(peer_node->GetOpDesc()->GetId(), out_data_anchor->GetIdx(), data);
  531. if (ret == GRAPH_SUCCESS) {
  532. return GRAPH_SUCCESS;
  533. }
  534. }
  535. } else {
  536. // For outer graph
  537. return GetInputConstDataOut(dst_name, data);
  538. }
  539. auto op_name = operator_impl_->GetName();
  540. GELOGW("node[%s]'s input[%s]'s peer node is not const", op_name.c_str(), dst_name.c_str());
  541. return GRAPH_FAILED;
  542. }
  543. graphStatus Operator::GetInputConstDataOut(const string &dst_name, Tensor &data) const {
  544. ge::OpIO out_handle("", 0, nullptr);
  545. GE_CHECK_NOTNULL(operator_impl_);
  546. if (operator_impl_->GetInputImpl(dst_name, out_handle) != GRAPH_SUCCESS) {
  547. GELOGE(FAILED, "%s get input impl failed", dst_name.c_str());
  548. return GRAPH_FAILED;
  549. }
  550. if (out_handle.GetOwner() != nullptr && out_handle.GetOwner()->GetOpDescImpl() != nullptr) {
  551. Operator const_op(out_handle.GetOwner());
  552. const auto &op_desc_impl_type = out_handle.GetOwner()->GetOpDescImpl()->GetType();
  553. if (op_desc_impl_type == CONSTANTOP) {
  554. return const_op.GetAttr(op::Constant::name_attr_value(), data);
  555. } else if (op_desc_impl_type == CONSTANT) {
  556. return const_op.GetAttr(op::Const::name_attr_value(), data);
  557. }
  558. }
  559. return GRAPH_FAILED;
  560. }
  561. std::shared_ptr<const Node> Operator::GetNode() const {
  562. GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return nullptr, "operator impl is nullptr.");
  563. return operator_impl_->GetNode();
  564. }
  565. TensorDesc Operator::GetInputDesc(const std::string &name) const {
  566. GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return TensorDesc(), "operator impl is nullptr.");
  567. return TensorAdapter::GeTensorDesc2TensorDesc(operator_impl_->GetInputDesc(name));
  568. }
  569. void Operator::SetInferenceContext(const InferenceContextPtr &inference_context) {
  570. GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return, "operator impl is nullptr.");
  571. operator_impl_->SetInferenceContext(inference_context);
  572. }
  573. InferenceContextPtr Operator::GetInferenceContext() const {
  574. GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return nullptr, "operator impl is nullptr.");
  575. return operator_impl_->GetInferenceContext();
  576. }
  577. TensorDesc Operator::GetInputDesc(uint32_t index) const {
  578. GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return TensorDesc(), "operator impl is nullptr.");
  579. return TensorAdapter::GeTensorDesc2TensorDesc(operator_impl_->GetInputDesc(index));
  580. }
  581. graphStatus Operator::TryGetInputDesc(const string &name, TensorDesc &tensor_desc) const {
  582. GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return GRAPH_FAILED, "operator impl is nullptr.");
  583. auto check = operator_impl_->InputIsSet(name);
  584. if (check) tensor_desc = TensorAdapter::GeTensorDesc2TensorDesc(operator_impl_->GetInputDesc(name));
  585. return check ? GRAPH_SUCCESS : GRAPH_FAILED;
  586. }
  587. graphStatus Operator::UpdateInputDesc(const std::string &name, const ge::TensorDesc &tensor_desc) {
  588. GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return GRAPH_FAILED, "operator impl is nullptr.");
  589. return operator_impl_->UpdateInputDesc(name, TensorAdapter::TensorDesc2GeTensorDesc(tensor_desc));
  590. }
  591. OutHandler Operator::GetOutput(const string &name) const {
  592. GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return nullptr, "operator impl is nullptr.");
  593. return operator_impl_->GetOutput(name);
  594. }
  595. OutHandler Operator::GetOutput(uint32_t index) const {
  596. GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return nullptr, "operator impl is nullptr.");
  597. return operator_impl_->GetOutput(index);
  598. }
  599. TensorDesc Operator::GetOutputDesc(const std::string &name) const {
  600. GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return TensorDesc(), "operator impl is nullptr.");
  601. return TensorAdapter::GeTensorDesc2TensorDesc(operator_impl_->GetOutputDesc(name));
  602. }
  603. TensorDesc Operator::GetOutputDesc(uint32_t index) const {
  604. GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return TensorDesc(), "operator impl is nullptr.");
  605. return TensorAdapter::GeTensorDesc2TensorDesc(operator_impl_->GetOutputDesc(index));
  606. }
  607. graphStatus Operator::UpdateOutputDesc(const std::string &name, const ge::TensorDesc &tensor_desc) {
  608. GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return GRAPH_FAILED, "operator impl is nullptr.");
  609. return operator_impl_->UpdateOutputDesc(name, TensorAdapter::TensorDesc2GeTensorDesc(tensor_desc));
  610. }
  611. TensorDesc Operator::GetDynamicInputDesc(const string &name, uint32_t index) const {
  612. GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return TensorDesc(), "operator impl is nullptr.");
  613. return TensorAdapter::GeTensorDesc2TensorDesc(operator_impl_->GetInputDesc(name + std::to_string(index)));
  614. }
  615. graphStatus Operator::UpdateDynamicInputDesc(const string &name, uint32_t index, const TensorDesc &tensor_desc) {
  616. GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return GRAPH_FAILED, "operator impl is nullptr.");
  617. return operator_impl_->UpdateInputDesc(name + std::to_string(index),
  618. TensorAdapter::TensorDesc2GeTensorDesc(tensor_desc));
  619. }
  620. TensorDesc Operator::GetDynamicOutputDesc(const string &name, uint32_t index) const {
  621. GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return TensorDesc(), "operator impl is nullptr.");
  622. return TensorAdapter::GeTensorDesc2TensorDesc(operator_impl_->GetOutputDesc(name + std::to_string(index)));
  623. }
  624. graphStatus Operator::UpdateDynamicOutputDesc(const string &name, uint32_t index, const TensorDesc &tensor_desc) {
  625. GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return GRAPH_FAILED, "operator impl is nullptr.");
  626. return operator_impl_->UpdateOutputDesc(name + std::to_string(index),
  627. TensorAdapter::TensorDesc2GeTensorDesc(tensor_desc));
  628. }
  629. graphStatus Operator::InferShapeAndType() {
  630. GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return GRAPH_FAILED, "operator impl is nullptr.");
  631. GE_CHK_BOOL_EXEC(operator_impl_->GetOpDescImpl() != nullptr, return GRAPH_FAILED, "GetOpDescImpl is nullptr.");
  632. return operator_impl_->GetOpDescImpl()->CallInferFunc(*this);
  633. }
  634. graphStatus Operator::VerifyAllAttr(bool disable_common_verifier) {
  635. GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return GRAPH_FAILED, "operator impl is nullptr.");
  636. GE_CHK_BOOL_EXEC(operator_impl_->GetOpDescImpl() != nullptr, return GRAPH_FAILED, "GetOpDescImpl is nullptr.");
  637. if (!disable_common_verifier && (graphStatus)Operator::VerifyAll() == GRAPH_FAILED) {
  638. return GRAPH_FAILED;
  639. } else {
  640. return (graphStatus)operator_impl_->GetOpDescImpl()->OpVerify();
  641. }
  642. }
  643. GE_FUNC_HOST_VISIBILITY size_t Operator::GetInputsSize() const {
  644. GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return 0, "OperatorImpl_ is nullptr");
  645. return operator_impl_->GetInputsSize();
  646. }
  647. GE_FUNC_HOST_VISIBILITY size_t Operator::GetOutputsSize() const {
  648. GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return 0, "OperatorImpl_ is nullptr");
  649. return operator_impl_->GetOutputsSize();
  650. }
  651. // According to op get the attrs name and type
  652. namespace {
  653. const std::map<GeAttrValue::ValueType, std::string> kAttrTypesMap = {
  654. {GeAttrValue::VT_NONE, "VT_STRING"},
  655. {GeAttrValue::VT_STRING, "VT_STRING"},
  656. {GeAttrValue::VT_FLOAT, "VT_FLOAT"},
  657. {GeAttrValue::VT_BOOL, "VT_BOOL"},
  658. {GeAttrValue::VT_INT, "VT_INT"},
  659. {GeAttrValue::VT_TENSOR_DESC, "VT_TENSOR_DESC"},
  660. {GeAttrValue::VT_TENSOR, "VT_TENSOR"},
  661. {GeAttrValue::VT_BYTES, "VT_BYTES"},
  662. {GeAttrValue::VT_GRAPH, "VT_GRAPH"},
  663. {GeAttrValue::VT_NAMED_ATTRS, "VT_NAMED_ATTRS"},
  664. {GeAttrValue::VT_LIST_BASE, "VT_LIST_BASE"},
  665. {GeAttrValue::VT_LIST_STRING, "VT_LIST_STRING"},
  666. {GeAttrValue::VT_LIST_FLOAT, "VT_LIST_FLOAT"},
  667. {GeAttrValue::VT_LIST_BOOL, "VT_LIST_BOOL"},
  668. {GeAttrValue::VT_LIST_INT, "VT_LIST_INT"},
  669. {GeAttrValue::VT_LIST_TENSOR_DESC, "VT_LIST_TENSOR_DESC"},
  670. {GeAttrValue::VT_LIST_TENSOR, "VT_LIST_TENSOR"},
  671. {GeAttrValue::VT_LIST_BYTES, "VT_LIST_BYTES"},
  672. {GeAttrValue::VT_GRAPH, "VT_GRAPH"},
  673. {GeAttrValue::VT_LIST_NAMED_ATTRS, "VT_LIST_NAMED_ATTRS"},
  674. };
  675. } // namespace
  676. const std::map<std::string, std::string> Operator::GetAllAttrNamesAndTypes() const {
  677. std::map<std::string, std::string> attr_types;
  678. GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return attr_types, "operator impl is nullptr.");
  679. GE_CHK_BOOL_EXEC(operator_impl_->GetOpDescImpl() != nullptr, return attr_types, "GetOpDescImpl is nullptr.");
  680. std::map<string, GeAttrValue> attr_map = operator_impl_->GetOpDescImpl()->GetAllAttrs();
  681. map<string, GeAttrValue>::iterator iter;
  682. for (iter = attr_map.begin(); iter != attr_map.end(); ++iter) {
  683. string name = iter->first;
  684. GeAttrValue attr_value = iter->second;
  685. GeAttrValue::ValueType type = attr_value.GetValueType();
  686. auto iter2 = kAttrTypesMap.find(type);
  687. if (iter2 != kAttrTypesMap.end()) {
  688. attr_types[name] = iter2->second;
  689. }
  690. }
  691. return attr_types;
  692. }
  693. void Operator::InputRegister(const string &name) {
  694. GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return, "operator impl is nullptr.");
  695. GE_CHK_BOOL_EXEC(operator_impl_->GetOpDescImpl() != nullptr, return, "GetOpDescImpl is nullptr.");
  696. (void)operator_impl_->GetOpDescImpl()->AddInputDesc(name, GeTensorDesc());
  697. }
  698. void Operator::OptionalInputRegister(const string &name) {
  699. GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return, "operator impl is nullptr.");
  700. GE_CHK_BOOL_EXEC(operator_impl_->GetOpDescImpl() != nullptr, return, "GetOpDescImpl is nullptr.");
  701. // [No need to verify return value]
  702. (void)operator_impl_->GetOpDescImpl()->AddOptionalInputDesc(name,
  703. GeTensorDesc(GeShape(), FORMAT_RESERVED, DT_UNDEFINED));
  704. }
  705. void Operator::InferFuncRegister(const std::function<graphStatus(Operator &)> &func) {
  706. GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return, "operator impl is nullptr.");
  707. GE_CHK_BOOL_EXEC(operator_impl_->GetOpDescImpl() != nullptr, return, "GetOpDescImpl is nullptr.");
  708. // [No need to verify return value]
  709. (void)operator_impl_->GetOpDescImpl()->AddInferFunc(func);
  710. }
  711. void Operator::InferFormatFuncRegister(const std::function<graphStatus(Operator &)> &func) {
  712. GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return, "operator impl is nullptr.");
  713. GE_CHK_BOOL_EXEC(operator_impl_->GetOpDescImpl() != nullptr, return, "GetOpDescImpl is nullptr.");
  714. // [No need to verify return value]
  715. (void)operator_impl_->GetOpDescImpl()->AddInferFormatFunc(func);
  716. }
  717. void Operator::VerifierFuncRegister(const std::function<graphStatus(Operator &)> &func) {
  718. GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return, "operator impl is nullptr.");
  719. GE_CHK_BOOL_EXEC(operator_impl_->GetOpDescImpl() != nullptr, return, "GetOpDescImpl is nullptr.");
  720. // [No need to verify return value]
  721. (void)operator_impl_->GetOpDescImpl()->AddVerifierFunc(func);
  722. }
  723. void Operator::OutputRegister(const string &name) {
  724. GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return, "operator impl is nullptr.");
  725. GE_CHK_BOOL_EXEC(operator_impl_->GetOpDescImpl() != nullptr, return, "GetOpDescImpl is nullptr.");
  726. // [No need to verify return value]
  727. (void)operator_impl_->GetOpDescImpl()->AddOutputDesc(name, GeTensorDesc());
  728. }
  729. void Operator::DynamicInputRegister(const string &name, const unsigned int num, bool is_push_back) {
  730. GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return, "operator impl is nullptr.");
  731. GE_CHK_BOOL_EXEC(operator_impl_->GetOpDescImpl() != nullptr, return, "GetOpDescImpl is nullptr.");
  732. GE_CHK_BOOL_EXEC(AttrUtils::SetInt(operator_impl_->GetOpDescImpl(), DYNAMIC_INPUT_TD_NUM(name), num), return,
  733. "set int failed");
  734. (void)operator_impl_->GetOpDescImpl()->AddDynamicInputDesc(name, num, is_push_back);
  735. }
  736. void Operator::DynamicInputRegisterByIndex(const string &name, const unsigned int num, size_t index) {
  737. GE_CHK_BOOL_EXEC(!!operator_impl_, return, "operator impl is nullptr.");
  738. GE_CHK_BOOL_EXEC(nullptr != operator_impl_->GetOpDescImpl(), return, "GetOpDescImpl is nullptr.");
  739. operator_impl_->GetOpDescImpl()->AddDynamicInputDescByIndex(name, num, index);
  740. }
  741. int Operator::GetDynamicInputNum(const string &name) const {
  742. GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return 0, "operator impl is nullptr.");
  743. GE_CHK_BOOL_EXEC(operator_impl_->GetOpDescImpl() != nullptr, return 0, "GetOpDescImpl is nullptr.");
  744. int num = 0;
  745. GE_CHK_BOOL_EXEC(AttrUtils::GetInt(operator_impl_->GetOpDescImpl(), DYNAMIC_INPUT_TD_NUM(name), num), return num,
  746. "Get %s int failed", name.c_str());
  747. return num;
  748. }
  749. void Operator::DynamicOutputRegister(const string &name, const unsigned int num, bool is_push_back) {
  750. GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return, "operator impl is nullptr.");
  751. GE_CHK_BOOL_EXEC(operator_impl_->GetOpDescImpl() != nullptr, return, "GetOpDescImpl is nullptr.");
  752. GE_CHK_BOOL_EXEC(AttrUtils::SetInt(operator_impl_->GetOpDescImpl(), DYNAMIC_OUTPUT_TD_NUM(name), num), return,
  753. "Set %s int failed", name.c_str());
  754. (void)operator_impl_->GetOpDescImpl()->AddDynamicOutputDesc(name, num, is_push_back);
  755. }
  756. int Operator::GetDynamicOutputNum(const string &name) const {
  757. GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return 0, "operator impl is nullptr.");
  758. GE_CHK_BOOL_EXEC(operator_impl_->GetOpDescImpl() != nullptr, return 0, "GetOpDescImpl is nullptr.");
  759. int num = 0;
  760. GE_CHK_BOOL_EXEC(AttrUtils::GetInt(operator_impl_->GetOpDescImpl(), DYNAMIC_OUTPUT_TD_NUM(name), num), return num,
  761. "Get %s int failed", name.c_str());
  762. return num;
  763. }
  764. void Operator::RequiredAttrRegister(const string &name) {
  765. GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return, "operator impl is nullptr.");
  766. GE_CHK_BOOL_EXEC(operator_impl_->GetOpDescImpl() != nullptr, return, "GetOpDescImpl is nullptr.");
  767. operator_impl_->GetOpDescImpl()->AddRequiredAttr(name);
  768. }
  769. graphStatus Operator::VerifyAll() {
  770. GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return GRAPH_FAILED, "operator impl is nullptr.");
  771. GE_CHK_BOOL_EXEC(operator_impl_->GetOpDescImpl() != nullptr, return GRAPH_FAILED, "GetOpDescImpl is nullptr.");
  772. // Check all inputs defined
  773. for (const string &iname : operator_impl_->GetOpDescImpl()->GetAllInputNames()) {
  774. GE_CHK_BOOL_RET_STATUS(operator_impl_->GetOpDescImpl()->IsOptionalInput(iname) || operator_impl_->InputIsSet(iname),
  775. GRAPH_FAILED, "operator input %s is not linked.", iname.c_str());
  776. vector<int64_t> ishape = operator_impl_->GetOpDescImpl()->GetInputDesc(iname).GetShape().GetDims();
  777. for (int64_t dim : ishape) {
  778. GE_CHK_BOOL_RET_STATUS(dim > 0, GRAPH_FAILED, "operator input %s shape contains negative or zero dimension.",
  779. iname.c_str());
  780. }
  781. }
  782. // Check all attributes defined
  783. const auto all_attributes = operator_impl_->GetOpDescImpl()->GetAllAttrs();
  784. for (const auto &name : operator_impl_->GetOpDescImpl()->GetAllAttrNames()) {
  785. GE_CHK_BOOL_RET_STATUS(all_attributes.find(name) != all_attributes.end(), GRAPH_FAILED,
  786. "operator attribute %s is empty.", name.c_str());
  787. }
  788. return GRAPH_SUCCESS;
  789. }
  790. string Operator::GetOpType() const {
  791. GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return "Data", "operator impl is nullptr.");
  792. return OperatorImpl::GetOpDesc(*this)->GetType();
  793. }
  794. Operator &Operator::SetInput(const std::string &dst_name, uint32_t dst_index, const ge::Operator &src_oprt) {
  795. string dynamic_dst_name = DYNAMIN_INPUT_NAME(dst_name, dst_index);
  796. return SetInput(dynamic_dst_name, src_oprt);
  797. }
  798. Operator &Operator::SetInput(const std::string &dst_name, uint32_t dst_index, const ge::Operator &src_oprt,
  799. const std::string &name) {
  800. string dynamic_dst_name = DYNAMIN_INPUT_NAME(dst_name, dst_index);
  801. return SetInput(dynamic_dst_name, src_oprt, name);
  802. }
  803. OperatorImplPtr Operator::GetOperatorImplPtr() const { return operator_impl_; }
  804. #define OP_ATTR_SET_IMP(ArgType, AttrUtilsFun) \
  805. Operator &Operator::SetAttr(const string &name, ArgType attr_value) { \
  806. if (operator_impl_ == nullptr || operator_impl_->GetOpDescImpl() == nullptr) { \
  807. GELOGE(GRAPH_FAILED, "operator impl is nullptr, name %s.", name.c_str()); \
  808. return *this; \
  809. } \
  810. if (!AttrUtils::Set##AttrUtilsFun(operator_impl_->GetOpDescImpl(), name, attr_value)) { \
  811. GELOGW("set attr name %s failed.", name.c_str()); \
  812. } \
  813. return *this; \
  814. } // lint !e665
  815. #define OP_ATTR_GET_IMP(ArgType, AttrUtilsFun) \
  816. graphStatus Operator::GetAttr(const string &name, ArgType attr_value) const { \
  817. if (operator_impl_ == nullptr || operator_impl_->GetOpDescImpl() == nullptr) { \
  818. GELOGE(GRAPH_FAILED, "operator impl is nullptr, name %s.", name.c_str()); \
  819. return GRAPH_FAILED; \
  820. } \
  821. if (!AttrUtils::Get##AttrUtilsFun(operator_impl_->GetOpDescImpl(), name, attr_value)) { \
  822. GELOGW("get attr name %s failed.", name.c_str()); \
  823. return GRAPH_FAILED; \
  824. } \
  825. return GRAPH_SUCCESS; \
  826. } // lint !e665
  827. void Operator::BreakConnect() const {
  828. if (operator_impl_ == nullptr) {
  829. GELOGW("operator impl is nullptr.");
  830. return;
  831. }
  832. operator_impl_->ClearInputLinks();
  833. operator_impl_->ClearOutputLinks();
  834. OperatorKeeper::GetInstance().CheckOutOperator(operator_impl_);
  835. }
  836. #define OP_ATTR_REG_IMP(ArgType, AttrUtilsFun) \
  837. void Operator::AttrRegister(const string &name, ArgType attr_value) { \
  838. if (operator_impl_ == nullptr || operator_impl_->GetOpDescImpl() == nullptr) { \
  839. GELOGE(GRAPH_FAILED, "operator impl is nullptr, name %s.", name.c_str()); \
  840. return; \
  841. } \
  842. if (!AttrUtils::Set##AttrUtilsFun(operator_impl_->GetOpDescImpl(), name, attr_value)) { \
  843. GELOGW("reg attr name %s failed.", name.c_str()); \
  844. } \
  845. } // lint !e665
  846. OP_ATTR_SET_IMP(int64_t, Int)
  847. OP_ATTR_SET_IMP(int32_t, Int)
  848. OP_ATTR_SET_IMP(uint32_t, Int)
  849. OP_ATTR_GET_IMP(int64_t &, Int)
  850. OP_ATTR_GET_IMP(int32_t &, Int)
  851. OP_ATTR_GET_IMP(uint32_t &, Int)
  852. OP_ATTR_SET_IMP(const vector<int64_t> &, ListInt)
  853. OP_ATTR_SET_IMP(const vector<int32_t> &, ListInt)
  854. OP_ATTR_SET_IMP(const vector<uint32_t> &, ListInt)
  855. OP_ATTR_SET_IMP(std::initializer_list<int64_t> &&, ListInt)
  856. OP_ATTR_GET_IMP(vector<int64_t> &, ListInt)
  857. OP_ATTR_GET_IMP(vector<int32_t> &, ListInt)
  858. OP_ATTR_GET_IMP(vector<uint32_t> &, ListInt)
  859. OP_ATTR_GET_IMP(vector<vector<int64_t>> &, ListListInt)
  860. OP_ATTR_SET_IMP(const vector<vector<int64_t>> &, ListListInt)
  861. OP_ATTR_SET_IMP(float, Float)
  862. OP_ATTR_GET_IMP(float &, Float)
  863. OP_ATTR_SET_IMP(const vector<float> &, ListFloat)
  864. OP_ATTR_GET_IMP(vector<float> &, ListFloat) // lint !e665
  865. OP_ATTR_SET_IMP(bool, Bool)
  866. OP_ATTR_GET_IMP(bool &, Bool)
  867. OP_ATTR_SET_IMP(const vector<bool> &, ListBool)
  868. OP_ATTR_GET_IMP(vector<bool> &, ListBool) // lint !e665
  869. OP_ATTR_SET_IMP(const string &, Str)
  870. OP_ATTR_GET_IMP(string &, Str)
  871. OP_ATTR_SET_IMP(const vector<string> &, ListStr)
  872. OP_ATTR_GET_IMP(vector<string> &, ListStr) // lint !e665
  873. OP_ATTR_SET_IMP(const GeAttrValue::NAMED_ATTRS &, NamedAttrs)
  874. OP_ATTR_GET_IMP(GeAttrValue::NAMED_ATTRS &, NamedAttrs)
  875. OP_ATTR_SET_IMP(const vector<GeAttrValue::NAMED_ATTRS> &, ListNamedAttrs)
  876. OP_ATTR_GET_IMP(vector<GeAttrValue::NAMED_ATTRS> &, ListNamedAttrs) // lint !e665
  877. OP_ATTR_REG_IMP(int64_t, Int)
  878. OP_ATTR_REG_IMP(const vector<int64_t> &, ListInt)
  879. OP_ATTR_REG_IMP(float, Float)
  880. OP_ATTR_REG_IMP(const vector<float> &, ListFloat)
  881. OP_ATTR_REG_IMP(const string &, Str)
  882. OP_ATTR_REG_IMP(const vector<string> &, ListStr)
  883. OP_ATTR_REG_IMP(bool, Bool)
  884. OP_ATTR_REG_IMP(const vector<bool> &, ListBool)
  885. OP_ATTR_REG_IMP(const vector<vector<int64_t>> &, ListListInt)
  886. OP_ATTR_REG_IMP(const GeAttrValue::NAMED_ATTRS &, NamedAttrs)
  887. OP_ATTR_REG_IMP(const vector<GeAttrValue::NAMED_ATTRS> &, ListNamedAttrs)
  888. #undef OP_ATTR_SET_IMP
  889. #undef OP_ATTR_GET_IMP
  890. #undef OP_ATTR_REG_IMP
  891. Operator &Operator::SetAttr(const string &name, const Tensor &attr_value) {
  892. if (operator_impl_ == nullptr || operator_impl_->GetOpDescImpl() == nullptr) {
  893. GELOGE(GRAPH_FAILED, "operator impl is nullptr, name %s.", name.c_str());
  894. return *this;
  895. }
  896. GeTensor tensor = TensorAdapter::AsGeTensor(attr_value);
  897. if (!AttrUtils::SetTensor(operator_impl_->GetOpDescImpl(), name, tensor)) {
  898. GELOGW("set attr name %s failed.", name.c_str());
  899. }
  900. return *this;
  901. }
  902. Operator &Operator::SetAttr(const string &name, const vector<Tensor> &attr_value) {
  903. if (operator_impl_ == nullptr || operator_impl_->GetOpDescImpl() == nullptr) {
  904. GELOGE(GRAPH_FAILED, "operator impl is nullptr, name %s.", name.c_str());
  905. return *this;
  906. }
  907. vector<GeTensor> val_list;
  908. for (const auto &item : attr_value) {
  909. auto tensor = TensorAdapter::AsGeTensor(item);
  910. val_list.push_back(tensor);
  911. }
  912. if (!AttrUtils::SetListTensor(operator_impl_->GetOpDescImpl(), name, val_list)) {
  913. GELOGW("set attr name %s failed.", name.c_str());
  914. }
  915. return *this;
  916. }
  917. graphStatus Operator::GetAttr(const string &name, Tensor &attr_value) const {
  918. if (operator_impl_ == nullptr || operator_impl_->GetOpDescImpl() == nullptr) {
  919. GELOGE(GRAPH_FAILED, "operator impl is nullptr, name %s.", name.c_str());
  920. return GRAPH_FAILED;
  921. }
  922. ConstGeTensorPtr tensor;
  923. if (!AttrUtils::GetTensor(operator_impl_->GetOpDescImpl(), name, tensor)) {
  924. GELOGW("get attr name %s failed.", name.c_str());
  925. return GRAPH_FAILED;
  926. }
  927. attr_value = TensorAdapter::GeTensor2Tensor(tensor);
  928. return GRAPH_SUCCESS;
  929. }
  930. graphStatus Operator::GetAttr(const string &name, vector<Tensor> &attr_value) const {
  931. attr_value.clear();
  932. if (operator_impl_ == nullptr || operator_impl_->GetOpDescImpl() == nullptr) {
  933. GELOGE(GRAPH_FAILED, "operator impl is nullptr, name %s.", name.c_str());
  934. return GRAPH_FAILED;
  935. }
  936. vector<ConstGeTensorPtr> val_list;
  937. if (!AttrUtils::GetListTensor(operator_impl_->GetOpDescImpl(), name, val_list)) {
  938. GELOGW("get attr name %s failed.", name.c_str());
  939. return GRAPH_FAILED;
  940. }
  941. for (auto &tensor : val_list) {
  942. attr_value.push_back(TensorAdapter::GeTensor2Tensor(tensor));
  943. }
  944. return GRAPH_SUCCESS;
  945. }
  946. Operator &Operator::SetAttr(const string &name, const OpBytes &attr_value) {
  947. if (operator_impl_ == nullptr || operator_impl_->GetOpDescImpl() == nullptr) {
  948. GELOGE(GRAPH_FAILED, "operator impl is nullptr, name %s.", name.c_str());
  949. return *this;
  950. }
  951. if (!AttrUtils::SetZeroCopyBytes(operator_impl_->GetOpDescImpl(), name,
  952. Buffer::CopyFrom(attr_value.data(), attr_value.size()))) {
  953. GELOGW("set attr name %s failed.", name.c_str());
  954. }
  955. return *this;
  956. }
  957. graphStatus Operator::GetAttr(const string &name, OpBytes &attr_value) const {
  958. if (operator_impl_ == nullptr || operator_impl_->GetOpDescImpl() == nullptr) {
  959. GELOGE(GRAPH_FAILED, "operator impl is nullptr, name %s.", name.c_str());
  960. return GRAPH_FAILED;
  961. }
  962. Buffer buffer;
  963. if (!AttrUtils::GetZeroCopyBytes(operator_impl_->GetOpDescImpl(), name, buffer)) {
  964. GELOGW("get attr name %s failed.", name.c_str());
  965. return GRAPH_FAILED;
  966. }
  967. attr_value.clear();
  968. if (buffer.data() == nullptr) {
  969. GELOGE(GRAPH_FAILED, "buffer data is null.");
  970. return GRAPH_FAILED;
  971. }
  972. attr_value.assign(buffer.data(), buffer.data() + buffer.size());
  973. return GRAPH_SUCCESS;
  974. }
  975. Operator &Operator::SetAttr(const string &name, ge::AttrValue &&attrValue) {
  976. GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return *this, "operator impl is nullptr.");
  977. (void)operator_impl_->SetAttr(name, std::move(attrValue.impl->geAttrValue_));
  978. return *this;
  979. }
  980. graphStatus Operator::GetAttr(const string &name, ge::AttrValue &attrValue) const {
  981. GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return GRAPH_FAILED, "operator impl is nullptr.");
  982. return operator_impl_->GetAttr(name, attrValue.impl->geAttrValue_);
  983. }
  984. Operator &Operator::SetAttr(const string &name, const std::vector<ge::DataType> &attr_value) {
  985. if (operator_impl_ == nullptr || !operator_impl_->GetOpDescImpl()) {
  986. GELOGE(GRAPH_FAILED, "operator impl is nullptr, name %s.", name.c_str());
  987. return *this;
  988. }
  989. if (!AttrUtils::SetListDataType(operator_impl_->GetOpDescImpl(), name, attr_value)) {
  990. GELOGW("set attr name %s failed.", name.c_str());
  991. }
  992. return *this;
  993. }
  994. graphStatus Operator::GetAttr(const string &name, std::vector<ge::DataType> &attr_value) const {
  995. attr_value.clear();
  996. if (operator_impl_ == nullptr || operator_impl_->GetOpDescImpl() == nullptr) {
  997. GELOGE(GRAPH_FAILED, "operator impl is nullptr, name %s.", name.c_str());
  998. return GRAPH_FAILED;
  999. }
  1000. if (!AttrUtils::GetListDataType(operator_impl_->GetOpDescImpl(), name, attr_value)) {
  1001. GELOGW("get attr name %s failed.", name.c_str());
  1002. return GRAPH_FAILED;
  1003. }
  1004. return GRAPH_SUCCESS;
  1005. }
  1006. Operator &Operator::SetAttr(const string &name, const ge::DataType &attr_value) {
  1007. if (operator_impl_ == nullptr || operator_impl_->GetOpDescImpl() == nullptr) {
  1008. GELOGE(GRAPH_FAILED, "operator impl is nullptr, name %s.", name.c_str());
  1009. return *this;
  1010. }
  1011. if (!AttrUtils::SetDataType(operator_impl_->GetOpDescImpl(), name, attr_value)) {
  1012. GELOGW("set attr name %s failed.", name.c_str());
  1013. }
  1014. return *this;
  1015. }
  1016. graphStatus Operator::GetAttr(const string &name, ge::DataType &attr_value) const {
  1017. if (operator_impl_ == nullptr || operator_impl_->GetOpDescImpl() == nullptr) {
  1018. GELOGE(GRAPH_FAILED, "operator impl is nullptr, name %s.", name.c_str());
  1019. return GRAPH_FAILED;
  1020. }
  1021. if (!AttrUtils::GetDataType(operator_impl_->GetOpDescImpl(), name, attr_value)) {
  1022. GELOGW("get attr name %s failed.", name.c_str());
  1023. return GRAPH_FAILED;
  1024. }
  1025. return GRAPH_SUCCESS;
  1026. }
  1027. void Operator::AttrRegister(const string &name, const std::vector<ge::DataType> &attr_value) {
  1028. if (operator_impl_ == nullptr || operator_impl_->GetOpDescImpl() == nullptr) {
  1029. GELOGE(GRAPH_FAILED, "operator impl is nullptr, name %s.", name.c_str());
  1030. return;
  1031. }
  1032. if (!AttrUtils::SetListDataType(operator_impl_->GetOpDescImpl(), name, attr_value)) {
  1033. GELOGW("set attr name %s failed.", name.c_str());
  1034. }
  1035. }
  1036. void Operator::AttrRegister(const string &name, const ge::DataType &attr_value) {
  1037. if (operator_impl_ == nullptr || operator_impl_->GetOpDescImpl() == nullptr) {
  1038. GELOGE(GRAPH_FAILED, "operator impl is nullptr, name %s.", name.c_str());
  1039. return;
  1040. }
  1041. if (!AttrUtils::SetDataType(operator_impl_->GetOpDescImpl(), name, attr_value)) {
  1042. GELOGW("set attr name %s failed.", name.c_str());
  1043. }
  1044. }
  1045. void Operator::AttrRegister(const string &name, const Tensor &attr_value) {
  1046. if (operator_impl_ == nullptr || operator_impl_->GetOpDescImpl() == nullptr) {
  1047. GELOGE(GRAPH_FAILED, "operator impl is nullptr, name %s.", name.c_str());
  1048. return;
  1049. }
  1050. auto tensor = TensorAdapter::AsGeTensor(attr_value);
  1051. if (!AttrUtils::SetTensor(operator_impl_->GetOpDescImpl(), name, tensor)) {
  1052. GELOGW("reg attr name %s failed.", name.c_str());
  1053. }
  1054. }
  1055. void Operator::AttrRegister(const string &name, const vector<Tensor> &attr_value) {
  1056. if (operator_impl_ == nullptr || operator_impl_->GetOpDescImpl() == nullptr) {
  1057. GELOGE(GRAPH_FAILED, "operator impl is nullptr, name %s.", name.c_str());
  1058. return;
  1059. }
  1060. vector<GeTensor> val_list;
  1061. for (const auto &item : attr_value) {
  1062. val_list.push_back(TensorAdapter::AsGeTensor(item));
  1063. }
  1064. if (!AttrUtils::SetListTensor(operator_impl_->GetOpDescImpl(), name, val_list)) {
  1065. GELOGW("reg attr name %s failed.", name.c_str());
  1066. }
  1067. }
  1068. void Operator::AttrRegister(const string &name, const OpBytes &attr_value) {
  1069. if (operator_impl_ == nullptr || operator_impl_->GetOpDescImpl() == nullptr) {
  1070. GELOGE(GRAPH_FAILED, "operator impl is nullptr, name %s.", name.c_str());
  1071. return;
  1072. }
  1073. if (!AttrUtils::SetZeroCopyBytes(operator_impl_->GetOpDescImpl(), name,
  1074. Buffer::CopyFrom(attr_value.data(), attr_value.size()))) {
  1075. GELOGW("reg attr name %s failed.", name.c_str());
  1076. }
  1077. }
  1078. void Operator::SubgraphRegister(const std::string &name, bool dynamic) {
  1079. if (operator_impl_ == nullptr) {
  1080. GELOGE(GRAPH_FAILED, "operator impl is nullptr, name %s.", name.c_str());
  1081. return;
  1082. }
  1083. operator_impl_->SubgraphRegister(name, dynamic ? kDynamic : kStatic);
  1084. }
  1085. void Operator::SubgraphCountRegister(const std::string &name, uint32_t count) {
  1086. if (operator_impl_ == nullptr) {
  1087. GELOGE(GRAPH_FAILED, "operator impl is nullptr, name %s.", name.c_str());
  1088. return;
  1089. }
  1090. operator_impl_->SubgraphCountRegister(name, count);
  1091. }
  1092. void Operator::SetSubgraphBuilder(const std::string &ir_name, uint32_t index, const SubgraphBuilder &builder) {
  1093. if (operator_impl_ == nullptr) {
  1094. GELOGE(GRAPH_FAILED, "operator impl is nullptr, name %s.", ir_name.c_str());
  1095. return;
  1096. }
  1097. operator_impl_->SetSubgraphBuilder(ir_name, index, builder);
  1098. }
  1099. std::vector<std::string> Operator::GetSubgraphNames() const { return operator_impl_->GetSubgraphNames(); }
  1100. SubgraphBuilder Operator::GetDynamicSubgraphBuilder(const string &ir_name, uint32_t index) const {
  1101. if (operator_impl_ == nullptr) {
  1102. GELOGE(GRAPH_FAILED, "operator impl is nullptr.");
  1103. return nullptr;
  1104. }
  1105. return operator_impl_->GetSubgraphBuilder(ir_name, index);
  1106. }
  1107. SubgraphBuilder Operator::GetSubgraphBuilder(const string &ir_name) const {
  1108. return GetDynamicSubgraphBuilder(ir_name, 0);
  1109. }
  1110. Graph Operator::GetSubgraph(const string &name) const {
  1111. if (operator_impl_ == nullptr) {
  1112. GE_LOGE("Failed to get subgraph %s, the operator impl is null", name.c_str());
  1113. return Graph("");
  1114. }
  1115. auto op_desc = OpDescUtils::GetOpDescFromOperator(*this);
  1116. if (op_desc == nullptr) {
  1117. GE_LOGE("Failed to get subgraph %s, the op_desc is null", name.c_str());
  1118. return Graph("");
  1119. }
  1120. const auto &subgraph_names_to_index = op_desc->GetSubgraphNameIndexes();
  1121. auto iter = subgraph_names_to_index.find(name);
  1122. if (iter == subgraph_names_to_index.end()) {
  1123. GE_LOGE("Failed to get subgraph %s, the name may be invalid", name.c_str());
  1124. return Graph("");
  1125. }
  1126. auto subgraph_instance_name = op_desc->GetSubgraphInstanceName(iter->second);
  1127. if (subgraph_instance_name.empty()) {
  1128. GE_LOGE("Failed to get subgraph %s index %u, the subgraph may not be added", name.c_str(), iter->second);
  1129. return Graph("");
  1130. }
  1131. auto node = operator_impl_->GetNode();
  1132. if (node == nullptr) {
  1133. GE_LOGE("Failed to get subgraph %s, the node is null", name.c_str());
  1134. return Graph("");
  1135. }
  1136. auto root_graph = GraphUtils::FindRootGraph(node->GetOwnerComputeGraph());
  1137. if (root_graph == nullptr) {
  1138. GE_LOGE("Failed to get subgraph %s, can not find the root graph", name.c_str());
  1139. return Graph("");
  1140. }
  1141. auto subgraph = root_graph->GetSubgraph(subgraph_instance_name);
  1142. if (subgraph == nullptr) {
  1143. GE_LOGE("Failed to get subgraph %s index %u, can not find the instance %s from the root graph", name.c_str(),
  1144. iter->second, subgraph_instance_name.c_str());
  1145. return Graph("");
  1146. }
  1147. return GraphUtils::CreateGraphFromComputeGraph(subgraph);
  1148. }
  1149. Graph Operator::GetDynamicSubgraph(const string &name, uint32_t index) const {
  1150. return GetSubgraph(name + std::to_string(index));
  1151. }
  1152. size_t Operator::GetSubgraphNamesCount() const {
  1153. if (operator_impl_ == nullptr) {
  1154. GE_LOGE("Failed to get subgraph names count, the operator impl is null");
  1155. return 0;
  1156. }
  1157. return operator_impl_->GetSubgraphNamesCount();
  1158. }
  1159. class GraphBuilderImpl {
  1160. public:
  1161. explicit GraphBuilderImpl(const string &name) : graph_(ComGraphMakeShared<ComputeGraph>(name)) {
  1162. if (graph_ == nullptr) {
  1163. GELOGE(GRAPH_FAILED, "ComputeGraph make shared failed");
  1164. return;
  1165. }
  1166. }
  1167. ~GraphBuilderImpl() {}
  1168. ComputeGraphPtr BuildGraph(const std::vector<Operator> &inputs) {
  1169. std::vector<OperatorImplPtr> vec_inputs;
  1170. for (auto &it : inputs) {
  1171. auto src_op_impl = it.operator_impl_;
  1172. GE_CHK_BOOL_EXEC(src_op_impl != nullptr, return nullptr, "Operator Impl is null.");
  1173. GE_CHK_BOOL_EXEC(src_op_impl->op_desc_ != nullptr, return nullptr, "Operator impl's opdesc is null.");
  1174. string type = src_op_impl->op_desc_->GetType();
  1175. auto node_op = ge::OperatorFactory::CreateOperator("node_op", type);
  1176. auto tensor_desc = ge::OpDescUtils::GetOpDescFromOperator(node_op);
  1177. node_op.BreakConnect();
  1178. GE_CHK_BOOL_EXEC(tensor_desc != nullptr, continue, "tensor_desc is null.");
  1179. if ((tensor_desc->GetInputsSize() == 0 && tensor_desc->GetOutputsSize() > 0) || type == DATA ||
  1180. type == VARIABLE || type == INITDATA || type == GETNEXT) {
  1181. vec_inputs.push_back(it.operator_impl_);
  1182. } else {
  1183. GELOGW("Input operator should be Data, Variable operator or operator that has output but no input.");
  1184. }
  1185. }
  1186. GE_CHK_BOOL_EXEC(!vec_inputs.empty(), return nullptr,
  1187. "User Input do not include operator such as "
  1188. "Data, Variable operator or operator that has output but no input.");
  1189. auto ret = WalkAllOperators(vec_inputs);
  1190. GE_CHK_BOOL_EXEC(ret == GRAPH_SUCCESS, return nullptr, "WalkAllOperators failed.");
  1191. ret = AddEdge();
  1192. GE_CHK_BOOL_EXEC(ret == GRAPH_SUCCESS, return nullptr, "AddEdge failed.");
  1193. return graph_;
  1194. }
  1195. const std::map<OperatorImplPtr, NodePtr> &GetAllNodesInfo() const { return all_nodes_info_; }
  1196. private:
  1197. graphStatus WalkAllOperators(const std::vector<OperatorImplPtr> &vec_ops) {
  1198. GE_CHK_BOOL_EXEC(graph_ != nullptr, return GRAPH_FAILED, "graph_ is null.")
  1199. std::queue<std::vector<OperatorImplPtr>> que;
  1200. que.push(vec_ops);
  1201. while (!que.empty()) {
  1202. auto vec_tem = que.front();
  1203. que.pop();
  1204. for (const auto &op_impl : vec_tem) {
  1205. GE_CHK_BOOL_EXEC(op_impl != nullptr, return GRAPH_FAILED, "Operator Impl is null.")
  1206. GE_CHK_BOOL_EXEC_INFO(all_nodes_info_.find(op_impl) == all_nodes_info_.end(), continue,
  1207. "This node %s has created.", op_impl->GetName().c_str())
  1208. auto node_ptr = graph_->AddNode(op_impl->op_desc_);
  1209. GE_CHK_BOOL_EXEC(node_ptr != nullptr, return GRAPH_FAILED, "Add node failed.");
  1210. all_nodes_info_.insert(std::make_pair(op_impl, node_ptr));
  1211. auto &out_links = op_impl->output_links_;
  1212. std::vector<OperatorImplPtr> vec_op_forward{};
  1213. for (const auto &out_link : out_links) {
  1214. for (const auto &op_forward : out_link.second) {
  1215. vec_op_forward.push_back(op_forward.GetOwner());
  1216. }
  1217. }
  1218. auto &out_control_links = op_impl->control_output_link_;
  1219. for (const auto &out_link : out_control_links) {
  1220. vec_op_forward.push_back(out_link.lock());
  1221. }
  1222. que.push(vec_op_forward);
  1223. auto &in_links = op_impl->input_link_;
  1224. std::vector<OperatorImplPtr> vec_op_back_forward{};
  1225. for (const auto &in_link : in_links) {
  1226. vec_op_back_forward.push_back(in_link.second.GetOwner());
  1227. }
  1228. auto &in_control_links = op_impl->control_input_link_;
  1229. for (const auto &in_link : in_control_links) {
  1230. vec_op_back_forward.push_back(in_link.lock());
  1231. }
  1232. que.push(vec_op_back_forward);
  1233. if (WalkAllSubgraphs(node_ptr, op_impl) != GRAPH_SUCCESS) {
  1234. return GRAPH_FAILED;
  1235. }
  1236. }
  1237. }
  1238. return MoveSubgraphToRoot(graph_);
  1239. }
  1240. graphStatus WalkAllSubgraphs(const NodePtr &node, const OperatorImplPtr &op_impl) {
  1241. const string name = node->GetName();
  1242. for (auto &name_idx : op_impl->op_desc_->GetSubgraphNameIndexes()) {
  1243. const SubgraphBuilder &builder = op_impl->GetSubgraphBuilder(name_idx.first);
  1244. GE_CHK_BOOL_EXEC(builder != nullptr, return GRAPH_FAILED, "Node: %s, Get builder failed.", name.c_str());
  1245. Graph graph = builder(); // Build subgraph from user define builder.
  1246. const ComputeGraphPtr &subgraph = GraphUtils::GetComputeGraph(graph);
  1247. GE_CHK_BOOL_EXEC(subgraph != nullptr, return GRAPH_FAILED, "Node: %s, Build graph failed.", name.c_str());
  1248. subgraph->SetParentNode(node);
  1249. subgraph->SetParentGraph(graph_);
  1250. if (graph_->AddSubgraph(subgraph->GetName(), subgraph) != GRAPH_SUCCESS) {
  1251. return GRAPH_FAILED;
  1252. }
  1253. if (op_impl->op_desc_->SetSubgraphInstanceName(name_idx.second, subgraph->GetName()) != GRAPH_SUCCESS) {
  1254. GELOGE(GRAPH_FAILED, "Failed to set subgraph %s index %u", subgraph->GetName().c_str(), name_idx.second);
  1255. return GRAPH_FAILED;
  1256. }
  1257. }
  1258. return GRAPH_SUCCESS;
  1259. }
  1260. graphStatus MoveSubgraphToRoot(const ComputeGraphPtr &graph) {
  1261. const ComputeGraphPtr &root_graph = GraphUtils::FindRootGraph(graph);
  1262. if (root_graph == nullptr) {
  1263. GELOGE(GRAPH_FAILED, "Graph: %s, Find root graph failed.", graph->GetName().c_str());
  1264. return GRAPH_FAILED;
  1265. }
  1266. if (root_graph == graph) {
  1267. auto subgraphs = graph->GetAllSubgraphs();
  1268. for (auto &subgraph : subgraphs) {
  1269. if (MoveSubgraphToRoot(subgraph) != GRAPH_SUCCESS) {
  1270. return GRAPH_FAILED;
  1271. }
  1272. }
  1273. } else {
  1274. auto subgraphs = graph->GetAllSubgraphs();
  1275. for (auto &subgraph : subgraphs) {
  1276. if (root_graph->AddSubgraph(subgraph->GetName(), subgraph) != GRAPH_SUCCESS) {
  1277. return GRAPH_FAILED;
  1278. }
  1279. graph->RemoveSubgraph(subgraph->GetName());
  1280. if (MoveSubgraphToRoot(subgraph) != GRAPH_SUCCESS) {
  1281. return GRAPH_FAILED;
  1282. }
  1283. }
  1284. }
  1285. return GRAPH_SUCCESS;
  1286. }
  1287. graphStatus AddEdge() {
  1288. for (const auto &node_info : all_nodes_info_) {
  1289. auto src_op_impl_ptr = node_info.first;
  1290. auto src_node_ptr = node_info.second;
  1291. GE_IF_BOOL_EXEC(src_op_impl_ptr == nullptr || src_node_ptr == nullptr, continue);
  1292. auto out_links = src_op_impl_ptr->output_links_;
  1293. GE_CHK_BOOL_EXEC(src_op_impl_ptr->op_desc_ != nullptr, return GRAPH_FAILED,
  1294. "Src operator impl's op_desc is null.");
  1295. auto &op_desc = src_op_impl_ptr->op_desc_;
  1296. GE_IF_BOOL_EXEC(op_desc == nullptr, continue);
  1297. for (const auto &out : out_links) {
  1298. auto src_idx = op_desc->GetOutputIndexByName(out.first);
  1299. GE_CHK_BOOL_EXEC(src_idx >= 0, return GRAPH_FAILED, "Find output index by name failed");
  1300. auto src_anchor = src_node_ptr->GetOutDataAnchor(src_idx);
  1301. GE_CHK_BOOL_EXEC(src_anchor != nullptr, return GRAPH_FAILED, "GetOutDataAnchor failed.");
  1302. for (const auto &dst_opio : out.second) {
  1303. auto dst_node_info = all_nodes_info_.find(dst_opio.GetOwner());
  1304. GE_CHK_BOOL_EXEC(dst_node_info != all_nodes_info_.end(), return GRAPH_FAILED, "Find Dst node failed.");
  1305. GE_IF_BOOL_EXEC(dst_node_info->second == nullptr, continue);
  1306. auto dst_anchor = dst_node_info->second->GetInDataAnchor(dst_opio.GetIndex());
  1307. GE_CHK_BOOL_EXEC(dst_anchor != nullptr, return GRAPH_FAILED, "GetInDataAnchor failed.");
  1308. auto ret = GraphUtils::AddEdge(src_anchor, dst_anchor);
  1309. GE_CHK_BOOL_EXEC(ret == GRAPH_SUCCESS, return GRAPH_FAILED,
  1310. "from node[%s][%d] to node[%s][%d]AddEdge failed.", src_node_ptr->GetName().c_str(),
  1311. src_anchor->GetIdx(), dst_node_info->second->GetName().c_str(), dst_anchor->GetIdx());
  1312. }
  1313. }
  1314. auto out_control_anchor = src_node_ptr->GetOutControlAnchor();
  1315. for (const auto &control_out : src_op_impl_ptr->control_output_link_) {
  1316. auto dst_node_info = all_nodes_info_.find(control_out.lock());
  1317. if (dst_node_info == all_nodes_info_.end()) {
  1318. GELOGE(GRAPH_FAILED, "Find Dst node failed.");
  1319. return GRAPH_FAILED;
  1320. }
  1321. GE_IF_BOOL_EXEC(dst_node_info->second == nullptr, continue);
  1322. auto in_control_anchor = dst_node_info->second->GetInControlAnchor();
  1323. auto ret = GraphUtils::AddEdge(out_control_anchor, in_control_anchor);
  1324. if (ret != GRAPH_SUCCESS) {
  1325. GELOGE(ret, "AddEdge failed. srcNode %s:%s, dstNode %s:%s", op_desc->GetName().c_str(),
  1326. op_desc->GetType().c_str(), dst_node_info->second->GetName().c_str(),
  1327. dst_node_info->second->GetType().c_str());
  1328. return ret;
  1329. }
  1330. }
  1331. }
  1332. return GRAPH_SUCCESS;
  1333. }
  1334. ComputeGraphPtr graph_ = nullptr;
  1335. std::map<OperatorImplPtr, NodePtr> all_nodes_info_{};
  1336. };
  1337. inline bool HasSameNameNode(const ComputeGraphPtr &compute_graph) {
  1338. for (const auto &graph : compute_graph->GetAllSubgraphs()) {
  1339. std::set<string> node_names;
  1340. for (auto const &node : graph->GetDirectNode()) {
  1341. auto result = node_names.insert(node->GetName());
  1342. if (!result.second) {
  1343. GELOGE(GRAPH_FAILED, "graph %s has same name node%s", graph->GetName().c_str(), node->GetName().c_str());
  1344. return true;
  1345. }
  1346. }
  1347. }
  1348. std::set<string> node_names;
  1349. for (auto const &node : compute_graph->GetDirectNode()) {
  1350. auto result = node_names.insert(node->GetName());
  1351. if (!result.second) {
  1352. GELOGE(GRAPH_FAILED, "graph %s has same name node%s", compute_graph->GetName().c_str(), node->GetName().c_str());
  1353. return true;
  1354. }
  1355. }
  1356. return false;
  1357. }
  1358. ComputeGraphPtr GraphUtils::CreateGraphFromOperator(const string &name, const vector<ge::Operator> &inputs) {
  1359. auto graph_builder_impl = GraphBuilderImpl(name);
  1360. ComputeGraphPtr compute_graph = graph_builder_impl.BuildGraph(inputs);
  1361. GE_CHK_BOOL_EXEC(compute_graph != nullptr, return compute_graph, "Computer graph is nullptr");
  1362. compute_graph->SetAllNodesInfo(graph_builder_impl.GetAllNodesInfo());
  1363. if (HasSameNameNode(compute_graph)) {
  1364. GELOGW("Compute do not allow has same name nodes.");
  1365. compute_graph = nullptr;
  1366. }
  1367. return compute_graph;
  1368. }
  1369. void GraphUtils::BreakConnect(const std::map<OperatorImplPtr, NodePtr> &all_nodes_infos) {
  1370. for (const auto &it : all_nodes_infos) {
  1371. OperatorImplPtr op_impl = it.first;
  1372. if (op_impl == nullptr) {
  1373. GELOGW("operator impl is nullptr.");
  1374. continue;
  1375. }
  1376. op_impl->ClearOutputLinks();
  1377. op_impl->ClearInputLinks();
  1378. OperatorKeeper::GetInstance().CheckOutOperator(op_impl);
  1379. }
  1380. }
  1381. } // namespace ge
  1382. /*lint +e446 +e732*/
  1383. /*lint +e665*/

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示