You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

op_desc_utils.cc 28 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768
  1. /**
  2. * Copyright 2019-2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "utils/op_desc_utils.h"
  17. #include <algorithm>
  18. #include "debug/ge_attr_define.h"
  19. #include "debug/ge_op_types.h"
  20. #include "debug/ge_util.h"
  21. #include "framework/common/debug/ge_log.h"
  22. #include "graph/anchor.h"
  23. #include "graph/compute_graph.h"
  24. #include "graph/ge_attr_value.h"
  25. #include "utils/graph_utils.h"
  26. #include "utils/node_utils.h"
  27. using std::vector;
  28. /*lint -e512 -e737 -e752*/
  29. namespace ge {
  30. const char OP_DESC_QUANT_PARAMS[] = "quantize_factor";
  31. static const int CONST_OP_NORMAL_WEIGHT_SIZE = 1;
  32. bool OpDescUtils::ClearInputDesc(const NodePtr &node) {
  33. GE_CHK_BOOL_EXEC(node != nullptr, return false, "node is nullptr");
  34. GE_CHK_BOOL_EXEC(node->GetOpDesc() != nullptr, return false, "opdesc is nullptr");
  35. vector<int> index_list;
  36. for (const auto &in_anchor : node->GetAllInDataAnchors()) {
  37. if (in_anchor->GetPeerOutAnchor() == nullptr) {
  38. index_list.push_back(in_anchor->GetIdx());
  39. }
  40. }
  41. std::sort(index_list.begin(), index_list.end());
  42. // Node's in anchor index need shrink
  43. for (size_t i = 0; i < index_list.size(); ++i) {
  44. auto iter = node->GetOpDesc()->inputs_desc_.begin() + index_list[i];
  45. if (iter < node->GetOpDesc()->inputs_desc_.end()) {
  46. (void)node->GetOpDesc()->inputs_desc_.erase(iter);
  47. } else {
  48. GELOGW("inputs_desc_ iterator out of range.");
  49. }
  50. }
  51. return true;
  52. }
  53. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool OpDescUtils::ClearInputDesc(OpDescPtr op_desc,
  54. const uint32_t index) {
  55. GE_CHK_BOOL_EXEC(op_desc != nullptr, return false, "op_desc is nullptr");
  56. GE_CHK_BOOL_EXEC(index < op_desc->inputs_desc_.size(), return false, "index %u is invalid.", index);
  57. auto iter = op_desc->inputs_desc_.begin() + index;
  58. if (iter < op_desc->inputs_desc_.end()) {
  59. (void)op_desc->inputs_desc_.erase(iter);
  60. } else {
  61. GELOGW("inputs_desc_ iterator out of range.");
  62. }
  63. return true;
  64. }
  65. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool OpDescUtils::HasQuantizeFactorParams(const OpDescPtr &op_desc) {
  66. GE_CHK_BOOL_EXEC_INFO(op_desc != nullptr, return false, "op_desc is nullptr");
  67. return op_desc->HasAttr(OP_DESC_QUANT_PARAMS);
  68. }
  69. bool OpDescUtils::ClearOutputDesc(const NodePtr &node) {
  70. GE_CHK_BOOL_EXEC(node != nullptr, return false, "node is nullptr");
  71. GE_CHK_BOOL_EXEC(node->GetOpDesc() != nullptr, return false, "opdesc is nullptr");
  72. vector<int> index_list;
  73. for (const auto &out_anchor : node->GetAllOutDataAnchors()) {
  74. if (out_anchor->GetPeerInDataAnchors().empty()) {
  75. index_list.push_back(out_anchor->GetIdx());
  76. }
  77. }
  78. std::sort(index_list.begin(), index_list.end());
  79. // Node's out anchor index need shrink
  80. for (size_t i = 0; i < index_list.size(); ++i) {
  81. auto iter = node->GetOpDesc()->outputs_desc_.begin() + index_list[i];
  82. if (iter < node->GetOpDesc()->outputs_desc_.end()) {
  83. (void)node->GetOpDesc()->outputs_desc_.erase(iter);
  84. } else {
  85. GELOGW("outputs_desc_ iterator out of range.");
  86. }
  87. }
  88. return true;
  89. }
  90. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool OpDescUtils::ClearOutputDesc(const OpDescPtr &op_desc,
  91. uint32_t index) {
  92. GE_CHK_BOOL_EXEC(op_desc != nullptr, return false, "op_desc is nullptr");
  93. GE_CHK_BOOL_EXEC(index < op_desc->outputs_desc_.size(), return false, "index %u is invalid.", index);
  94. auto iter = op_desc->outputs_desc_.begin() + index;
  95. if (iter < op_desc->outputs_desc_.end()) {
  96. (void)op_desc->outputs_desc_.erase(iter);
  97. } else {
  98. GELOGW("outputs_desc_ iterator out of range.");
  99. }
  100. return true;
  101. }
  102. bool OpDescUtils::HasQuantizeFactorParams(const OpDesc &op_desc) { return op_desc.HasAttr(OP_DESC_QUANT_PARAMS); }
  103. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus
  104. OpDescUtils::GetQuantizeFactorParams(const OpDescPtr &op_desc, QuantizeFactorParams &quant) {
  105. GE_CHK_BOOL_EXEC_INFO(op_desc != nullptr, return GRAPH_FAILED, "op_desc is nullptr");
  106. GeAttrValue attr_value;
  107. GE_CHK_BOOL_EXEC_INFO(op_desc->GetAttr(OP_DESC_QUANT_PARAMS, attr_value) == GRAPH_SUCCESS, return GRAPH_FAILED,
  108. "GetQuantizeFactorParams failed");
  109. return attr_value.GetValue<QuantizeFactorParams>(quant);
  110. }
  111. graphStatus OpDescUtils::GetQuantizeFactorParams(const OpDesc &op_desc, QuantizeFactorParams &quant) {
  112. GeAttrValue attr_value;
  113. GE_CHK_BOOL_EXEC_INFO(op_desc.GetAttr(OP_DESC_QUANT_PARAMS, attr_value) == GRAPH_SUCCESS, return GRAPH_FAILED,
  114. "GetQuantizeFactorParams failed");
  115. return attr_value.GetValue<QuantizeFactorParams>(quant);
  116. }
  117. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus
  118. OpDescUtils::SetQuantizeFactorParams(const OpDescPtr &op_desc, const QuantizeFactorParams &quant) {
  119. GE_CHK_BOOL_EXEC_INFO(op_desc != nullptr, return GRAPH_FAILED, "op_desc is nullptr");
  120. return op_desc->SetAttr(OP_DESC_QUANT_PARAMS, GeAttrValue::CreateFrom<QuantizeFactorParams>(quant)); // lint !e732
  121. }
  122. graphStatus OpDescUtils::SetQuantizeFactorParams(OpDesc &op_desc, const QuantizeFactorParams &quant) {
  123. return op_desc.SetAttr(OP_DESC_QUANT_PARAMS, GeAttrValue::CreateFrom<QuantizeFactorParams>(quant)); // lint !e732
  124. }
  125. GeTensorPtr OpDescUtils::MutableWeights(OpDesc &op_desc) {
  126. GeTensorPtr weight = nullptr;
  127. if (!AttrUtils::MutableTensor(&op_desc, ATTR_NAME_WEIGHTS, weight)) {
  128. GELOGW("MutableTensor error");
  129. }
  130. return weight;
  131. }
  132. GE_FUNC_HOST_VISIBILITY GeTensorPtr OpDescUtils::MutableWeights(OpDescPtr op_desc) {
  133. if (op_desc == nullptr) {
  134. GELOGE(GRAPH_FAILED, "op_desc is null");
  135. return nullptr;
  136. }
  137. return MutableWeights(*op_desc);
  138. }
  139. graphStatus OpDescUtils::SetWeights(OpDesc &op_desc, const GeTensorPtr weight) {
  140. if (weight == nullptr) {
  141. GELOGE(GRAPH_FAILED, "weight is null");
  142. return GRAPH_FAILED;
  143. }
  144. return AttrUtils::SetTensor(&op_desc, ATTR_NAME_WEIGHTS, weight) ? GRAPH_SUCCESS : GRAPH_FAILED;
  145. }
  146. graphStatus OpDescUtils::SetWeights(OpDescPtr op_desc, const GeTensorPtr weight) {
  147. GE_CHECK_NOTNULL(op_desc);
  148. GE_CHECK_NOTNULL(weight);
  149. return SetWeights(*op_desc, weight);
  150. }
  151. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector<ConstGeTensorPtr> OpDescUtils::GetWeights(const ge::Node &node) {
  152. auto weights = MutableWeights(node);
  153. vector<ConstGeTensorPtr> ret(weights.size());
  154. std::copy(weights.begin(), weights.end(), ret.begin());
  155. return ret;
  156. }
  157. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector<ConstGeTensorPtr> OpDescUtils::GetWeights(
  158. const ge::ConstNodePtr &node) {
  159. if (node == nullptr) {
  160. return vector<ge::ConstGeTensorPtr>();
  161. }
  162. return GetWeights(*node);
  163. }
  164. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector<ge::NodePtr> OpDescUtils::GetConstInputNode(
  165. const ge::Node &node) {
  166. vector<ge::NodePtr> ret;
  167. auto in_anchors = node.GetAllInDataAnchors();
  168. for (const auto &in_anchor : in_anchors) {
  169. auto out_anchor = in_anchor->GetPeerOutAnchor();
  170. if (out_anchor == nullptr) {
  171. // normally out_anchor could be null, this is ok
  172. GELOGD("node %s' peer_out_anchor is null", node.GetName().c_str());
  173. continue;
  174. }
  175. auto in_node = out_anchor->GetOwnerNode();
  176. while (true) {
  177. if (in_node == nullptr) {
  178. break;
  179. }
  180. if ((in_node->GetType() == CONSTANT) || (in_node->GetType() == CONSTANTOP)) {
  181. ret.push_back(in_node);
  182. break;
  183. } else if (in_node->GetType() == DATA) {
  184. if (NodeUtils::IsWhileVaryingInput(in_node)) {
  185. break;
  186. }
  187. in_node = NodeUtils::GetParentInput(in_node);
  188. } else if ((in_node->GetType() == ENTER) || (in_node->GetType() == REFENTER)) {
  189. bool is_constant = false;
  190. (void)AttrUtils::GetBool(in_node->GetOpDesc(), ENTER_ATTR_CONSTANT_FLAG, is_constant);
  191. if (!is_constant) {
  192. break;
  193. }
  194. // Enter node has and only has one input
  195. if (in_node->GetInDataNodes().size() != 1) {
  196. GELOGW("Check number of input_nodes for Enter node %s failed, size=%zu.", node.GetName().c_str(),
  197. in_node->GetInDataNodes().size());
  198. break;
  199. }
  200. in_node = in_node->GetInDataNodes().at(0);
  201. } else {
  202. break;
  203. }
  204. }
  205. }
  206. return ret;
  207. }
  208. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector<ConstGeTensorPtr> OpDescUtils::GetInputData(
  209. const vector<ge::NodePtr> &input_nodes) {
  210. vector<ConstGeTensorPtr> ret;
  211. for (const auto &input_node : input_nodes) {
  212. auto temp_weight = MutableWeights(input_node->GetOpDesc());
  213. if (temp_weight == nullptr) {
  214. GELOGE(GRAPH_FAILED, "const op's weight is null, name: %s", input_node->GetName().c_str());
  215. return vector<ConstGeTensorPtr>();
  216. }
  217. ret.push_back(temp_weight);
  218. }
  219. return ret;
  220. }
  221. size_t OpDescUtils::GetNonConstInputsSize(const ge::Node &node) {
  222. if (NodeUtils::IsAnchorStatusSet(node)) {
  223. size_t input_num = 0;
  224. for (const auto &anchor : node.GetAllInDataAnchors()) {
  225. if (ge::AnchorUtils::GetStatus(anchor) == ANCHOR_DATA) {
  226. input_num++;
  227. continue;
  228. }
  229. }
  230. return input_num; // lint !e712
  231. } else {
  232. GE_IF_BOOL_EXEC(
  233. node.GetInDataNodes().size() < GetConstInputs(node).size(),
  234. GELOGE(GRAPH_FAILED, "%zu is smaller than %zu", node.GetInDataNodes().size(), GetConstInputs(node).size());
  235. return 0);
  236. return node.GetInDataNodes().size() - GetConstInputs(node).size();
  237. }
  238. }
  239. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY size_t OpDescUtils::GetNonConstInputsSize(const ge::ConstNodePtr node) {
  240. if (node == nullptr) {
  241. GELOGE(GRAPH_FAILED, "Node is nullptr");
  242. return 0;
  243. }
  244. return GetNonConstInputsSize(*node);
  245. }
  246. GeTensorDesc OpDescUtils::GetNonConstInputTensorDesc(const ge::Node &node, size_t index_non_const) {
  247. GE_CHK_BOOL_EXEC(node.GetOpDesc() != nullptr, return GeTensorDesc(), "node.GetOpDesc() is nullptr!");
  248. size_t i = 0;
  249. if (NodeUtils::IsAnchorStatusSet(node)) {
  250. for (const auto &anchor : node.GetAllInDataAnchors()) {
  251. if (ge::AnchorUtils::GetStatus(anchor) == ANCHOR_DATA) {
  252. if (index_non_const == i) {
  253. return node.GetOpDesc()->GetInputDesc(static_cast<uint32_t>(anchor->GetIdx()));
  254. }
  255. ++i;
  256. }
  257. }
  258. } else {
  259. for (const auto &anchor : node.GetAllInDataAnchors()) {
  260. auto peer_anchor = anchor->GetPeerOutAnchor();
  261. if (peer_anchor == nullptr) {
  262. continue;
  263. }
  264. auto owner_node = peer_anchor->GetOwnerNode();
  265. if (owner_node == nullptr) {
  266. continue;
  267. }
  268. if (owner_node->GetType() == CONSTANT) {
  269. continue;
  270. }
  271. if (index_non_const == i) {
  272. return node.GetOpDesc()->GetInputDesc(anchor->GetIdx());
  273. }
  274. ++i;
  275. }
  276. }
  277. return GeTensorDesc();
  278. }
  279. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeTensorDesc
  280. OpDescUtils::GetNonConstInputTensorDesc(const ge::ConstNodePtr &node, size_t index_non_const) {
  281. CHECK_FALSE_EXEC(node != nullptr, return GeTensorDesc());
  282. return GetNonConstInputTensorDesc(*node, index_non_const);
  283. }
  284. bool OpDescUtils::GetNonConstInputIndex(const ge::Node &node, const size_t index_non_const, size_t &index) {
  285. bool ret = false;
  286. size_t i = 0;
  287. if (NodeUtils::IsAnchorStatusSet(node)) {
  288. for (const auto &anchor : node.GetAllInDataAnchors()) {
  289. if (ge::AnchorUtils::GetStatus(anchor) == ANCHOR_DATA) {
  290. if (index_non_const == i) {
  291. index = static_cast<size_t>(anchor->GetIdx());
  292. ret = true;
  293. }
  294. ++i;
  295. }
  296. }
  297. } else {
  298. for (const auto &anchor : node.GetAllInDataAnchors()) {
  299. auto peer_anchor = anchor->GetPeerOutAnchor();
  300. if (peer_anchor == nullptr) {
  301. continue;
  302. }
  303. auto owner_node = peer_anchor->GetOwnerNode();
  304. if (owner_node == nullptr) {
  305. continue;
  306. }
  307. if (owner_node->GetType() == CONSTANT) {
  308. continue;
  309. }
  310. if (index_non_const == i) {
  311. index = static_cast<size_t>(anchor->GetIdx());
  312. ret = true;
  313. }
  314. ++i;
  315. }
  316. }
  317. return ret;
  318. }
  319. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool OpDescUtils::GetNonConstInputIndex(const ge::ConstNodePtr &node,
  320. size_t index_non_const,
  321. size_t &index) {
  322. CHECK_FALSE_EXEC(node != nullptr, return false);
  323. return GetNonConstInputIndex(*node, index_non_const, index);
  324. }
  325. bool OpDescUtils::IsNonConstInput(const ge::Node &node, const size_t index) {
  326. bool ret = false;
  327. if (index < node.GetAllInDataAnchors().size()) {
  328. if (NodeUtils::IsAnchorStatusSet(node)) {
  329. ret = (ge::AnchorUtils::GetStatus(node.GetInDataAnchor(static_cast<int>(index))) == ANCHOR_DATA); // lint !e712
  330. } else {
  331. for (const auto &anchor : node.GetAllInDataAnchors()) {
  332. if (anchor->GetIdx() != static_cast<int>(index)) {
  333. continue;
  334. }
  335. auto peer_anchor = anchor->GetPeerOutAnchor();
  336. if (peer_anchor == nullptr) {
  337. break;
  338. }
  339. auto owner_node = peer_anchor->GetOwnerNode();
  340. if (owner_node == nullptr) {
  341. break;
  342. }
  343. ret = (owner_node->GetType() != CONSTANT);
  344. }
  345. }
  346. }
  347. return ret;
  348. }
  349. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool OpDescUtils::IsNonConstInput(const ge::ConstNodePtr &node,
  350. size_t index) {
  351. CHECK_FALSE_EXEC(node != nullptr, return false);
  352. return IsNonConstInput(*node, index);
  353. }
  354. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector<ge::NodePtr> OpDescUtils::GetConstInputs(
  355. const ge::ConstNodePtr &node) {
  356. if (node == nullptr) {
  357. return vector<ge::NodePtr>();
  358. }
  359. return GetConstInputs(*node);
  360. }
  361. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector<ge::GeTensorDesc> OpDescUtils::GetNonConstTensorDesc(
  362. const ge::ConstNodePtr &node) {
  363. if (node == nullptr || node->GetOpDesc() == nullptr) {
  364. return vector<ge::GeTensorDesc>();
  365. }
  366. vector<ge::GeTensorDesc> ret;
  367. if (NodeUtils::IsAnchorStatusSet(*node)) {
  368. for (const auto &in_anchor : node->GetAllInDataAnchors()) {
  369. if (ge::AnchorUtils::GetStatus(in_anchor) == ANCHOR_DATA) {
  370. ret.push_back(node->GetOpDesc()->GetInputDesc(in_anchor->GetIdx()));
  371. }
  372. }
  373. } else {
  374. for (const auto &in_anchor : node->GetAllInDataAnchors()) {
  375. auto out_anchor = in_anchor->GetPeerOutAnchor();
  376. if (out_anchor == nullptr || out_anchor->GetOwnerNode()->GetOpDesc() == nullptr) {
  377. continue;
  378. }
  379. if (out_anchor->GetOwnerNode()->GetOpDesc()->GetType() != CONSTANT) {
  380. ret.push_back(node->GetOpDesc()->GetInputDesc(in_anchor->GetIdx()));
  381. }
  382. }
  383. }
  384. return ret;
  385. }
  386. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector<ge::NodePtr> OpDescUtils::GetConstInputs(const ge::Node &node) {
  387. vector<ge::NodePtr> ret;
  388. auto in_anchors = node.GetAllInDataAnchors();
  389. for (const auto &in_anchor : in_anchors) {
  390. auto out_anchor = in_anchor->GetPeerOutAnchor();
  391. if (out_anchor == nullptr) continue;
  392. auto in_node = out_anchor->GetOwnerNode();
  393. if (in_node->GetType() == CONSTANT) {
  394. ret.push_back(in_node);
  395. } else if (in_node->GetType() == SWITCH && node.GetType() == MATMUL) {
  396. // const --> switch --> matmul
  397. auto switch_input = GetConstInputs(*in_node);
  398. if (switch_input.size() > 0) {
  399. ret.insert(ret.end(), switch_input.begin(), switch_input.end());
  400. }
  401. } else if (in_node->GetType() == DATA) {
  402. auto parent = NodeUtils::GetParentInput(in_node);
  403. if ((parent != nullptr) && (parent->GetType() == CONSTANT)) {
  404. ret.push_back(parent);
  405. }
  406. }
  407. }
  408. return ret;
  409. }
  410. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector<GeTensorPtr> OpDescUtils::MutableWeights(const ge::Node &node) {
  411. vector<GeTensorPtr> ret;
  412. auto op_desc = node.GetOpDesc();
  413. GE_CHK_BOOL_EXEC(op_desc != nullptr, return ret, "op_desc is nullptr!");
  414. // Place holder operator, try to get the weight from parent node
  415. // when parent node is const operator
  416. if (node.GetType() == PLACEHOLDER) {
  417. std::string parent_op;
  418. (void)AttrUtils::GetStr(op_desc, "parentOpType", parent_op);
  419. // This if judgment is necessary because the current subgraph optimization is multithreaded
  420. // and the parent node of the PLD operation should be a stable type, such as const
  421. if (parent_op == CONSTANT || parent_op == CONSTANTOP) {
  422. NodePtr parent_node = nullptr;
  423. parent_node = op_desc->TryGetExtAttr("parentNode", parent_node);
  424. if (parent_node != nullptr) {
  425. op_desc = parent_node->GetOpDesc();
  426. GELOGD("pld[%s] get weight from const[%s]", node.GetName().c_str(), op_desc->GetName().c_str());
  427. }
  428. }
  429. }
  430. // Const operator, take the weight directly
  431. if (op_desc->GetType() == CONSTANT || (op_desc->GetType() == CONSTANTOP)) {
  432. auto weight = MutableWeights(op_desc);
  433. if (weight == nullptr) {
  434. GELOGI("const op has no weight, op name:%s", node.GetName().c_str());
  435. return ret;
  436. }
  437. ret.push_back(weight);
  438. return ret;
  439. }
  440. // Other operators, get weights from connected constop
  441. auto input_nodes = GetConstInputs(node);
  442. for (const auto &input_node : input_nodes) {
  443. auto temp_weight = MutableWeights(input_node->GetOpDesc());
  444. if (temp_weight == nullptr) {
  445. GELOGE(GRAPH_FAILED, "const op's weight is null, name: %s", input_node->GetName().c_str());
  446. return vector<GeTensorPtr>();
  447. }
  448. ret.push_back(temp_weight);
  449. }
  450. return ret;
  451. }
  452. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector<GeTensorPtr> OpDescUtils::MutableWeights(const ge::NodePtr node) {
  453. if (node == nullptr) {
  454. GELOGE(GRAPH_FAILED, "Node is nullptr");
  455. return vector<ge::GeTensorPtr>();
  456. }
  457. return MutableWeights(*node);
  458. }
  459. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus
  460. OpDescUtils::SetWeights(ge::Node &node, const vector<ge::GeTensorPtr> &weights) {
  461. GE_CHK_BOOL_EXEC(node.GetOpDesc() != nullptr, return GRAPH_PARAM_INVALID, "node.GetOpDesc is nullptr!");
  462. if (node.GetOpDesc()->GetType() == CONSTANT) {
  463. if (weights.size() == CONST_OP_NORMAL_WEIGHT_SIZE) {
  464. return SetWeights(node.GetOpDesc(), weights[0]);
  465. }
  466. GELOGI("const op weight size %zu should be 1", weights.size());
  467. return GRAPH_PARAM_INVALID;
  468. }
  469. auto input_nodes = GetConstInputs(node);
  470. if (weights.size() < input_nodes.size()) {
  471. GELOGE(GRAPH_FAILED, "weights count can't be less than const input count");
  472. return GRAPH_PARAM_INVALID;
  473. }
  474. ge::GeAttrValue::NAMED_ATTRS named_attrs;
  475. (void)ge::AttrUtils::SetListTensor(named_attrs, "key", weights);
  476. vector<ge::GeTensorPtr> copy_weights;
  477. (void)ge::AttrUtils::MutableListTensor(named_attrs, "key", copy_weights);
  478. for (size_t i = 0; i < input_nodes.size(); ++i) {
  479. if (input_nodes[i]->GetOpDesc() != nullptr) {
  480. SetWeights(input_nodes[i]->GetOpDesc(), copy_weights[i]);
  481. }
  482. }
  483. // If set more weights than constop, need to add constop
  484. for (size_t i = input_nodes.size(); i < copy_weights.size(); ++i) {
  485. // Use org weight before SetWeights Overwrite
  486. auto const_opdesc = CreateConstOp(copy_weights[i]);
  487. GE_CHECK_NOTNULL(const_opdesc);
  488. auto owner_graph = node.GetOwnerComputeGraph();
  489. if (owner_graph == nullptr) {
  490. GELOGE(GRAPH_FAILED, "node's graph is empty, name: %s", node.GetName().c_str());
  491. return GRAPH_PARAM_INVALID;
  492. }
  493. auto const_node = owner_graph->AddNodeFront(const_opdesc);
  494. GE_CHK_BOOL_EXEC(node.AddLinkFrom(const_node) == GRAPH_SUCCESS, return GRAPH_FAILED, "graph add link failed!");
  495. std::vector<ge::NodePtr> original_nodes;
  496. ge::GraphUtils::RecordOriginalNames(original_nodes, const_node);
  497. }
  498. return GRAPH_SUCCESS;
  499. }
  500. OpDescPtr OpDescUtils::CreateConstOp(const GeTensorPtr &tensor_ptr) {
  501. GE_CHK_BOOL_EXEC(tensor_ptr != nullptr, return nullptr, "tensor_ptr is nullptr!");
  502. shared_ptr<OpDesc> const_opdesc = ComGraphMakeShared<OpDesc>();
  503. if (const_opdesc == nullptr) {
  504. GELOGE(GRAPH_FAILED, "failed to make_shared ");
  505. return nullptr;
  506. }
  507. CHECK_FALSE_EXEC(SetWeights(const_opdesc, tensor_ptr) == ge::GRAPH_SUCCESS, return nullptr);
  508. const_opdesc->SetType(CONSTANT);
  509. static int const_count = 0;
  510. const_opdesc->SetName("dynamic_const_" + std::to_string(const_count));
  511. GELOGI("add const op: %s", const_opdesc->GetName().c_str());
  512. ++const_count;
  513. (void)const_opdesc->AddOutputDesc(tensor_ptr->GetTensorDesc());
  514. GELOGI("after add const op: %s", const_opdesc->GetName().c_str());
  515. return const_opdesc;
  516. }
  517. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus
  518. OpDescUtils::AddConstOpToAnchor(InDataAnchorPtr in_anchor, const GeTensorPtr &tensor_ptr) {
  519. GE_CHECK_NOTNULL(in_anchor);
  520. GE_CHECK_NOTNULL(tensor_ptr);
  521. auto const_opdesc = CreateConstOp(tensor_ptr);
  522. GE_CHECK_NOTNULL(const_opdesc);
  523. auto in_node = in_anchor->GetOwnerNode();
  524. GE_CHECK_NOTNULL(in_node);
  525. auto owner_graph = in_node->GetOwnerComputeGraph();
  526. if (owner_graph == nullptr) {
  527. GELOGE(GRAPH_PARAM_INVALID, "node's graph is empty, name: %s", in_node->GetName().c_str());
  528. return GRAPH_PARAM_INVALID;
  529. }
  530. auto const_node = in_node->GetOwnerComputeGraph()->AddNodeFront(const_opdesc);
  531. GE_CHECK_NOTNULL(const_node);
  532. if (GraphUtils::AddEdge(const_node->GetOutDataAnchor(0), in_anchor) != GRAPH_SUCCESS) {
  533. GELOGE(GRAPH_PARAM_INVALID, "Addedge const to node failed.");
  534. return GRAPH_PARAM_INVALID;
  535. }
  536. return GRAPH_SUCCESS;
  537. }
  538. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus
  539. OpDescUtils::SetWeights(ge::NodePtr node, const vector<ge::GeTensorPtr> &weights) {
  540. GE_CHECK_NOTNULL(node);
  541. return SetWeights(*node, weights);
  542. }
  543. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus OpDescUtils::ClearWeights(const ge::NodePtr node) {
  544. GE_CHECK_NOTNULL(node);
  545. auto const_ops = GetConstInputs(node);
  546. auto graph = node->GetOwnerComputeGraph();
  547. if (graph == nullptr) {
  548. GELOGE(GRAPH_FAILED, "Graph is nullptr");
  549. return GRAPH_PARAM_INVALID;
  550. }
  551. for (const auto &const_op : const_ops) {
  552. GE_CHK_STATUS_RET(GraphUtils::IsolateNode(const_op, {}), "Isolate removed node: %s, type: %s failed",
  553. const_op->GetName().c_str(), const_op->GetType().c_str());
  554. GE_CHK_STATUS_RET(GraphUtils::RemoveNodeWithoutRelink(graph, const_op),
  555. "Remove node: %s, type: %s without relink failed", const_op->GetName().c_str(),
  556. const_op->GetType().c_str());
  557. }
  558. return GRAPH_SUCCESS;
  559. }
  560. ///
  561. /// @brief Add input
  562. /// @param [in] name
  563. /// @return OpDescBuilder
  564. ///
  565. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescBuilder &OpDescBuilder::AddInput(const std::string &name) {
  566. inputs_.emplace_back(std::make_pair(name, GeTensorDesc()));
  567. return *this;
  568. }
  569. ///
  570. /// @brief Add input
  571. /// @param [in] name
  572. /// @param [in] tensor
  573. /// @return OpDescBuilder
  574. ///
  575. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescBuilder &OpDescBuilder::AddInput(const std::string &name,
  576. const GeTensorDesc &tensor) {
  577. inputs_.emplace_back(std::make_pair(name, tensor));
  578. return *this;
  579. }
  580. ///
  581. /// @brief Add dynamic input
  582. /// @param [in] name
  583. /// @param [in] num
  584. /// @return OpDescBuilder
  585. ///
  586. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescBuilder &OpDescBuilder::AddDynamicInput(const std::string &name,
  587. uint32_t num) {
  588. for (uint32_t i = 0; i < num; i++) {
  589. inputs_.emplace_back(std::make_pair(name + std::to_string(i), GeTensorDesc()));
  590. }
  591. return *this;
  592. }
  593. ///
  594. /// @brief Add dynamic input
  595. /// @param [in] name
  596. /// @param [in] num
  597. /// @param [in] tensor
  598. /// @return OpDescBuilder
  599. ///
  600. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescBuilder &OpDescBuilder::AddDynamicInput(
  601. const std::string &name, uint32_t num, const GeTensorDesc &tensor) {
  602. for (uint32_t i = 0; i < num; i++) {
  603. inputs_.emplace_back(std::make_pair(name + std::to_string(i), tensor));
  604. }
  605. return *this;
  606. }
  607. ///
  608. /// @brief Add output
  609. /// @param [in] name
  610. /// @return OpDescBuilder
  611. ///
  612. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescBuilder &OpDescBuilder::AddOutput(const std::string &name) {
  613. outputs_.emplace_back(std::make_pair(name, GeTensorDesc()));
  614. return *this;
  615. }
  616. ///
  617. /// @brief Add output
  618. /// @param [in] name
  619. /// @param [in] tensor
  620. /// @return OpDescBuilder
  621. ///
  622. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescBuilder &OpDescBuilder::AddOutput(const std::string &name,
  623. const GeTensorDesc &tensor) {
  624. outputs_.emplace_back(std::make_pair(name, tensor));
  625. return *this;
  626. }
  627. ///
  628. /// @brief Add dynamic output
  629. /// @param [in] name
  630. /// @param [in] num
  631. /// @return OpDescBuilder
  632. ///
  633. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescBuilder &OpDescBuilder::AddDynamicOutput(const std::string &name,
  634. uint32_t num) {
  635. for (uint32_t i = 0; i < num; i++) {
  636. outputs_.emplace_back(std::make_pair(name + std::to_string(i), GeTensorDesc()));
  637. }
  638. return *this;
  639. }
  640. ///
  641. /// @brief Add dynamic output
  642. /// @param [in] name
  643. /// @param [in] num
  644. /// @param [in] tensor
  645. /// @return OpDescBuilder
  646. ///
  647. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescBuilder &OpDescBuilder::AddDynamicOutput(
  648. const std::string &name, uint32_t num, const GeTensorDesc &tensor) {
  649. for (uint32_t i = 0; i < num; i++) {
  650. outputs_.emplace_back(std::make_pair(name + std::to_string(i), tensor));
  651. }
  652. return *this;
  653. }
  654. ///
  655. /// @brief Build op_desc
  656. /// @return OpDescPtr
  657. ///
  658. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescPtr OpDescBuilder::Build() {
  659. OpDescPtr op_desc = shared_ptr<OpDesc>(new (std::nothrow) OpDesc(name_, type_));
  660. if (op_desc == nullptr) {
  661. GELOGE(GRAPH_FAILED, "OpDesc is nullptr");
  662. return nullptr;
  663. }
  664. for (auto &input : inputs_) {
  665. if (op_desc->AddInputDesc(input.first, input.second) != GRAPH_SUCCESS) {
  666. GELOGE(GRAPH_FAILED, "Add input_desc failed.");
  667. return nullptr;
  668. }
  669. }
  670. for (auto &output : outputs_) {
  671. if (op_desc->AddOutputDesc(output.first, output.second) != GRAPH_SUCCESS) {
  672. GELOGE(GRAPH_FAILED, "Add output_desc failed.");
  673. return nullptr;
  674. }
  675. }
  676. return op_desc;
  677. }
  678. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus OpDescUtils::SetSubgraphInstanceName(
  679. const std::string &subgraph_name, const std::string &subgraph_instance_name, OpDescPtr &op_desc) {
  680. const auto &subgraph_names_to_index = op_desc->GetSubgraphNameIndexes();
  681. auto iter = subgraph_names_to_index.find(subgraph_name);
  682. if (iter == subgraph_names_to_index.end()) {
  683. GELOGE(GRAPH_PARAM_INVALID,
  684. "Failed to set subgraph instance %s for node %s type %s, the subgraph name %s does not exists",
  685. subgraph_instance_name.c_str(), op_desc->GetName().c_str(), op_desc->GetType().c_str(),
  686. subgraph_name.c_str());
  687. return GRAPH_PARAM_INVALID;
  688. }
  689. return op_desc->SetSubgraphInstanceName(iter->second, subgraph_instance_name);
  690. }
  691. } // namespace ge
  692. /*lint +e512 +e737 +e752*/

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示