You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

operator.cc 65 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago

  1. /**
  2. * Copyright 2019-2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "external/graph/operator.h"
  17. #include "external/graph/operator_factory.h"
  18. #include <stdint.h>
  19. #include <algorithm>
  20. #include <mutex>
  21. #include <queue>
  22. #include <set>
  23. #include "./array_ops.h"
  24. #include "debug/ge_log.h"
  25. #include "debug/ge_op_types.h"
  26. #include "debug/ge_util.h"
  27. #include "external/graph/attr_value.h"
  28. #include "external/graph/types.h"
  29. #include "framework/common/debug/ge_log.h"
  30. #include "graph/compute_graph.h"
  31. #include "graph/ge_attr_value.h"
  32. #include "graph/ge_context.h"
  33. #include "graph/ge_tensor.h"
  34. #include "graph/node.h"
  35. #include "graph/op_desc.h"
  36. #include "graph/runtime_inference_context.h"
  37. #include "graph/usr_types.h"
  38. #include "graph/utils/node_utils.h"
  39. #include "graph/debug/ge_attr_define.h"
  40. #include "utils/graph_utils.h"
  41. #include "utils/op_desc_utils.h"
  42. #include "utils/tensor_adapter.h"
  43. #include "utils/tensor_utils.h"
  44. #include "utils/type_utils.h"
  45. #include <algorithm>
  46. #include <mutex>
  47. #include <queue>
  48. #include <set>
  49. #include <stdint.h>
  50. using std::enable_shared_from_this;
  51. using std::make_pair;
  52. using std::shared_ptr;
  53. using std::string;
  54. using std::to_string;
  55. using std::vector;
  56. namespace ge {
  57. class OpIO {
  58. public:
  59. OpIO(const string &name, int index, const OperatorImplPtr &owner) : name_(name), index_(index), owner_(owner) {}
  60. ~OpIO() = default;
  61. string GetName() const { return name_; }
  62. int GetIndex() const { return index_; }
  63. OperatorImplPtr GetOwner() const { return owner_; }
  64. bool operator==(const OpIO &r_value) const {
  65. return (this->name_ == r_value.GetName()) && (this->index_ == r_value.GetIndex()) &&
  66. (this->GetOwner() == r_value.GetOwner());
  67. }
  68. private:
  69. string name_;
  70. int index_;
  71. std::shared_ptr<OperatorImpl> owner_;
  72. };
  73. class TensorTypeImpl {
  74. public:
  75. TensorTypeImpl() = default;
  76. ~TensorTypeImpl() = default;
  77. std::vector<DataType> dt_vec_;
  78. };
  79. TensorType::TensorType(DataType dt) {
  80. tensor_type_impl_ = ComGraphMakeShared<TensorTypeImpl>();
  81. if (tensor_type_impl_ != nullptr) {
  82. tensor_type_impl_->dt_vec_.push_back(dt);
  83. }
  84. }
  85. TensorType::TensorType(const std::initializer_list<DataType> &types) {
  86. tensor_type_impl_ = ComGraphMakeShared<TensorTypeImpl>();
  87. if (tensor_type_impl_ != nullptr) {
  88. tensor_type_impl_->dt_vec_ = types;
  89. }
  90. }
  91. class OperatorImpl : public std::enable_shared_from_this<OperatorImpl> {
  92. friend class GraphBuilderImpl;
  93. friend class OpDescUtils;
  94. public:
  95. explicit OperatorImpl(const string &name, const string &type) : op_desc_(ComGraphMakeShared<OpDesc>(name, type)) {
  96. if (op_desc_ == nullptr) {
  97. GELOGW("OpDesc make shared failed");
  98. }
  99. }
  100. explicit OperatorImpl(const OpDescPtr &op_desc) : op_desc_(op_desc) {}
  101. explicit OperatorImpl(ge::ConstNodePtr node) : node_(std::move(node)) {
  102. if (node_ != nullptr && node_->GetOpDesc() != nullptr) {
  103. op_desc_ = node_->GetOpDesc();
  104. }
  105. }
  106. ~OperatorImpl() {}
  107. void SetInputImpl(const string &dst_name, const ge::Operator &src_oprt) {
  108. GE_CHK_BOOL_EXEC(!dst_name.empty(), return, "dst name is empty");
  109. GE_CHK_BOOL_EXEC(op_desc_ != nullptr, return, "op_desc_ is nullptr.");
  110. GE_CHK_BOOL_EXEC(src_oprt.operator_impl_ != nullptr, return, "operator_impl_ is nullptr.");
  111. GE_CHK_BOOL_EXEC(src_oprt.operator_impl_->op_desc_ != nullptr, return, "op_desc_ is nullptr.");
  112. auto src_op_impl = src_oprt.GetOperatorImplPtr();
  113. GE_CHK_BOOL_EXEC(src_op_impl != nullptr, return, "Src impl is null.");
  114. GE_CHK_BOOL_EXEC(src_op_impl->op_desc_ != nullptr, return, "Src impl's opdesc is null.");
  115. GE_CHK_BOOL_EXEC(src_oprt.operator_impl_->op_desc_->GetOutputsSize() == 1, return,
  116. "The source operator[%s] must has one output",
  117. src_oprt.operator_impl_->op_desc_->GetName().c_str())
  118. uint32_t src_index = 0;
  119. string src_name = src_op_impl->op_desc_->GetOutputNameByIndex(src_index);
  120. GE_CHK_BOOL_EXEC(!src_name.empty(), return, "Src output's name is empty.");
  121. OpIO out_handler(src_name, src_index, src_op_impl);
  122. input_link_.insert(std::make_pair(dst_name, out_handler));
  123. int dst_index = op_desc_->GetInputIndexByName(dst_name);
  124. GE_CHK_BOOL_EXEC(dst_index >= 0, return, "Find input index by name failed. name[%s], op name:%s", dst_name.c_str(),
  125. op_desc_->GetName().c_str());
  126. bool is_const = false;
  127. if (src_oprt.GetOpType() == CONSTANT) {
  128. is_const = true;
  129. }
  130. auto is_input_const = op_desc_->GetIsInputConst();
  131. for (int i = static_cast<int>(is_input_const.size()); i <= dst_index; ++i) {
  132. is_input_const.push_back(false);
  133. }
  134. is_input_const[dst_index] = is_const;
  135. op_desc_->SetIsInputConst(is_input_const);
  136. OpIO op_dst(dst_name, dst_index, shared_from_this());
  137. src_op_impl->UpdateLinkMapImpl(src_name, op_dst);
  138. auto output_desc = src_op_impl->GetOutputDesc(src_name);
  139. auto input_desc = op_desc_->GetInputDesc(dst_name);
  140. if (input_desc.GetFormat() == FORMAT_RESERVED) {
  141. output_desc.SetFormat(FORMAT_ND);
  142. } else {
  143. output_desc.SetFormat(input_desc.GetFormat());
  144. }
  145. // Fix for linking opdesc
  146. if (op_desc_->UpdateInputDesc(dst_name, output_desc) != GRAPH_SUCCESS) {
  147. GELOGE(GRAPH_FAILED, "Update inputdesc failed,dst name is %s, src name is %s", dst_name.c_str(),
  148. src_name.c_str());
  149. return;
  150. }
  151. }
  152. void SetInputImpl(const string &dst_name, const ge::OutHandler &out_handler) {
  153. GE_CHK_BOOL_EXEC(!dst_name.empty(), return, "dst name is empty");
  154. GE_CHK_BOOL_EXEC(out_handler != nullptr, return, "SetInputImpl faild, out_handler is nullptr.");
  155. GE_CHK_BOOL_EXEC(op_desc_ != nullptr, return, "op_desc_ is nullptr.");
  156. input_link_.insert(std::make_pair(dst_name, *out_handler));
  157. string src_name = out_handler->GetName();
  158. int dst_index = op_desc_->GetInputIndexByName(dst_name);
  159. GE_CHK_BOOL_EXEC(dst_index >= 0, return, "Find input index by name failed. name[%s], op name:%s", dst_name.c_str(),
  160. op_desc_->GetName().c_str());
  161. auto out_op_impl = out_handler->GetOwner();
  162. GE_CHK_BOOL_EXEC(out_op_impl != nullptr && out_op_impl->GetOpDescImpl() != nullptr, return,
  163. "out_handler invalid. name[%s]", dst_name.c_str());
  164. bool is_const = false;
  165. if (out_op_impl->GetOpDescImpl()->GetType() == CONSTANT) {
  166. is_const = true;
  167. }
  168. auto is_input_const = op_desc_->GetIsInputConst();
  169. for (int i = static_cast<int>(is_input_const.size()); i <= dst_index; ++i) {
  170. is_input_const.push_back(false);
  171. }
  172. is_input_const[dst_index] = is_const;
  173. op_desc_->SetIsInputConst(is_input_const);
  174. OpIO in_handler(dst_name, dst_index, shared_from_this());
  175. GE_CHK_BOOL_EXEC(out_op_impl != nullptr, return, "Get out_handler's impl failed.");
  176. out_op_impl->UpdateLinkMapImpl(src_name, in_handler);
  177. auto src_output_desc = out_op_impl->GetOutputDesc(src_name);
  178. auto dst_input_desc = op_desc_->GetInputDesc(dst_name);
  179. if (dst_input_desc.GetFormat() == FORMAT_RESERVED) {
  180. src_output_desc.SetFormat(FORMAT_ND);
  181. } else {
  182. src_output_desc.SetFormat(dst_input_desc.GetFormat());
  183. }
  184. GE_CHK_BOOL_EXEC(op_desc_->UpdateInputDesc(dst_name, src_output_desc) == GRAPH_SUCCESS, return,
  185. "Update input desc failed,dst name is %s,src name is %s", dst_name.c_str(),
  186. src_name.c_str()); // fix for linking opdesc
  187. }
  188. void AddControlInputImp(const ge::Operator &src_oprt) {
  189. if (src_oprt.operator_impl_ == nullptr) {
  190. GELOGE(FAILED, "Src operator impl is nullptr");
  191. return;
  192. }
  193. for (auto &input : control_input_link_) {
  194. if (input.lock() == src_oprt.operator_impl_) {
  195. return;
  196. }
  197. }
  198. control_input_link_.push_back(src_oprt.operator_impl_);
  199. src_oprt.operator_impl_->control_output_link_.push_back(shared_from_this());
  200. }
  201. graphStatus GetInputImpl(const string &dst_name, ge::OpIO &out_handler) {
  202. auto out = input_link_.find(dst_name);
  203. if (out == input_link_.end()) {
  204. return GRAPH_FAILED;
  205. }
  206. out_handler = out->second;
  207. return GRAPH_SUCCESS;
  208. }
  209. bool InputIsSet(const string &name) {
  210. GE_CHK_BOOL_EXEC(op_desc_ != nullptr, return false, "op_desc_ is nullptr.");
  211. return op_desc_->InputIsSet(name);
  212. }
  213. string GetName() const {
  214. GE_CHK_BOOL_EXEC(op_desc_ != nullptr, return string(), "op_desc_ is nullptr.");
  215. return op_desc_->GetName();
  216. }
  217. GeTensorDesc GetInputDesc(const string &name) const {
  218. GE_CHK_BOOL_EXEC(op_desc_ != nullptr, return GeTensorDesc(), "op_desc_ is nullptr.");
  219. return op_desc_->GetInputDesc(name);
  220. }
  221. GeTensorDesc GetInputDesc(uint32_t index) const {
  222. GE_CHK_BOOL_EXEC(op_desc_ != nullptr, return GeTensorDesc(), "op_desc_ is nullptr.");
  223. return op_desc_->GetInputDesc(index);
  224. }
  225. graphStatus UpdateInputDesc(const string &name, const GeTensorDesc &tensor_desc) {
  226. GE_CHK_BOOL_EXEC(op_desc_ != nullptr, return GRAPH_FAILED, "op_desc_ is nullptr.");
  227. return op_desc_->UpdateInputDesc(name, tensor_desc);
  228. }
  229. OutHandler GetOutput(const string &name) {
  230. GE_CHK_BOOL_EXEC(op_desc_ != nullptr, return nullptr, "op_desc_ is nullptr.");
  231. int src_index = op_desc_->GetOutputIndexByName(name);
  232. GE_CHK_BOOL_EXEC(src_index >= 0, return nullptr, "Find src index by name failed. name[%s]", name.c_str());
  233. shared_ptr<OpIO> output_ptr = ComGraphMakeShared<OpIO>(name, src_index, shared_from_this());
  234. if (output_ptr == nullptr) {
  235. GELOGE(GRAPH_FAILED, "OpIO make shared failed");
  236. return nullptr;
  237. }
  238. return output_ptr;
  239. }
  240. OutHandler GetOutput(uint32_t index) {
  241. GE_CHK_BOOL_EXEC(op_desc_ != nullptr, return nullptr, "op_desc_ is nullptr.");
  242. string name = op_desc_->GetOutputNameByIndex(index);
  243. if (name.empty()) {
  244. GELOGE(GRAPH_FAILED, "Find src name by index failed. index[%u]", index);
  245. return nullptr;
  246. }
  247. shared_ptr<OpIO> output_ptr = ComGraphMakeShared<OpIO>(name, index, shared_from_this());
  248. if (output_ptr == nullptr) {
  249. GELOGE(GRAPH_FAILED, "OpIO make shared failed");
  250. return nullptr;
  251. }
  252. return output_ptr;
  253. }
  254. GeTensorDesc GetOutputDesc(const string &name) const {
  255. GE_CHK_BOOL_EXEC(op_desc_ != nullptr, return GeTensorDesc(), "op_desc_ is nullptr.");
  256. return op_desc_->GetOutputDesc(name);
  257. }
  258. GeTensorDesc GetOutputDesc(uint32_t index) const {
  259. GE_CHK_BOOL_EXEC(op_desc_ != nullptr, return GeTensorDesc(), "op_desc_ is nullptr.");
  260. return op_desc_->GetOutputDesc(index);
  261. }
  262. graphStatus UpdateOutputDesc(const string &name, const GeTensorDesc &tensor_desc) {
  263. GE_CHK_BOOL_RET_STATUS(op_desc_ != nullptr, GRAPH_FAILED, "op_desc is nullptr.");
  264. auto res = op_desc_->UpdateOutputDesc(name, tensor_desc);
  265. if (res == GRAPH_SUCCESS) {
  266. for (auto ol : output_links_[name]) {
  267. if (ol.GetOwner() == nullptr) {
  268. GELOGW("%s get owner is nullptr", ol.GetName().c_str());
  269. continue;
  270. }
  271. GE_CHK_BOOL_RET_STATUS(ol.GetOwner()->UpdateInputDesc(ol.GetName(), tensor_desc) == GRAPH_SUCCESS, GRAPH_FAILED,
  272. "Could not update next operator's input %s.", ol.GetName().c_str());
  273. }
  274. }
  275. return res;
  276. }
  277. size_t GetInputsSize() const {
  278. GE_IF_BOOL_EXEC(op_desc_ == nullptr, return 0);
  279. return op_desc_->GetInputsSize();
  280. }
  281. size_t GetOutputsSize() const {
  282. GE_IF_BOOL_EXEC(op_desc_ == nullptr, return 0);
  283. return op_desc_->GetOutputsSize();
  284. }
  285. graphStatus SetAttr(const string &name, GeAttrValue &&attr_value) {
  286. GE_CHK_BOOL_RET_STATUS(op_desc_ != nullptr, GRAPH_FAILED, "op_desc is nullptr.");
  287. return op_desc_->SetAttr(name, std::move(attr_value));
  288. }
  289. graphStatus GetAttr(const string &name, GeAttrValue &attr_value) const {
  290. GE_CHK_BOOL_RET_STATUS(op_desc_ != nullptr, GRAPH_FAILED, "op_desc is nullptr.");
  291. return op_desc_->GetAttr(name, attr_value);
  292. }
  293. OpDescPtr GetOpDescImpl() const { return op_desc_; }
  294. void UpdateLinkMapImpl(const string &src_name, OpIO &op_dst) {
  295. auto it_find = output_links_.find(src_name);
  296. if (it_find == output_links_.end()) {
  297. std::vector<OpIO> dsts{op_dst};
  298. output_links_.insert(std::make_pair(src_name, dsts));
  299. } else {
  300. it_find->second.push_back(op_dst);
  301. }
  302. }
  303. Operator ToOperator() { return Operator(shared_from_this()); }
  304. static OpDescPtr GetOpDesc(const Operator &oprt) {
  305. GE_IF_BOOL_EXEC(oprt.operator_impl_ == nullptr, return nullptr);
  306. return oprt.operator_impl_->op_desc_;
  307. }
  308. void ClearOutputLinks() noexcept { output_links_.clear(); }
  309. void ClearInputLinks() noexcept { input_link_.clear(); }
  310. ge::ConstNodePtr GetNode() { return node_; }
  311. void SetInferenceContext(const InferenceContextPtr &inference_context) { inference_context_ = inference_context; }
  312. InferenceContextPtr GetInferenceContext() const { return inference_context_; }
  313. void SubgraphRegister(const std::string &ir_name, bool dynamic) {
  314. op_desc_->RegisterSubgraphIrName(ir_name, dynamic ? kDynamic : kStatic);
  315. }
  316. void SubgraphCountRegister(const std::string &ir_name, uint32_t count) {
  317. if (op_desc_->GetSubgraphTypeByIrName(ir_name) == kStatic) {
  318. op_desc_->AddSubgraphName(ir_name);
  319. subgraph_names_to_builders_[ir_name] = nullptr;
  320. } else {
  321. for (uint32_t i = 0; i < count; ++i) {
  322. string key_name = ir_name + std::to_string(i);
  323. op_desc_->AddSubgraphName(key_name);
  324. subgraph_names_to_builders_[key_name] = nullptr;
  325. }
  326. }
  327. }
  328. void SetSubgraphBuilder(const std::string &ir_name, uint32_t index, const SubgraphBuilder &builder) {
  329. string key_name = ir_name;
  330. if (op_desc_->GetSubgraphTypeByIrName(ir_name) == kDynamic) {
  331. key_name += std::to_string(index);
  332. }
  333. auto it = subgraph_names_to_builders_.find(key_name);
  334. if (it == subgraph_names_to_builders_.end()) {
  335. GELOGE(PARAM_INVALID, "Failed to set subgraph builder for name %s index %u.", ir_name.c_str(), index);
  336. return;
  337. }
  338. it->second = builder;
  339. }
  340. SubgraphBuilder GetSubgraphBuilder(const std::string &ir_name, uint32_t index) const {
  341. string key_name = ir_name;
  342. if (op_desc_->GetSubgraphTypeByIrName(ir_name) == kDynamic) {
  343. key_name += std::to_string(index);
  344. }
  345. return GetSubgraphBuilder(key_name);
  346. }
  347. SubgraphBuilder GetSubgraphBuilder(const std::string &name) const {
  348. auto iter = subgraph_names_to_builders_.find(name);
  349. if (iter == subgraph_names_to_builders_.end()) {
  350. GELOGE(PARAM_INVALID, "Failed to get subgraph builder for name %s", name.c_str());
  351. return nullptr;
  352. }
  353. return iter->second;
  354. }
  355. std::vector<std::string> GetSubgraphNames() const {
  356. std::vector<std::string> names;
  357. for (const auto &subgraph_name_to_type : op_desc_->GetSubgraphIrNames()) {
  358. names.emplace_back(subgraph_name_to_type.first);
  359. }
  360. return names;
  361. }
  362. size_t GetSubgraphNamesCount() const { return op_desc_->GetSubgraphIrNames().size(); }
  363. OpDescPtr op_desc_ = nullptr;
  364. private:
  365. ge::ConstNodePtr node_{nullptr};
  366. ge::InferenceContextPtr inference_context_;
  367. std::map<string, std::vector<OpIO>> output_links_{};
  368. std::map<string, OpIO> input_link_{};
  369. std::vector<std::weak_ptr<OperatorImpl>> control_input_link_{};
  370. std::vector<std::weak_ptr<OperatorImpl>> control_output_link_{};
  371. std::map<std::string, SubgraphBuilder> subgraph_names_to_builders_;
  372. };
  373. // Used to manage OperatorImpl instances created by ge api.
  374. class OperatorKeeper {
  375. private:
  376. OperatorKeeper() = default;
  377. ~OperatorKeeper() {
  378. for (const auto &iter : operators_) {
  379. if (iter) {
  380. iter->ClearInputLinks();
  381. iter->ClearOutputLinks();
  382. }
  383. }
  384. }
  385. std::set<OperatorImplPtr> operators_;
  386. std::mutex mutex_;
  387. public:
  388. static OperatorKeeper &GetInstance() {
  389. static OperatorKeeper instance;
  390. return instance;
  391. }
  392. void CheckInOperator(const OperatorImplPtr &op_impl) {
  393. if (op_impl) {
  394. std::lock_guard<std::mutex> lock(mutex_);
  395. operators_.insert(op_impl);
  396. }
  397. }
  398. void CheckOutOperator(const OperatorImplPtr &op_impl) {
  399. if (op_impl) {
  400. std::lock_guard<std::mutex> lock(mutex_);
  401. operators_.erase(op_impl);
  402. }
  403. }
  404. };
  405. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Operator OpDescUtils::CreateOperatorFromNode(ge::ConstNodePtr node_ptr) {
  406. ge::OperatorImplPtr operator_impl_ptr = ComGraphMakeShared<OperatorImpl>(node_ptr);
  407. if (operator_impl_ptr == nullptr) {
  408. GELOGE(GRAPH_FAILED, "OperatorImpl make shared failed");
  409. return Operator("default");
  410. }
  411. return operator_impl_ptr->ToOperator();
  412. }
  413. Operator::Operator(const std::string &type) {
  414. static uint32_t index = 0;
  415. string name = type + "_" + std::to_string(index++);
  416. operator_impl_ = ComGraphMakeShared<OperatorImpl>(name, type);
  417. if (operator_impl_ == nullptr) {
  418. GELOGW("OperatorImpl make shared failed");
  419. }
  420. OperatorKeeper::GetInstance().CheckInOperator(operator_impl_);
  421. }
  422. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Operator OpDescUtils::CreateOperatorFromOpDesc(OpDescPtr op_desc) {
  423. shared_ptr<OperatorImpl> operator_impl_ptr;
  424. operator_impl_ptr = ComGraphMakeShared<OperatorImpl>(op_desc);
  425. if (operator_impl_ptr == nullptr) {
  426. GELOGE(GRAPH_FAILED, "OperatorImpl make shared failed");
  427. return Operator("default");
  428. }
  429. OperatorKeeper::GetInstance().CheckInOperator(operator_impl_ptr);
  430. return operator_impl_ptr->ToOperator();
  431. }
  432. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescPtr OpDescUtils::GetOpDescFromOperator(const Operator &oprt) {
  433. return OperatorImpl::GetOpDesc(oprt);
  434. }
  435. GE_FUNC_HOST_VISIBILITY Operator::Operator(const string &name, const string &type) {
  436. operator_impl_ = ComGraphMakeShared<OperatorImpl>(name, type);
  437. if (operator_impl_ == nullptr) {
  438. GELOGE(GRAPH_FAILED, "OperatorImpl make shared failed");
  439. return;
  440. }
  441. OperatorKeeper::GetInstance().CheckInOperator(operator_impl_);
  442. }
  443. Operator::Operator(ge::OperatorImplPtr &&op_impl) { operator_impl_ = std::move(op_impl); }
  444. bool Operator::IsEmpty() const {
  445. if (operator_impl_ == nullptr) {
  446. return true;
  447. }
  448. return false;
  449. }
  450. string Operator::GetName() const {
  451. if (operator_impl_ != nullptr) {
  452. return operator_impl_->GetName();
  453. }
  454. return "";
  455. }
  456. GE_FUNC_HOST_VISIBILITY Operator &Operator::SetInput(const string &dst_name, const ge::Operator &src_oprt) {
  457. // Describe the connection relationship between operators, no create action
  458. GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return *this, "operator impl is nullptr.");
  459. operator_impl_->SetInputImpl(dst_name, src_oprt);
  460. return *this;
  461. }
  462. Operator &Operator::SetInput(const string &dst_name, const ge::OutHandler &out_handler) {
  463. GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return *this, "operator impl is nullptr.");
  464. operator_impl_->SetInputImpl(dst_name, out_handler);
  465. return *this;
  466. }
  467. Operator &Operator::SetInput(const std::string &dst_name, const ge::Operator &src_oprt, const std::string &name) {
  468. auto out_handler = src_oprt.GetOutput(name);
  469. GE_CHK_BOOL_EXEC(out_handler != nullptr, return *this, "out_handler is nullptr.");
  470. (void)SetInput(dst_name, out_handler);
  471. return *this;
  472. }
  473. Operator &Operator::SetInput(const std::string &dst_name, const ge::Operator &src_oprt, uint32_t index) {
  474. auto out_handler = src_oprt.GetOutput(index);
  475. GE_CHK_BOOL_EXEC(out_handler != nullptr, return *this, "out_handler is nullptr.");
  476. (void)SetInput(dst_name, out_handler);
  477. return *this;
  478. }
  479. Operator &Operator::AddControlInput(const Operator &src_oprt) {
  480. if (operator_impl_ == nullptr) {
  481. GELOGE(GRAPH_FAILED, "operator impl is nullptr.");
  482. return *this;
  483. }
  484. operator_impl_->AddControlInputImp(src_oprt);
  485. return *this;
  486. }
  487. graphStatus Operator::GetInputConstData(const string &dst_name, Tensor &data) const {
  488. GE_CHECK_NOTNULL(operator_impl_);
  489. auto node_ptr = operator_impl_->GetNode();
  490. if (node_ptr != nullptr) {
  491. // For inner compute graph
  492. auto op_desc = node_ptr->GetOpDesc();
  493. GE_CHECK_NOTNULL(op_desc);
  494. auto index = op_desc->GetInputIndexByName(dst_name);
  495. auto in_data_anchor = node_ptr->GetInDataAnchor(index);
  496. GE_CHECK_NOTNULL(in_data_anchor);
  497. auto out_data_anchor = in_data_anchor->GetPeerOutAnchor();
  498. GE_CHECK_NOTNULL(out_data_anchor);
  499. auto peer_node = out_data_anchor->GetOwnerNode();
  500. GE_CHECK_NOTNULL(peer_node);
  501. auto peer_op_desc = peer_node->GetOpDesc();
  502. GE_CHECK_NOTNULL(peer_op_desc);
  503. auto peer_op_type = peer_op_desc->GetType();
  504. if (peer_op_type == CONSTANTOP || peer_op_type == CONSTANT) {
  505. auto const_op_impl = ComGraphMakeShared<OperatorImpl>(peer_node);
  506. GE_CHECK_NOTNULL(const_op_impl);
  507. Operator const_op(std::move(const_op_impl));
  508. return const_op.GetAttr(ATTR_NAME_WEIGHTS, data);
  509. } else if (peer_op_type == DATA) {
  510. auto parent_node = NodeUtils::GetParentInput(peer_node);
  511. while ((parent_node != nullptr) && (parent_node->GetType() == DATA)) {
  512. parent_node = NodeUtils::GetParentInput(parent_node);
  513. }
  514. if ((parent_node != nullptr) &&
  515. ((parent_node->GetType() == CONSTANT) || (parent_node->GetType() == CONSTANTOP))) {
  516. auto const_op_impl = ComGraphMakeShared<OperatorImpl>(parent_node);
  517. GE_CHECK_NOTNULL(const_op_impl);
  518. Operator const_op(std::move(const_op_impl));
  519. return const_op.GetAttr(ATTR_NAME_WEIGHTS, data);
  520. }
  521. }
  522. // Try get from runtime inference context
  523. auto session_id = std::to_string(GetContext().SessionId());
  524. RuntimeInferenceContext *runtime_infer_ctx = nullptr;
  525. if (RuntimeInferenceContext::GetContext(session_id, &runtime_infer_ctx) == GRAPH_SUCCESS) {
  526. GELOGD("To get constant from runtime inference context. session_id = %s", session_id.c_str());
  527. auto ret = runtime_infer_ctx->GetTensor(peer_node->GetOpDesc()->GetId(), out_data_anchor->GetIdx(), data);
  528. if (ret == GRAPH_SUCCESS) {
  529. return GRAPH_SUCCESS;
  530. }
  531. }
  532. } else {
  533. // For outer graph
  534. return GetInputConstDataOut(dst_name, data);
  535. }
  536. auto op_name = operator_impl_->GetName();
  537. GELOGW("node[%s]'s input[%s]'s peer node is not const", op_name.c_str(), dst_name.c_str());
  538. return GRAPH_FAILED;
  539. }
  540. graphStatus Operator::GetInputConstDataOut(const string &dst_name, Tensor &data) const {
  541. ge::OpIO out_handle("", 0, nullptr);
  542. GE_CHECK_NOTNULL(operator_impl_);
  543. if (operator_impl_->GetInputImpl(dst_name, out_handle) != GRAPH_SUCCESS) {
  544. GELOGE(FAILED, "%s get input impl failed", dst_name.c_str());
  545. return GRAPH_FAILED;
  546. }
  547. if (out_handle.GetOwner() != nullptr && out_handle.GetOwner()->GetOpDescImpl() != nullptr) {
  548. Operator const_op(out_handle.GetOwner());
  549. const auto &op_desc_impl_type = out_handle.GetOwner()->GetOpDescImpl()->GetType();
  550. if (op_desc_impl_type == CONSTANTOP) {
  551. return const_op.GetAttr(op::Constant::name_attr_value(), data);
  552. } else if (op_desc_impl_type == CONSTANT) {
  553. return const_op.GetAttr(op::Const::name_attr_value(), data);
  554. }
  555. }
  556. return GRAPH_FAILED;
  557. }
  558. std::shared_ptr<const Node> Operator::GetNode() const {
  559. GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return nullptr, "operator impl is nullptr.");
  560. return operator_impl_->GetNode();
  561. }
  562. TensorDesc Operator::GetInputDesc(const std::string &name) const {
  563. GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return TensorDesc(), "operator impl is nullptr.");
  564. return TensorAdapter::GeTensorDesc2TensorDesc(operator_impl_->GetInputDesc(name));
  565. }
  566. void Operator::SetInferenceContext(const InferenceContextPtr &inference_context) {
  567. GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return, "operator impl is nullptr.");
  568. operator_impl_->SetInferenceContext(inference_context);
  569. }
  570. InferenceContextPtr Operator::GetInferenceContext() const {
  571. GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return nullptr, "operator impl is nullptr.");
  572. return operator_impl_->GetInferenceContext();
  573. }
  574. TensorDesc Operator::GetInputDesc(uint32_t index) const {
  575. GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return TensorDesc(), "operator impl is nullptr.");
  576. return TensorAdapter::GeTensorDesc2TensorDesc(operator_impl_->GetInputDesc(index));
  577. }
  578. graphStatus Operator::TryGetInputDesc(const string &name, TensorDesc &tensor_desc) const {
  579. GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return GRAPH_FAILED, "operator impl is nullptr.");
  580. auto check = operator_impl_->InputIsSet(name);
  581. if (check) tensor_desc = TensorAdapter::GeTensorDesc2TensorDesc(operator_impl_->GetInputDesc(name));
  582. return check ? GRAPH_SUCCESS : GRAPH_FAILED;
  583. }
  584. graphStatus Operator::UpdateInputDesc(const std::string &name, const ge::TensorDesc &tensor_desc) {
  585. GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return GRAPH_FAILED, "operator impl is nullptr.");
  586. return operator_impl_->UpdateInputDesc(name, TensorAdapter::TensorDesc2GeTensorDesc(tensor_desc));
  587. }
  588. OutHandler Operator::GetOutput(const string &name) const {
  589. GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return nullptr, "operator impl is nullptr.");
  590. return operator_impl_->GetOutput(name);
  591. }
  592. OutHandler Operator::GetOutput(uint32_t index) const {
  593. GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return nullptr, "operator impl is nullptr.");
  594. return operator_impl_->GetOutput(index);
  595. }
  596. TensorDesc Operator::GetOutputDesc(const std::string &name) const {
  597. GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return TensorDesc(), "operator impl is nullptr.");
  598. return TensorAdapter::GeTensorDesc2TensorDesc(operator_impl_->GetOutputDesc(name));
  599. }
  600. TensorDesc Operator::GetOutputDesc(uint32_t index) const {
  601. GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return TensorDesc(), "operator impl is nullptr.");
  602. return TensorAdapter::GeTensorDesc2TensorDesc(operator_impl_->GetOutputDesc(index));
  603. }
  604. graphStatus Operator::UpdateOutputDesc(const std::string &name, const ge::TensorDesc &tensor_desc) {
  605. GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return GRAPH_FAILED, "operator impl is nullptr.");
  606. return operator_impl_->UpdateOutputDesc(name, TensorAdapter::TensorDesc2GeTensorDesc(tensor_desc));
  607. }
  608. TensorDesc Operator::GetDynamicInputDesc(const string &name, uint32_t index) const {
  609. GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return TensorDesc(), "operator impl is nullptr.");
  610. return TensorAdapter::GeTensorDesc2TensorDesc(operator_impl_->GetInputDesc(name + std::to_string(index)));
  611. }
  612. graphStatus Operator::UpdateDynamicInputDesc(const string &name, uint32_t index, const TensorDesc &tensor_desc) {
  613. GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return GRAPH_FAILED, "operator impl is nullptr.");
  614. return operator_impl_->UpdateInputDesc(name + std::to_string(index),
  615. TensorAdapter::TensorDesc2GeTensorDesc(tensor_desc));
  616. }
  617. TensorDesc Operator::GetDynamicOutputDesc(const string &name, uint32_t index) const {
  618. GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return TensorDesc(), "operator impl is nullptr.");
  619. return TensorAdapter::GeTensorDesc2TensorDesc(operator_impl_->GetOutputDesc(name + std::to_string(index)));
  620. }
  621. graphStatus Operator::UpdateDynamicOutputDesc(const string &name, uint32_t index, const TensorDesc &tensor_desc) {
  622. GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return GRAPH_FAILED, "operator impl is nullptr.");
  623. return operator_impl_->UpdateOutputDesc(name + std::to_string(index),
  624. TensorAdapter::TensorDesc2GeTensorDesc(tensor_desc));
  625. }
  626. graphStatus Operator::InferShapeAndType() {
  627. GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return GRAPH_FAILED, "operator impl is nullptr.");
  628. GE_CHK_BOOL_EXEC(operator_impl_->GetOpDescImpl() != nullptr, return GRAPH_FAILED, "GetOpDescImpl is nullptr.");
  629. return operator_impl_->GetOpDescImpl()->CallInferFunc(*this);
  630. }
  631. graphStatus Operator::VerifyAllAttr(bool disable_common_verifier) {
  632. GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return GRAPH_FAILED, "operator impl is nullptr.");
  633. GE_CHK_BOOL_EXEC(operator_impl_->GetOpDescImpl() != nullptr, return GRAPH_FAILED, "GetOpDescImpl is nullptr.");
  634. if (!disable_common_verifier && (graphStatus)Operator::VerifyAll() == GRAPH_FAILED) {
  635. return GRAPH_FAILED;
  636. } else {
  637. return (graphStatus)operator_impl_->GetOpDescImpl()->OpVerify();
  638. }
  639. }
  640. GE_FUNC_HOST_VISIBILITY size_t Operator::GetInputsSize() const {
  641. GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return 0, "OperatorImpl_ is nullptr");
  642. return operator_impl_->GetInputsSize();
  643. }
  644. GE_FUNC_HOST_VISIBILITY size_t Operator::GetOutputsSize() const {
  645. GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return 0, "OperatorImpl_ is nullptr");
  646. return operator_impl_->GetOutputsSize();
  647. }
  648. // According to op get the attrs name and type
  649. namespace {
  650. const std::map<GeAttrValue::ValueType, std::string> kAttrTypesMap = {
  651. {GeAttrValue::VT_NONE, "VT_STRING"},
  652. {GeAttrValue::VT_STRING, "VT_STRING"},
  653. {GeAttrValue::VT_FLOAT, "VT_FLOAT"},
  654. {GeAttrValue::VT_BOOL, "VT_BOOL"},
  655. {GeAttrValue::VT_INT, "VT_INT"},
  656. {GeAttrValue::VT_TENSOR_DESC, "VT_TENSOR_DESC"},
  657. {GeAttrValue::VT_TENSOR, "VT_TENSOR"},
  658. {GeAttrValue::VT_BYTES, "VT_BYTES"},
  659. {GeAttrValue::VT_GRAPH, "VT_GRAPH"},
  660. {GeAttrValue::VT_NAMED_ATTRS, "VT_NAMED_ATTRS"},
  661. {GeAttrValue::VT_LIST_BASE, "VT_LIST_BASE"},
  662. {GeAttrValue::VT_LIST_STRING, "VT_LIST_STRING"},
  663. {GeAttrValue::VT_LIST_FLOAT, "VT_LIST_FLOAT"},
  664. {GeAttrValue::VT_LIST_BOOL, "VT_LIST_BOOL"},
  665. {GeAttrValue::VT_LIST_INT, "VT_LIST_INT"},
  666. {GeAttrValue::VT_LIST_TENSOR_DESC, "VT_LIST_TENSOR_DESC"},
  667. {GeAttrValue::VT_LIST_TENSOR, "VT_LIST_TENSOR"},
  668. {GeAttrValue::VT_LIST_BYTES, "VT_LIST_BYTES"},
  669. {GeAttrValue::VT_GRAPH, "VT_GRAPH"},
  670. {GeAttrValue::VT_LIST_NAMED_ATTRS, "VT_LIST_NAMED_ATTRS"},
  671. };
  672. } // namespace
  673. const std::map<std::string, std::string> Operator::GetAllAttrNamesAndTypes() const {
  674. std::map<std::string, std::string> attr_types;
  675. GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return attr_types, "operator impl is nullptr.");
  676. GE_CHK_BOOL_EXEC(operator_impl_->GetOpDescImpl() != nullptr, return attr_types, "GetOpDescImpl is nullptr.");
  677. std::map<string, GeAttrValue> attr_map = operator_impl_->GetOpDescImpl()->GetAllAttrs();
  678. map<string, GeAttrValue>::iterator iter;
  679. for (iter = attr_map.begin(); iter != attr_map.end(); ++iter) {
  680. string name = iter->first;
  681. GeAttrValue attr_value = iter->second;
  682. GeAttrValue::ValueType type = attr_value.GetValueType();
  683. auto iter2 = kAttrTypesMap.find(type);
  684. if (iter2 != kAttrTypesMap.end()) {
  685. attr_types[name] = iter2->second;
  686. }
  687. }
  688. return attr_types;
  689. }
  690. void Operator::InputRegister(const string &name) {
  691. GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return, "operator impl is nullptr.");
  692. GE_CHK_BOOL_EXEC(operator_impl_->GetOpDescImpl() != nullptr, return, "GetOpDescImpl is nullptr.");
  693. (void)operator_impl_->GetOpDescImpl()->AddInputDesc(name, GeTensorDesc());
  694. }
  695. void Operator::OptionalInputRegister(const string &name) {
  696. GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return, "operator impl is nullptr.");
  697. GE_CHK_BOOL_EXEC(operator_impl_->GetOpDescImpl() != nullptr, return, "GetOpDescImpl is nullptr.");
  698. // [No need to verify return value]
  699. (void)operator_impl_->GetOpDescImpl()->AddOptionalInputDesc(name,
  700. GeTensorDesc(GeShape(), FORMAT_RESERVED, DT_UNDEFINED));
  701. }
  702. void Operator::InferFuncRegister(const std::function<graphStatus(Operator &)> &func) {
  703. GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return, "operator impl is nullptr.");
  704. GE_CHK_BOOL_EXEC(operator_impl_->GetOpDescImpl() != nullptr, return, "GetOpDescImpl is nullptr.");
  705. // [No need to verify return value]
  706. (void)operator_impl_->GetOpDescImpl()->AddInferFunc(func);
  707. }
  708. void Operator::InferFormatFuncRegister(const std::function<graphStatus(Operator &)> &func) {
  709. GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return, "operator impl is nullptr.");
  710. GE_CHK_BOOL_EXEC(operator_impl_->GetOpDescImpl() != nullptr, return, "GetOpDescImpl is nullptr.");
  711. // [No need to verify return value]
  712. (void)operator_impl_->GetOpDescImpl()->AddInferFormatFunc(func);
  713. }
  714. void Operator::VerifierFuncRegister(const std::function<graphStatus(Operator &)> &func) {
  715. GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return, "operator impl is nullptr.");
  716. GE_CHK_BOOL_EXEC(operator_impl_->GetOpDescImpl() != nullptr, return, "GetOpDescImpl is nullptr.");
  717. // [No need to verify return value]
  718. (void)operator_impl_->GetOpDescImpl()->AddVerifierFunc(func);
  719. }
  720. void Operator::OutputRegister(const string &name) {
  721. GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return, "operator impl is nullptr.");
  722. GE_CHK_BOOL_EXEC(operator_impl_->GetOpDescImpl() != nullptr, return, "GetOpDescImpl is nullptr.");
  723. // [No need to verify return value]
  724. (void)operator_impl_->GetOpDescImpl()->AddOutputDesc(name, GeTensorDesc());
  725. }
  726. void Operator::DynamicInputRegister(const string &name, const unsigned int num, bool is_push_back) {
  727. GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return, "operator impl is nullptr.");
  728. GE_CHK_BOOL_EXEC(operator_impl_->GetOpDescImpl() != nullptr, return, "GetOpDescImpl is nullptr.");
  729. GE_CHK_BOOL_EXEC(AttrUtils::SetInt(operator_impl_->GetOpDescImpl(), DYNAMIC_INPUT_TD_NUM(name), num), return,
  730. "set int failed");
  731. (void)operator_impl_->GetOpDescImpl()->AddDynamicInputDesc(name, num, is_push_back);
  732. }
  733. void Operator::DynamicInputRegisterByIndex(const string &name, const unsigned int num, size_t index) {
  734. GE_CHK_BOOL_EXEC(!!operator_impl_, return, "operator impl is nullptr.");
  735. GE_CHK_BOOL_EXEC(nullptr != operator_impl_->GetOpDescImpl(), return, "GetOpDescImpl is nullptr.");
  736. operator_impl_->GetOpDescImpl()->AddDynamicInputDescByIndex(name, num, index);
  737. }
  738. int Operator::GetDynamicInputNum(const string &name) const {
  739. GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return 0, "operator impl is nullptr.");
  740. GE_CHK_BOOL_EXEC(operator_impl_->GetOpDescImpl() != nullptr, return 0, "GetOpDescImpl is nullptr.");
  741. int num = 0;
  742. GE_CHK_BOOL_EXEC(AttrUtils::GetInt(operator_impl_->GetOpDescImpl(), DYNAMIC_INPUT_TD_NUM(name), num), return num,
  743. "Get %s int failed", name.c_str());
  744. return num;
  745. }
  746. void Operator::DynamicOutputRegister(const string &name, const unsigned int num, bool is_push_back) {
  747. GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return, "operator impl is nullptr.");
  748. GE_CHK_BOOL_EXEC(operator_impl_->GetOpDescImpl() != nullptr, return, "GetOpDescImpl is nullptr.");
  749. GE_CHK_BOOL_EXEC(AttrUtils::SetInt(operator_impl_->GetOpDescImpl(), DYNAMIC_OUTPUT_TD_NUM(name), num), return,
  750. "Set %s int failed", name.c_str());
  751. (void)operator_impl_->GetOpDescImpl()->AddDynamicOutputDesc(name, num, is_push_back);
  752. }
  753. int Operator::GetDynamicOutputNum(const string &name) const {
  754. GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return 0, "operator impl is nullptr.");
  755. GE_CHK_BOOL_EXEC(operator_impl_->GetOpDescImpl() != nullptr, return 0, "GetOpDescImpl is nullptr.");
  756. int num = 0;
  757. GE_CHK_BOOL_EXEC(AttrUtils::GetInt(operator_impl_->GetOpDescImpl(), DYNAMIC_OUTPUT_TD_NUM(name), num), return num,
  758. "Get %s int failed", name.c_str());
  759. return num;
  760. }
  761. void Operator::RequiredAttrRegister(const string &name) {
  762. GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return, "operator impl is nullptr.");
  763. GE_CHK_BOOL_EXEC(operator_impl_->GetOpDescImpl() != nullptr, return, "GetOpDescImpl is nullptr.");
  764. operator_impl_->GetOpDescImpl()->AddRequiredAttr(name);
  765. }
  766. graphStatus Operator::VerifyAll() {
  767. GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return GRAPH_FAILED, "operator impl is nullptr.");
  768. GE_CHK_BOOL_EXEC(operator_impl_->GetOpDescImpl() != nullptr, return GRAPH_FAILED, "GetOpDescImpl is nullptr.");
  769. // Check all inputs defined
  770. for (const string &iname : operator_impl_->GetOpDescImpl()->GetAllInputNames()) {
  771. GE_CHK_BOOL_RET_STATUS(operator_impl_->GetOpDescImpl()->IsOptionalInput(iname) || operator_impl_->InputIsSet(iname),
  772. GRAPH_FAILED, "operator input %s is not linked.", iname.c_str());
  773. vector<int64_t> ishape = operator_impl_->GetOpDescImpl()->GetInputDesc(iname).GetShape().GetDims();
  774. for (int64_t dim : ishape) {
  775. GE_CHK_BOOL_RET_STATUS(dim > 0, GRAPH_FAILED, "operator input %s shape contains negative or zero dimension.",
  776. iname.c_str());
  777. }
  778. }
  779. // Check all attributes defined
  780. const auto all_attributes = operator_impl_->GetOpDescImpl()->GetAllAttrs();
  781. for (const auto &name : operator_impl_->GetOpDescImpl()->GetAllAttrNames()) {
  782. GE_CHK_BOOL_RET_STATUS(all_attributes.find(name) != all_attributes.end(), GRAPH_FAILED,
  783. "operator attribute %s is empty.", name.c_str());
  784. }
  785. return GRAPH_SUCCESS;
  786. }
  787. string Operator::GetOpType() const {
  788. GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return "Data", "operator impl is nullptr.");
  789. return OperatorImpl::GetOpDesc(*this)->GetType();
  790. }
  791. Operator &Operator::SetInput(const std::string &dst_name, uint32_t dst_index, const ge::Operator &src_oprt) {
  792. string dynamic_dst_name = DYNAMIN_INPUT_NAME(dst_name, dst_index);
  793. return SetInput(dynamic_dst_name, src_oprt);
  794. }
  795. Operator &Operator::SetInput(const std::string &dst_name, uint32_t dst_index, const ge::Operator &src_oprt,
  796. const std::string &name) {
  797. string dynamic_dst_name = DYNAMIN_INPUT_NAME(dst_name, dst_index);
  798. return SetInput(dynamic_dst_name, src_oprt, name);
  799. }
  800. OperatorImplPtr Operator::GetOperatorImplPtr() const { return operator_impl_; }
  801. #define OP_ATTR_SET_IMP(ArgType, AttrUtilsFun) \
  802. Operator &Operator::SetAttr(const string &name, ArgType attr_value) { \
  803. if (operator_impl_ == nullptr || operator_impl_->GetOpDescImpl() == nullptr) { \
  804. GELOGE(GRAPH_FAILED, "operator impl is nullptr, name %s.", name.c_str()); \
  805. return *this; \
  806. } \
  807. if (!AttrUtils::Set##AttrUtilsFun(operator_impl_->GetOpDescImpl(), name, attr_value)) { \
  808. GELOGW("set attr name %s failed.", name.c_str()); \
  809. } \
  810. return *this; \
  811. }
  812. #define OP_ATTR_GET_IMP(ArgType, AttrUtilsFun) \
  813. graphStatus Operator::GetAttr(const string &name, ArgType attr_value) const { \
  814. if (operator_impl_ == nullptr || operator_impl_->GetOpDescImpl() == nullptr) { \
  815. GELOGE(GRAPH_FAILED, "operator impl is nullptr, name %s.", name.c_str()); \
  816. return GRAPH_FAILED; \
  817. } \
  818. if (!AttrUtils::Get##AttrUtilsFun(operator_impl_->GetOpDescImpl(), name, attr_value)) { \
  819. GELOGW("get attr name %s failed.", name.c_str()); \
  820. return GRAPH_FAILED; \
  821. } \
  822. return GRAPH_SUCCESS; \
  823. }
  824. void Operator::BreakConnect() const {
  825. if (operator_impl_ == nullptr) {
  826. GELOGW("operator impl is nullptr.");
  827. return;
  828. }
  829. operator_impl_->ClearInputLinks();
  830. operator_impl_->ClearOutputLinks();
  831. OperatorKeeper::GetInstance().CheckOutOperator(operator_impl_);
  832. }
  833. #define OP_ATTR_REG_IMP(ArgType, AttrUtilsFun) \
  834. void Operator::AttrRegister(const string &name, ArgType attr_value) { \
  835. if (operator_impl_ == nullptr || operator_impl_->GetOpDescImpl() == nullptr) { \
  836. GELOGE(GRAPH_FAILED, "operator impl is nullptr, name %s.", name.c_str()); \
  837. return; \
  838. } \
  839. if (!AttrUtils::Set##AttrUtilsFun(operator_impl_->GetOpDescImpl(), name, attr_value)) { \
  840. GELOGW("reg attr name %s failed.", name.c_str()); \
  841. } \
  842. }
  843. OP_ATTR_SET_IMP(int64_t, Int)
  844. OP_ATTR_SET_IMP(int32_t, Int)
  845. OP_ATTR_SET_IMP(uint32_t, Int)
  846. OP_ATTR_GET_IMP(int64_t &, Int)
  847. OP_ATTR_GET_IMP(int32_t &, Int)
  848. OP_ATTR_GET_IMP(uint32_t &, Int)
  849. OP_ATTR_SET_IMP(const vector<int64_t> &, ListInt)
  850. OP_ATTR_SET_IMP(const vector<int32_t> &, ListInt)
  851. OP_ATTR_SET_IMP(const vector<uint32_t> &, ListInt)
  852. OP_ATTR_SET_IMP(std::initializer_list<int64_t> &&, ListInt)
  853. OP_ATTR_GET_IMP(vector<int64_t> &, ListInt)
  854. OP_ATTR_GET_IMP(vector<int32_t> &, ListInt)
  855. OP_ATTR_GET_IMP(vector<uint32_t> &, ListInt)
  856. OP_ATTR_GET_IMP(vector<vector<int64_t>> &, ListListInt)
  857. OP_ATTR_SET_IMP(const vector<vector<int64_t>> &, ListListInt)
  858. OP_ATTR_SET_IMP(float, Float)
  859. OP_ATTR_GET_IMP(float &, Float)
  860. OP_ATTR_SET_IMP(const vector<float> &, ListFloat)
  861. OP_ATTR_GET_IMP(vector<float> &, ListFloat)
  862. OP_ATTR_SET_IMP(bool, Bool)
  863. OP_ATTR_GET_IMP(bool &, Bool)
  864. OP_ATTR_SET_IMP(const vector<bool> &, ListBool)
  865. OP_ATTR_GET_IMP(vector<bool> &, ListBool)
  866. OP_ATTR_SET_IMP(const string &, Str)
  867. OP_ATTR_GET_IMP(string &, Str)
  868. OP_ATTR_SET_IMP(const vector<string> &, ListStr)
  869. OP_ATTR_GET_IMP(vector<string> &, ListStr)
  870. OP_ATTR_SET_IMP(const GeAttrValue::NAMED_ATTRS &, NamedAttrs)
  871. OP_ATTR_GET_IMP(GeAttrValue::NAMED_ATTRS &, NamedAttrs)
  872. OP_ATTR_SET_IMP(const vector<GeAttrValue::NAMED_ATTRS> &, ListNamedAttrs)
  873. OP_ATTR_GET_IMP(vector<GeAttrValue::NAMED_ATTRS> &, ListNamedAttrs)
  874. OP_ATTR_REG_IMP(int64_t, Int)
  875. OP_ATTR_REG_IMP(const vector<int64_t> &, ListInt)
  876. OP_ATTR_REG_IMP(float, Float)
  877. OP_ATTR_REG_IMP(const vector<float> &, ListFloat)
  878. OP_ATTR_REG_IMP(const string &, Str)
  879. OP_ATTR_REG_IMP(const vector<string> &, ListStr)
  880. OP_ATTR_REG_IMP(bool, Bool)
  881. OP_ATTR_REG_IMP(const vector<bool> &, ListBool)
  882. OP_ATTR_REG_IMP(const vector<vector<int64_t>> &, ListListInt)
  883. OP_ATTR_REG_IMP(const GeAttrValue::NAMED_ATTRS &, NamedAttrs)
  884. OP_ATTR_REG_IMP(const vector<GeAttrValue::NAMED_ATTRS> &, ListNamedAttrs)
  885. #undef OP_ATTR_SET_IMP
  886. #undef OP_ATTR_GET_IMP
  887. #undef OP_ATTR_REG_IMP
  888. Operator &Operator::SetAttr(const string &name, const Tensor &attr_value) {
  889. if (operator_impl_ == nullptr || operator_impl_->GetOpDescImpl() == nullptr) {
  890. GELOGE(GRAPH_FAILED, "operator impl is nullptr, name %s.", name.c_str());
  891. return *this;
  892. }
  893. GeTensor tensor = TensorAdapter::AsGeTensor(attr_value);
  894. if (!AttrUtils::SetTensor(operator_impl_->GetOpDescImpl(), name, tensor)) {
  895. GELOGW("set attr name %s failed.", name.c_str());
  896. }
  897. return *this;
  898. }
  899. Operator &Operator::SetAttr(const string &name, const vector<Tensor> &attr_value) {
  900. if (operator_impl_ == nullptr || operator_impl_->GetOpDescImpl() == nullptr) {
  901. GELOGE(GRAPH_FAILED, "operator impl is nullptr, name %s.", name.c_str());
  902. return *this;
  903. }
  904. vector<GeTensor> val_list;
  905. for (const auto &item : attr_value) {
  906. auto tensor = TensorAdapter::AsGeTensor(item);
  907. val_list.push_back(tensor);
  908. }
  909. if (!AttrUtils::SetListTensor(operator_impl_->GetOpDescImpl(), name, val_list)) {
  910. GELOGW("set attr name %s failed.", name.c_str());
  911. }
  912. return *this;
  913. }
  914. graphStatus Operator::GetAttr(const string &name, Tensor &attr_value) const {
  915. if (operator_impl_ == nullptr || operator_impl_->GetOpDescImpl() == nullptr) {
  916. GELOGE(GRAPH_FAILED, "operator impl is nullptr, name %s.", name.c_str());
  917. return GRAPH_FAILED;
  918. }
  919. ConstGeTensorPtr tensor;
  920. if (!AttrUtils::GetTensor(operator_impl_->GetOpDescImpl(), name, tensor)) {
  921. GELOGW("get attr name %s failed.", name.c_str());
  922. return GRAPH_FAILED;
  923. }
  924. attr_value = TensorAdapter::GeTensor2Tensor(tensor);
  925. return GRAPH_SUCCESS;
  926. }
  927. graphStatus Operator::GetAttr(const string &name, vector<Tensor> &attr_value) const {
  928. attr_value.clear();
  929. if (operator_impl_ == nullptr || operator_impl_->GetOpDescImpl() == nullptr) {
  930. GELOGE(GRAPH_FAILED, "operator impl is nullptr, name %s.", name.c_str());
  931. return GRAPH_FAILED;
  932. }
  933. vector<ConstGeTensorPtr> val_list;
  934. if (!AttrUtils::GetListTensor(operator_impl_->GetOpDescImpl(), name, val_list)) {
  935. GELOGW("get attr name %s failed.", name.c_str());
  936. return GRAPH_FAILED;
  937. }
  938. for (auto &tensor : val_list) {
  939. attr_value.push_back(TensorAdapter::GeTensor2Tensor(tensor));
  940. }
  941. return GRAPH_SUCCESS;
  942. }
  943. Operator &Operator::SetAttr(const string &name, const OpBytes &attr_value) {
  944. if (operator_impl_ == nullptr || operator_impl_->GetOpDescImpl() == nullptr) {
  945. GELOGE(GRAPH_FAILED, "operator impl is nullptr, name %s.", name.c_str());
  946. return *this;
  947. }
  948. if (!AttrUtils::SetZeroCopyBytes(operator_impl_->GetOpDescImpl(), name,
  949. Buffer::CopyFrom(attr_value.data(), attr_value.size()))) {
  950. GELOGW("set attr name %s failed.", name.c_str());
  951. }
  952. return *this;
  953. }
  954. graphStatus Operator::GetAttr(const string &name, OpBytes &attr_value) const {
  955. if (operator_impl_ == nullptr || operator_impl_->GetOpDescImpl() == nullptr) {
  956. GELOGE(GRAPH_FAILED, "operator impl is nullptr, name %s.", name.c_str());
  957. return GRAPH_FAILED;
  958. }
  959. Buffer buffer;
  960. if (!AttrUtils::GetZeroCopyBytes(operator_impl_->GetOpDescImpl(), name, buffer)) {
  961. GELOGW("get attr name %s failed.", name.c_str());
  962. return GRAPH_FAILED;
  963. }
  964. attr_value.clear();
  965. if (buffer.data() == nullptr) {
  966. GELOGE(GRAPH_FAILED, "buffer data is null.");
  967. return GRAPH_FAILED;
  968. }
  969. attr_value.assign(buffer.data(), buffer.data() + buffer.size());
  970. return GRAPH_SUCCESS;
  971. }
  972. Operator &Operator::SetAttr(const string &name, ge::AttrValue &&attrValue) {
  973. GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return *this, "operator impl is nullptr.");
  974. (void)operator_impl_->SetAttr(name, std::move(attrValue.impl->geAttrValue_));
  975. return *this;
  976. }
  977. graphStatus Operator::GetAttr(const string &name, ge::AttrValue &attrValue) const {
  978. GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return GRAPH_FAILED, "operator impl is nullptr.");
  979. return operator_impl_->GetAttr(name, attrValue.impl->geAttrValue_);
  980. }
  981. Operator &Operator::SetAttr(const string &name, const std::vector<ge::DataType> &attr_value) {
  982. if (operator_impl_ == nullptr || !operator_impl_->GetOpDescImpl()) {
  983. GELOGE(GRAPH_FAILED, "operator impl is nullptr, name %s.", name.c_str());
  984. return *this;
  985. }
  986. if (!AttrUtils::SetListDataType(operator_impl_->GetOpDescImpl(), name, attr_value)) {
  987. GELOGW("set attr name %s failed.", name.c_str());
  988. }
  989. return *this;
  990. }
  991. graphStatus Operator::GetAttr(const string &name, std::vector<ge::DataType> &attr_value) const {
  992. attr_value.clear();
  993. if (operator_impl_ == nullptr || operator_impl_->GetOpDescImpl() == nullptr) {
  994. GELOGE(GRAPH_FAILED, "operator impl is nullptr, name %s.", name.c_str());
  995. return GRAPH_FAILED;
  996. }
  997. if (!AttrUtils::GetListDataType(operator_impl_->GetOpDescImpl(), name, attr_value)) {
  998. GELOGW("get attr name %s failed.", name.c_str());
  999. return GRAPH_FAILED;
  1000. }
  1001. return GRAPH_SUCCESS;
  1002. }
  1003. Operator &Operator::SetAttr(const string &name, const ge::DataType &attr_value) {
  1004. if (operator_impl_ == nullptr || operator_impl_->GetOpDescImpl() == nullptr) {
  1005. GELOGE(GRAPH_FAILED, "operator impl is nullptr, name %s.", name.c_str());
  1006. return *this;
  1007. }
  1008. if (!AttrUtils::SetDataType(operator_impl_->GetOpDescImpl(), name, attr_value)) {
  1009. GELOGW("set attr name %s failed.", name.c_str());
  1010. }
  1011. return *this;
  1012. }
  1013. graphStatus Operator::GetAttr(const string &name, ge::DataType &attr_value) const {
  1014. if (operator_impl_ == nullptr || operator_impl_->GetOpDescImpl() == nullptr) {
  1015. GELOGE(GRAPH_FAILED, "operator impl is nullptr, name %s.", name.c_str());
  1016. return GRAPH_FAILED;
  1017. }
  1018. if (!AttrUtils::GetDataType(operator_impl_->GetOpDescImpl(), name, attr_value)) {
  1019. GELOGW("get attr name %s failed.", name.c_str());
  1020. return GRAPH_FAILED;
  1021. }
  1022. return GRAPH_SUCCESS;
  1023. }
  1024. void Operator::AttrRegister(const string &name, const std::vector<ge::DataType> &attr_value) {
  1025. if (operator_impl_ == nullptr || operator_impl_->GetOpDescImpl() == nullptr) {
  1026. GELOGE(GRAPH_FAILED, "operator impl is nullptr, name %s.", name.c_str());
  1027. return;
  1028. }
  1029. if (!AttrUtils::SetListDataType(operator_impl_->GetOpDescImpl(), name, attr_value)) {
  1030. GELOGW("set attr name %s failed.", name.c_str());
  1031. }
  1032. }
  1033. void Operator::AttrRegister(const string &name, const ge::DataType &attr_value) {
  1034. if (operator_impl_ == nullptr || operator_impl_->GetOpDescImpl() == nullptr) {
  1035. GELOGE(GRAPH_FAILED, "operator impl is nullptr, name %s.", name.c_str());
  1036. return;
  1037. }
  1038. if (!AttrUtils::SetDataType(operator_impl_->GetOpDescImpl(), name, attr_value)) {
  1039. GELOGW("set attr name %s failed.", name.c_str());
  1040. }
  1041. }
  1042. void Operator::AttrRegister(const string &name, const Tensor &attr_value) {
  1043. if (operator_impl_ == nullptr || operator_impl_->GetOpDescImpl() == nullptr) {
  1044. GELOGE(GRAPH_FAILED, "operator impl is nullptr, name %s.", name.c_str());
  1045. return;
  1046. }
  1047. auto tensor = TensorAdapter::AsGeTensor(attr_value);
  1048. if (!AttrUtils::SetTensor(operator_impl_->GetOpDescImpl(), name, tensor)) {
  1049. GELOGW("reg attr name %s failed.", name.c_str());
  1050. }
  1051. }
  1052. void Operator::AttrRegister(const string &name, const vector<Tensor> &attr_value) {
  1053. if (operator_impl_ == nullptr || operator_impl_->GetOpDescImpl() == nullptr) {
  1054. GELOGE(GRAPH_FAILED, "operator impl is nullptr, name %s.", name.c_str());
  1055. return;
  1056. }
  1057. vector<GeTensor> val_list;
  1058. for (const auto &item : attr_value) {
  1059. val_list.push_back(TensorAdapter::AsGeTensor(item));
  1060. }
  1061. if (!AttrUtils::SetListTensor(operator_impl_->GetOpDescImpl(), name, val_list)) {
  1062. GELOGW("reg attr name %s failed.", name.c_str());
  1063. }
  1064. }
  1065. void Operator::AttrRegister(const string &name, const OpBytes &attr_value) {
  1066. if (operator_impl_ == nullptr || operator_impl_->GetOpDescImpl() == nullptr) {
  1067. GELOGE(GRAPH_FAILED, "operator impl is nullptr, name %s.", name.c_str());
  1068. return;
  1069. }
  1070. if (!AttrUtils::SetZeroCopyBytes(operator_impl_->GetOpDescImpl(), name,
  1071. Buffer::CopyFrom(attr_value.data(), attr_value.size()))) {
  1072. GELOGW("reg attr name %s failed.", name.c_str());
  1073. }
  1074. }
  1075. void Operator::SubgraphRegister(const std::string &name, bool dynamic) {
  1076. if (operator_impl_ == nullptr) {
  1077. GELOGE(GRAPH_FAILED, "operator impl is nullptr, name %s.", name.c_str());
  1078. return;
  1079. }
  1080. operator_impl_->SubgraphRegister(name, dynamic ? kDynamic : kStatic);
  1081. }
  1082. void Operator::SubgraphCountRegister(const std::string &name, uint32_t count) {
  1083. if (operator_impl_ == nullptr) {
  1084. GELOGE(GRAPH_FAILED, "operator impl is nullptr, name %s.", name.c_str());
  1085. return;
  1086. }
  1087. operator_impl_->SubgraphCountRegister(name, count);
  1088. }
  1089. void Operator::SetSubgraphBuilder(const std::string &ir_name, uint32_t index, const SubgraphBuilder &builder) {
  1090. if (operator_impl_ == nullptr) {
  1091. GELOGE(GRAPH_FAILED, "operator impl is nullptr, name %s.", ir_name.c_str());
  1092. return;
  1093. }
  1094. operator_impl_->SetSubgraphBuilder(ir_name, index, builder);
  1095. }
  1096. std::vector<std::string> Operator::GetSubgraphNames() const { return operator_impl_->GetSubgraphNames(); }
  1097. SubgraphBuilder Operator::GetDynamicSubgraphBuilder(const string &ir_name, uint32_t index) const {
  1098. if (operator_impl_ == nullptr) {
  1099. GELOGE(GRAPH_FAILED, "operator impl is nullptr.");
  1100. return nullptr;
  1101. }
  1102. return operator_impl_->GetSubgraphBuilder(ir_name, index);
  1103. }
  1104. SubgraphBuilder Operator::GetSubgraphBuilder(const string &ir_name) const {
  1105. return GetDynamicSubgraphBuilder(ir_name, 0);
  1106. }
  1107. Graph Operator::GetSubgraph(const string &name) const {
  1108. if (operator_impl_ == nullptr) {
  1109. GE_LOGE("Failed to get subgraph %s, the operator impl is null", name.c_str());
  1110. return Graph("");
  1111. }
  1112. auto op_desc = OpDescUtils::GetOpDescFromOperator(*this);
  1113. if (op_desc == nullptr) {
  1114. GE_LOGE("Failed to get subgraph %s, the op_desc is null", name.c_str());
  1115. return Graph("");
  1116. }
  1117. const auto &subgraph_names_to_index = op_desc->GetSubgraphNameIndexes();
  1118. auto iter = subgraph_names_to_index.find(name);
  1119. if (iter == subgraph_names_to_index.end()) {
  1120. GE_LOGE("Failed to get subgraph %s, the name may be invalid", name.c_str());
  1121. return Graph("");
  1122. }
  1123. auto subgraph_instance_name = op_desc->GetSubgraphInstanceName(iter->second);
  1124. if (subgraph_instance_name.empty()) {
  1125. GE_LOGE("Failed to get subgraph %s index %u, the subgraph may not be added", name.c_str(), iter->second);
  1126. return Graph("");
  1127. }
  1128. auto node = operator_impl_->GetNode();
  1129. if (node == nullptr) {
  1130. GE_LOGE("Failed to get subgraph %s, the node is null", name.c_str());
  1131. return Graph("");
  1132. }
  1133. auto root_graph = GraphUtils::FindRootGraph(node->GetOwnerComputeGraph());
  1134. if (root_graph == nullptr) {
  1135. GE_LOGE("Failed to get subgraph %s, can not find the root graph", name.c_str());
  1136. return Graph("");
  1137. }
  1138. auto subgraph = root_graph->GetSubgraph(subgraph_instance_name);
  1139. if (subgraph == nullptr) {
  1140. GE_LOGE("Failed to get subgraph %s index %u, can not find the instance %s from the root graph", name.c_str(),
  1141. iter->second, subgraph_instance_name.c_str());
  1142. return Graph("");
  1143. }
  1144. return GraphUtils::CreateGraphFromComputeGraph(subgraph);
  1145. }
  1146. Graph Operator::GetDynamicSubgraph(const string &name, uint32_t index) const {
  1147. return GetSubgraph(name + std::to_string(index));
  1148. }
  1149. size_t Operator::GetSubgraphNamesCount() const {
  1150. if (operator_impl_ == nullptr) {
  1151. GE_LOGE("Failed to get subgraph names count, the operator impl is null");
  1152. return 0;
  1153. }
  1154. return operator_impl_->GetSubgraphNamesCount();
  1155. }
  1156. class GraphBuilderImpl {
  1157. public:
  1158. explicit GraphBuilderImpl(const string &name) : graph_(ComGraphMakeShared<ComputeGraph>(name)) {
  1159. if (graph_ == nullptr) {
  1160. GELOGE(GRAPH_FAILED, "ComputeGraph make shared failed");
  1161. return;
  1162. }
  1163. }
  1164. ~GraphBuilderImpl() {}
  1165. ComputeGraphPtr BuildGraph(const std::vector<Operator> &inputs) {
  1166. std::vector<OperatorImplPtr> vec_inputs;
  1167. for (auto &it : inputs) {
  1168. auto src_op_impl = it.operator_impl_;
  1169. GE_CHK_BOOL_EXEC(src_op_impl != nullptr, return nullptr, "Operator Impl is null.");
  1170. GE_CHK_BOOL_EXEC(src_op_impl->op_desc_ != nullptr, return nullptr, "Operator impl's opdesc is null.");
  1171. string type = src_op_impl->op_desc_->GetType();
  1172. auto node_op = ge::OperatorFactory::CreateOperator("node_op", type);
  1173. auto tensor_desc = ge::OpDescUtils::GetOpDescFromOperator(node_op);
  1174. node_op.BreakConnect();
  1175. GE_CHK_BOOL_EXEC(tensor_desc != nullptr, continue, "tensor_desc is null.");
  1176. if ((tensor_desc->GetInputsSize() == 0 && tensor_desc->GetOutputsSize() > 0) || type == DATA ||
  1177. type == VARIABLE || type == INITDATA || type == GETNEXT) {
  1178. vec_inputs.push_back(it.operator_impl_);
  1179. } else {
  1180. GELOGW("Input operator should be Data, Variable operator or operator that has output but no input.");
  1181. }
  1182. }
  1183. GE_CHK_BOOL_EXEC(!vec_inputs.empty(), return nullptr,
  1184. "User Input do not include operator such as "
  1185. "Data, Variable operator or operator that has output but no input.");
  1186. auto ret = WalkAllOperators(vec_inputs);
  1187. GE_CHK_BOOL_EXEC(ret == GRAPH_SUCCESS, return nullptr, "WalkAllOperators failed.");
  1188. ret = AddEdge();
  1189. GE_CHK_BOOL_EXEC(ret == GRAPH_SUCCESS, return nullptr, "AddEdge failed.");
  1190. return graph_;
  1191. }
  1192. const std::map<OperatorImplPtr, NodePtr> &GetAllNodesInfo() const { return all_nodes_info_; }
  1193. private:
  1194. graphStatus WalkAllOperators(const std::vector<OperatorImplPtr> &vec_ops) {
  1195. GE_CHK_BOOL_EXEC(graph_ != nullptr, return GRAPH_FAILED, "graph_ is null.")
  1196. std::queue<std::vector<OperatorImplPtr>> que;
  1197. que.push(vec_ops);
  1198. while (!que.empty()) {
  1199. auto vec_tem = que.front();
  1200. que.pop();
  1201. for (const auto &op_impl : vec_tem) {
  1202. GE_CHK_BOOL_EXEC(op_impl != nullptr, return GRAPH_FAILED, "Operator Impl is null.")
  1203. GE_CHK_BOOL_EXEC_INFO(all_nodes_info_.find(op_impl) == all_nodes_info_.end(), continue,
  1204. "This node %s has created.", op_impl->GetName().c_str())
  1205. auto node_ptr = graph_->AddNode(op_impl->op_desc_);
  1206. GE_CHK_BOOL_EXEC(node_ptr != nullptr, return GRAPH_FAILED, "Add node failed.");
  1207. all_nodes_info_.insert(std::make_pair(op_impl, node_ptr));
  1208. auto &out_links = op_impl->output_links_;
  1209. std::vector<OperatorImplPtr> vec_op_forward{};
  1210. for (const auto &out_link : out_links) {
  1211. for (const auto &op_forward : out_link.second) {
  1212. vec_op_forward.push_back(op_forward.GetOwner());
  1213. }
  1214. }
  1215. auto &out_control_links = op_impl->control_output_link_;
  1216. for (const auto &out_link : out_control_links) {
  1217. vec_op_forward.push_back(out_link.lock());
  1218. }
  1219. que.push(vec_op_forward);
  1220. auto &in_links = op_impl->input_link_;
  1221. std::vector<OperatorImplPtr> vec_op_back_forward{};
  1222. for (const auto &in_link : in_links) {
  1223. vec_op_back_forward.push_back(in_link.second.GetOwner());
  1224. }
  1225. auto &in_control_links = op_impl->control_input_link_;
  1226. for (const auto &in_link : in_control_links) {
  1227. vec_op_back_forward.push_back(in_link.lock());
  1228. }
  1229. que.push(vec_op_back_forward);
  1230. if (WalkAllSubgraphs(node_ptr, op_impl) != GRAPH_SUCCESS) {
  1231. return GRAPH_FAILED;
  1232. }
  1233. }
  1234. }
  1235. return MoveSubgraphToRoot(graph_);
  1236. }
  1237. graphStatus WalkAllSubgraphs(const NodePtr &node, const OperatorImplPtr &op_impl) {
  1238. const string name = node->GetName();
  1239. for (auto &name_idx : op_impl->op_desc_->GetSubgraphNameIndexes()) {
  1240. const SubgraphBuilder &builder = op_impl->GetSubgraphBuilder(name_idx.first);
  1241. GE_CHK_BOOL_EXEC(builder != nullptr, return GRAPH_FAILED, "Node: %s, Get builder failed.", name.c_str());
  1242. Graph graph = builder(); // Build subgraph from user define builder.
  1243. const ComputeGraphPtr &subgraph = GraphUtils::GetComputeGraph(graph);
  1244. GE_CHK_BOOL_EXEC(subgraph != nullptr, return GRAPH_FAILED, "Node: %s, Build graph failed.", name.c_str());
  1245. subgraph->SetParentNode(node);
  1246. subgraph->SetParentGraph(graph_);
  1247. if (graph_->AddSubgraph(subgraph->GetName(), subgraph) != GRAPH_SUCCESS) {
  1248. return GRAPH_FAILED;
  1249. }
  1250. if (op_impl->op_desc_->SetSubgraphInstanceName(name_idx.second, subgraph->GetName()) != GRAPH_SUCCESS) {
  1251. GELOGE(GRAPH_FAILED, "Failed to set subgraph %s index %u", subgraph->GetName().c_str(), name_idx.second);
  1252. return GRAPH_FAILED;
  1253. }
  1254. }
  1255. return GRAPH_SUCCESS;
  1256. }
  1257. graphStatus MoveSubgraphToRoot(const ComputeGraphPtr &graph) {
  1258. const ComputeGraphPtr &root_graph = GraphUtils::FindRootGraph(graph);
  1259. if (root_graph == nullptr) {
  1260. GELOGE(GRAPH_FAILED, "Graph: %s, Find root graph failed.", graph->GetName().c_str());
  1261. return GRAPH_FAILED;
  1262. }
  1263. if (root_graph == graph) {
  1264. auto subgraphs = graph->GetAllSubgraphs();
  1265. for (auto &subgraph : subgraphs) {
  1266. if (MoveSubgraphToRoot(subgraph) != GRAPH_SUCCESS) {
  1267. return GRAPH_FAILED;
  1268. }
  1269. }
  1270. } else {
  1271. auto subgraphs = graph->GetAllSubgraphs();
  1272. for (auto &subgraph : subgraphs) {
  1273. if (root_graph->AddSubgraph(subgraph->GetName(), subgraph) != GRAPH_SUCCESS) {
  1274. return GRAPH_FAILED;
  1275. }
  1276. graph->RemoveSubgraph(subgraph->GetName());
  1277. if (MoveSubgraphToRoot(subgraph) != GRAPH_SUCCESS) {
  1278. return GRAPH_FAILED;
  1279. }
  1280. }
  1281. }
  1282. return GRAPH_SUCCESS;
  1283. }
  1284. graphStatus AddEdge() {
  1285. for (const auto &node_info : all_nodes_info_) {
  1286. auto src_op_impl_ptr = node_info.first;
  1287. auto src_node_ptr = node_info.second;
  1288. GE_IF_BOOL_EXEC(src_op_impl_ptr == nullptr || src_node_ptr == nullptr, continue);
  1289. auto out_links = src_op_impl_ptr->output_links_;
  1290. GE_CHK_BOOL_EXEC(src_op_impl_ptr->op_desc_ != nullptr, return GRAPH_FAILED,
  1291. "Src operator impl's op_desc is null.");
  1292. auto &op_desc = src_op_impl_ptr->op_desc_;
  1293. GE_IF_BOOL_EXEC(op_desc == nullptr, continue);
  1294. for (const auto &out : out_links) {
  1295. auto src_idx = op_desc->GetOutputIndexByName(out.first);
  1296. GE_CHK_BOOL_EXEC(src_idx >= 0, return GRAPH_FAILED, "Find output index by name failed");
  1297. auto src_anchor = src_node_ptr->GetOutDataAnchor(src_idx);
  1298. GE_CHK_BOOL_EXEC(src_anchor != nullptr, return GRAPH_FAILED, "GetOutDataAnchor failed.");
  1299. for (const auto &dst_opio : out.second) {
  1300. auto dst_node_info = all_nodes_info_.find(dst_opio.GetOwner());
  1301. GE_CHK_BOOL_EXEC(dst_node_info != all_nodes_info_.end(), return GRAPH_FAILED, "Find Dst node failed.");
  1302. GE_IF_BOOL_EXEC(dst_node_info->second == nullptr, continue);
  1303. auto dst_anchor = dst_node_info->second->GetInDataAnchor(dst_opio.GetIndex());
  1304. GE_CHK_BOOL_EXEC(dst_anchor != nullptr, return GRAPH_FAILED, "GetInDataAnchor failed.");
  1305. auto ret = GraphUtils::AddEdge(src_anchor, dst_anchor);
  1306. GE_CHK_BOOL_EXEC(ret == GRAPH_SUCCESS, return GRAPH_FAILED,
  1307. "from node[%s][%d] to node[%s][%d]AddEdge failed.", src_node_ptr->GetName().c_str(),
  1308. src_anchor->GetIdx(), dst_node_info->second->GetName().c_str(), dst_anchor->GetIdx());
  1309. }
  1310. }
  1311. auto out_control_anchor = src_node_ptr->GetOutControlAnchor();
  1312. for (const auto &control_out : src_op_impl_ptr->control_output_link_) {
  1313. auto dst_node_info = all_nodes_info_.find(control_out.lock());
  1314. if (dst_node_info == all_nodes_info_.end()) {
  1315. GELOGE(GRAPH_FAILED, "Find Dst node failed.");
  1316. return GRAPH_FAILED;
  1317. }
  1318. GE_IF_BOOL_EXEC(dst_node_info->second == nullptr, continue);
  1319. auto in_control_anchor = dst_node_info->second->GetInControlAnchor();
  1320. auto ret = GraphUtils::AddEdge(out_control_anchor, in_control_anchor);
  1321. if (ret != GRAPH_SUCCESS) {
  1322. GELOGE(ret, "AddEdge failed. srcNode %s:%s, dstNode %s:%s", op_desc->GetName().c_str(),
  1323. op_desc->GetType().c_str(), dst_node_info->second->GetName().c_str(),
  1324. dst_node_info->second->GetType().c_str());
  1325. return ret;
  1326. }
  1327. }
  1328. }
  1329. return GRAPH_SUCCESS;
  1330. }
  1331. ComputeGraphPtr graph_ = nullptr;
  1332. std::map<OperatorImplPtr, NodePtr> all_nodes_info_{};
  1333. };
  1334. inline bool HasSameNameNode(const ComputeGraphPtr &compute_graph) {
  1335. for (const auto &graph : compute_graph->GetAllSubgraphs()) {
  1336. std::set<string> node_names;
  1337. for (auto const &node : graph->GetDirectNode()) {
  1338. auto result = node_names.insert(node->GetName());
  1339. if (!result.second) {
  1340. GELOGE(GRAPH_FAILED, "graph %s has same name node%s", graph->GetName().c_str(), node->GetName().c_str());
  1341. return true;
  1342. }
  1343. }
  1344. }
  1345. std::set<string> node_names;
  1346. for (auto const &node : compute_graph->GetDirectNode()) {
  1347. auto result = node_names.insert(node->GetName());
  1348. if (!result.second) {
  1349. GELOGE(GRAPH_FAILED, "graph %s has same name node%s", compute_graph->GetName().c_str(), node->GetName().c_str());
  1350. return true;
  1351. }
  1352. }
  1353. return false;
  1354. }
  1355. ComputeGraphPtr GraphUtils::CreateGraphFromOperator(const string &name, const vector<ge::Operator> &inputs) {
  1356. auto graph_builder_impl = GraphBuilderImpl(name);
  1357. ComputeGraphPtr compute_graph = graph_builder_impl.BuildGraph(inputs);
  1358. GE_CHK_BOOL_EXEC(compute_graph != nullptr, return compute_graph, "Computer graph is nullptr");
  1359. compute_graph->SetAllNodesInfo(graph_builder_impl.GetAllNodesInfo());
  1360. if (HasSameNameNode(compute_graph)) {
  1361. GELOGW("Compute do not allow has same name nodes.");
  1362. compute_graph = nullptr;
  1363. }
  1364. return compute_graph;
  1365. }
  1366. void GraphUtils::BreakConnect(const std::map<OperatorImplPtr, NodePtr> &all_nodes_infos) {
  1367. for (const auto &it : all_nodes_infos) {
  1368. OperatorImplPtr op_impl = it.first;
  1369. if (op_impl == nullptr) {
  1370. GELOGW("operator impl is nullptr.");
  1371. continue;
  1372. }
  1373. op_impl->ClearOutputLinks();
  1374. op_impl->ClearInputLinks();
  1375. OperatorKeeper::GetInstance().CheckOutOperator(op_impl);
  1376. }
  1377. }
  1378. } // namespace ge

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示