You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

op_desc_utils.cc 30 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823
  1. /**
  2. * Copyright 2019-2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "utils/op_desc_utils.h"
  17. #include <algorithm>
  18. #include "debug/ge_attr_define.h"
  19. #include "debug/ge_op_types.h"
  20. #include "debug/ge_util.h"
  21. #include "framework/common/debug/ge_log.h"
  22. #include "graph/anchor.h"
  23. #include "graph/compute_graph.h"
  24. #include "graph/ge_attr_value.h"
  25. #include "utils/graph_utils.h"
  26. #include "utils/node_utils.h"
  27. using std::vector;
  28. namespace ge {
  29. const char OP_DESC_QUANT_PARAMS[] = "quantize_factor";
  30. static const int CONST_OP_NORMAL_WEIGHT_SIZE = 1;
  31. bool OpDescUtils::ClearInputDesc(const NodePtr &node) {
  32. GE_CHK_BOOL_EXEC(node != nullptr, return false, "node is nullptr");
  33. GE_CHK_BOOL_EXEC(node->GetOpDesc() != nullptr, return false, "opdesc is nullptr");
  34. vector<int> index_list;
  35. for (const auto &in_anchor : node->GetAllInDataAnchors()) {
  36. if (in_anchor->GetPeerOutAnchor() == nullptr) {
  37. index_list.push_back(in_anchor->GetIdx());
  38. }
  39. }
  40. std::sort(index_list.begin(), index_list.end());
  41. // Node's in anchor index need shrink
  42. for (size_t i = 0; i < index_list.size(); ++i) {
  43. auto iter = node->GetOpDesc()->inputs_desc_.begin() + index_list[i];
  44. if (iter < node->GetOpDesc()->inputs_desc_.end()) {
  45. (void)node->GetOpDesc()->inputs_desc_.erase(iter);
  46. } else {
  47. GELOGW("inputs_desc_ iterator out of range.");
  48. }
  49. }
  50. return true;
  51. }
  52. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool OpDescUtils::ClearInputDesc(OpDescPtr op_desc,
  53. const uint32_t index) {
  54. GE_CHK_BOOL_EXEC(op_desc != nullptr, return false, "op_desc is nullptr");
  55. GE_CHK_BOOL_EXEC(index < op_desc->inputs_desc_.size(), return false, "index %u is invalid.", index);
  56. auto iter = op_desc->inputs_desc_.begin() + index;
  57. if (iter < op_desc->inputs_desc_.end()) {
  58. (void)op_desc->inputs_desc_.erase(iter);
  59. } else {
  60. GELOGW("inputs_desc_ iterator out of range.");
  61. }
  62. return true;
  63. }
  64. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool OpDescUtils::HasQuantizeFactorParams(const OpDescPtr &op_desc) {
  65. GE_CHK_BOOL_EXEC_INFO(op_desc != nullptr, return false, "op_desc is nullptr");
  66. return op_desc->HasAttr(OP_DESC_QUANT_PARAMS);
  67. }
  68. bool OpDescUtils::ClearOutputDesc(const NodePtr &node) {
  69. GE_CHK_BOOL_EXEC(node != nullptr, return false, "node is nullptr");
  70. GE_CHK_BOOL_EXEC(node->GetOpDesc() != nullptr, return false, "opdesc is nullptr");
  71. vector<int> index_list;
  72. for (const auto &out_anchor : node->GetAllOutDataAnchors()) {
  73. if (out_anchor->GetPeerInDataAnchors().empty()) {
  74. index_list.push_back(out_anchor->GetIdx());
  75. }
  76. }
  77. std::sort(index_list.begin(), index_list.end());
  78. // Node's out anchor index need shrink
  79. for (size_t i = 0; i < index_list.size(); ++i) {
  80. auto iter = node->GetOpDesc()->outputs_desc_.begin() + index_list[i];
  81. if (iter < node->GetOpDesc()->outputs_desc_.end()) {
  82. (void)node->GetOpDesc()->outputs_desc_.erase(iter);
  83. } else {
  84. GELOGW("outputs_desc_ iterator out of range.");
  85. }
  86. }
  87. return true;
  88. }
  89. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool OpDescUtils::ClearOutputDesc(const OpDescPtr &op_desc,
  90. uint32_t index) {
  91. GE_CHK_BOOL_EXEC(op_desc != nullptr, return false, "op_desc is nullptr");
  92. GE_CHK_BOOL_EXEC(index < op_desc->outputs_desc_.size(), return false, "index %u is invalid.", index);
  93. auto iter = op_desc->outputs_desc_.begin() + index;
  94. if (iter < op_desc->outputs_desc_.end()) {
  95. (void)op_desc->outputs_desc_.erase(iter);
  96. } else {
  97. GELOGW("outputs_desc_ iterator out of range.");
  98. }
  99. return true;
  100. }
  101. bool OpDescUtils::HasQuantizeFactorParams(const OpDesc &op_desc) { return op_desc.HasAttr(OP_DESC_QUANT_PARAMS); }
  102. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus
  103. OpDescUtils::GetQuantizeFactorParams(const OpDescPtr &op_desc, QuantizeFactorParams &quant) {
  104. GE_CHK_BOOL_EXEC_INFO(op_desc != nullptr, return GRAPH_FAILED, "op_desc is nullptr");
  105. GeAttrValue attr_value;
  106. GE_CHK_BOOL_EXEC_INFO(op_desc->GetAttr(OP_DESC_QUANT_PARAMS, attr_value) == GRAPH_SUCCESS, return GRAPH_FAILED,
  107. "GetQuantizeFactorParams failed");
  108. return attr_value.GetValue<QuantizeFactorParams>(quant);
  109. }
  110. graphStatus OpDescUtils::GetQuantizeFactorParams(const OpDesc &op_desc, QuantizeFactorParams &quant) {
  111. GeAttrValue attr_value;
  112. GE_CHK_BOOL_EXEC_INFO(op_desc.GetAttr(OP_DESC_QUANT_PARAMS, attr_value) == GRAPH_SUCCESS, return GRAPH_FAILED,
  113. "GetQuantizeFactorParams failed");
  114. return attr_value.GetValue<QuantizeFactorParams>(quant);
  115. }
  116. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus
  117. OpDescUtils::SetQuantizeFactorParams(const OpDescPtr &op_desc, const QuantizeFactorParams &quant) {
  118. GE_CHK_BOOL_EXEC_INFO(op_desc != nullptr, return GRAPH_FAILED, "op_desc is nullptr");
  119. return op_desc->SetAttr(OP_DESC_QUANT_PARAMS, GeAttrValue::CreateFrom<QuantizeFactorParams>(quant));
  120. }
  121. graphStatus OpDescUtils::SetQuantizeFactorParams(OpDesc &op_desc, const QuantizeFactorParams &quant) {
  122. return op_desc.SetAttr(OP_DESC_QUANT_PARAMS, GeAttrValue::CreateFrom<QuantizeFactorParams>(quant));
  123. }
  124. GeTensorPtr OpDescUtils::MutableWeights(OpDesc &op_desc) {
  125. GeTensorPtr weight = nullptr;
  126. if (!AttrUtils::MutableTensor(&op_desc, ATTR_NAME_WEIGHTS, weight)) {
  127. GELOGW("MutableTensor error");
  128. }
  129. return weight;
  130. }
  131. GE_FUNC_HOST_VISIBILITY GeTensorPtr OpDescUtils::MutableWeights(OpDescPtr op_desc) {
  132. if (op_desc == nullptr) {
  133. GELOGE(GRAPH_FAILED, "op_desc is null");
  134. return nullptr;
  135. }
  136. return MutableWeights(*op_desc);
  137. }
  138. graphStatus OpDescUtils::SetWeights(OpDesc &op_desc, const GeTensorPtr weight) {
  139. if (weight == nullptr) {
  140. GELOGE(GRAPH_FAILED, "weight is null");
  141. return GRAPH_FAILED;
  142. }
  143. return AttrUtils::SetTensor(&op_desc, ATTR_NAME_WEIGHTS, weight) ? GRAPH_SUCCESS : GRAPH_FAILED;
  144. }
  145. graphStatus OpDescUtils::SetWeights(OpDescPtr op_desc, const GeTensorPtr weight) {
  146. GE_CHECK_NOTNULL(op_desc);
  147. GE_CHECK_NOTNULL(weight);
  148. return SetWeights(*op_desc, weight);
  149. }
  150. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector<ConstGeTensorPtr> OpDescUtils::GetWeights(const ge::Node &node) {
  151. auto weights = MutableWeights(node);
  152. vector<ConstGeTensorPtr> ret(weights.size());
  153. std::copy(weights.begin(), weights.end(), ret.begin());
  154. return ret;
  155. }
  156. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector<ConstGeTensorPtr> OpDescUtils::GetWeights(
  157. const ge::ConstNodePtr &node) {
  158. if (node == nullptr) {
  159. return vector<ge::ConstGeTensorPtr>();
  160. }
  161. return GetWeights(*node);
  162. }
  163. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector<ge::NodePtr> OpDescUtils::GetConstInputNode(
  164. const ge::Node &node) {
  165. vector<ge::NodePtr> ret;
  166. auto in_anchors = node.GetAllInDataAnchors();
  167. for (const auto &in_anchor : in_anchors) {
  168. auto out_anchor = in_anchor->GetPeerOutAnchor();
  169. if (out_anchor == nullptr) {
  170. // normally out_anchor could be null, this is ok
  171. GELOGD("node %s' peer_out_anchor is null", node.GetName().c_str());
  172. continue;
  173. }
  174. auto in_node = out_anchor->GetOwnerNode();
  175. while (true) {
  176. if (in_node == nullptr) {
  177. break;
  178. }
  179. if ((in_node->GetType() == CONSTANT) || (in_node->GetType() == CONSTANTOP)) {
  180. ret.push_back(in_node);
  181. break;
  182. } else if (in_node->GetType() == DATA) {
  183. if (NodeUtils::IsWhileVaryingInput(in_node)) {
  184. break;
  185. }
  186. in_node = NodeUtils::GetParentInput(in_node);
  187. } else if ((in_node->GetType() == ENTER) || (in_node->GetType() == REFENTER)) {
  188. bool is_constant = false;
  189. (void)AttrUtils::GetBool(in_node->GetOpDesc(), ENTER_ATTR_CONSTANT_FLAG, is_constant);
  190. if (!is_constant) {
  191. break;
  192. }
  193. // Enter node has and only has one input
  194. if (in_node->GetInDataNodes().size() != 1) {
  195. GELOGW("Check number of input_nodes for Enter node %s failed, size=%zu.", node.GetName().c_str(),
  196. in_node->GetInDataNodes().size());
  197. break;
  198. }
  199. in_node = in_node->GetInDataNodes().at(0);
  200. } else {
  201. break;
  202. }
  203. }
  204. }
  205. return ret;
  206. }
  207. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector<ConstGeTensorPtr> OpDescUtils::GetInputData(
  208. const vector<ge::NodePtr> &input_nodes) {
  209. vector<ConstGeTensorPtr> ret;
  210. for (const auto &input_node : input_nodes) {
  211. auto temp_weight = MutableWeights(input_node->GetOpDesc());
  212. if (temp_weight == nullptr) {
  213. GELOGE(GRAPH_FAILED, "const op's weight is null, name: %s", input_node->GetName().c_str());
  214. return vector<ConstGeTensorPtr>();
  215. }
  216. ret.push_back(temp_weight);
  217. }
  218. return ret;
  219. }
  220. size_t OpDescUtils::GetNonConstInputsSize(const ge::Node &node) {
  221. if (NodeUtils::IsAnchorStatusSet(node)) {
  222. size_t input_num = 0;
  223. for (const auto &anchor : node.GetAllInDataAnchors()) {
  224. if (ge::AnchorUtils::GetStatus(anchor) == ANCHOR_DATA) {
  225. input_num++;
  226. continue;
  227. }
  228. }
  229. return input_num;
  230. } else {
  231. GE_IF_BOOL_EXEC(
  232. node.GetInDataNodes().size() < GetConstInputs(node).size(),
  233. GELOGE(GRAPH_FAILED, "%zu is smaller than %zu", node.GetInDataNodes().size(), GetConstInputs(node).size());
  234. return 0);
  235. return node.GetInDataNodes().size() - GetConstInputs(node).size();
  236. }
  237. }
  238. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY size_t OpDescUtils::GetNonConstInputsSize(const ge::ConstNodePtr node) {
  239. if (node == nullptr) {
  240. GELOGE(GRAPH_FAILED, "Node is nullptr");
  241. return 0;
  242. }
  243. return GetNonConstInputsSize(*node);
  244. }
  245. GeTensorDesc OpDescUtils::GetNonConstInputTensorDesc(const ge::Node &node, size_t index_non_const) {
  246. GE_CHK_BOOL_EXEC(node.GetOpDesc() != nullptr, return GeTensorDesc(), "node.GetOpDesc() is nullptr!");
  247. size_t i = 0;
  248. if (NodeUtils::IsAnchorStatusSet(node)) {
  249. for (const auto &anchor : node.GetAllInDataAnchors()) {
  250. if (ge::AnchorUtils::GetStatus(anchor) == ANCHOR_DATA) {
  251. if (index_non_const == i) {
  252. return node.GetOpDesc()->GetInputDesc(static_cast<uint32_t>(anchor->GetIdx()));
  253. }
  254. ++i;
  255. }
  256. }
  257. } else {
  258. for (const auto &anchor : node.GetAllInDataAnchors()) {
  259. auto peer_anchor = anchor->GetPeerOutAnchor();
  260. if (peer_anchor == nullptr) {
  261. continue;
  262. }
  263. auto owner_node = peer_anchor->GetOwnerNode();
  264. if (owner_node == nullptr) {
  265. continue;
  266. }
  267. if (owner_node->GetType() == CONSTANT) {
  268. continue;
  269. }
  270. if (index_non_const == i) {
  271. return node.GetOpDesc()->GetInputDesc(anchor->GetIdx());
  272. }
  273. ++i;
  274. }
  275. }
  276. return GeTensorDesc();
  277. }
  278. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeTensorDesc
  279. OpDescUtils::GetNonConstInputTensorDesc(const ge::ConstNodePtr &node, size_t index_non_const) {
  280. CHECK_FALSE_EXEC(node != nullptr, return GeTensorDesc());
  281. return GetNonConstInputTensorDesc(*node, index_non_const);
  282. }
  283. bool OpDescUtils::GetNonConstInputIndex(const ge::Node &node, const size_t index_non_const, size_t &index) {
  284. bool ret = false;
  285. size_t i = 0;
  286. if (NodeUtils::IsAnchorStatusSet(node)) {
  287. for (const auto &anchor : node.GetAllInDataAnchors()) {
  288. if (ge::AnchorUtils::GetStatus(anchor) == ANCHOR_DATA) {
  289. if (index_non_const == i) {
  290. index = static_cast<size_t>(anchor->GetIdx());
  291. ret = true;
  292. }
  293. ++i;
  294. }
  295. }
  296. } else {
  297. for (const auto &anchor : node.GetAllInDataAnchors()) {
  298. auto peer_anchor = anchor->GetPeerOutAnchor();
  299. if (peer_anchor == nullptr) {
  300. continue;
  301. }
  302. auto owner_node = peer_anchor->GetOwnerNode();
  303. if (owner_node == nullptr) {
  304. continue;
  305. }
  306. if (owner_node->GetType() == CONSTANT) {
  307. continue;
  308. }
  309. if (index_non_const == i) {
  310. index = static_cast<size_t>(anchor->GetIdx());
  311. ret = true;
  312. }
  313. ++i;
  314. }
  315. }
  316. return ret;
  317. }
  318. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool OpDescUtils::GetNonConstInputIndex(const ge::ConstNodePtr &node,
  319. size_t index_non_const,
  320. size_t &index) {
  321. CHECK_FALSE_EXEC(node != nullptr, return false);
  322. return GetNonConstInputIndex(*node, index_non_const, index);
  323. }
  324. bool OpDescUtils::IsNonConstInput(const ge::Node &node, const size_t index) {
  325. bool ret = false;
  326. if (index < node.GetAllInDataAnchors().size()) {
  327. if (NodeUtils::IsAnchorStatusSet(node)) {
  328. ret = (ge::AnchorUtils::GetStatus(node.GetInDataAnchor(static_cast<int>(index))) == ANCHOR_DATA);
  329. } else {
  330. for (const auto &anchor : node.GetAllInDataAnchors()) {
  331. if (anchor->GetIdx() != static_cast<int>(index)) {
  332. continue;
  333. }
  334. auto peer_anchor = anchor->GetPeerOutAnchor();
  335. if (peer_anchor == nullptr) {
  336. break;
  337. }
  338. auto owner_node = peer_anchor->GetOwnerNode();
  339. if (owner_node == nullptr) {
  340. break;
  341. }
  342. ret = (owner_node->GetType() != CONSTANT);
  343. }
  344. }
  345. }
  346. return ret;
  347. }
  348. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool OpDescUtils::IsNonConstInput(const ge::ConstNodePtr &node,
  349. size_t index) {
  350. CHECK_FALSE_EXEC(node != nullptr, return false);
  351. return IsNonConstInput(*node, index);
  352. }
  353. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector<ge::NodePtr> OpDescUtils::GetConstInputs(
  354. const ge::ConstNodePtr &node) {
  355. if (node == nullptr) {
  356. return vector<ge::NodePtr>();
  357. }
  358. return GetConstInputs(*node);
  359. }
  360. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector<ge::GeTensorDesc> OpDescUtils::GetNonConstTensorDesc(
  361. const ge::ConstNodePtr &node) {
  362. if (node == nullptr || node->GetOpDesc() == nullptr) {
  363. return vector<ge::GeTensorDesc>();
  364. }
  365. vector<ge::GeTensorDesc> ret;
  366. if (NodeUtils::IsAnchorStatusSet(*node)) {
  367. for (const auto &in_anchor : node->GetAllInDataAnchors()) {
  368. if (ge::AnchorUtils::GetStatus(in_anchor) == ANCHOR_DATA) {
  369. ret.push_back(node->GetOpDesc()->GetInputDesc(in_anchor->GetIdx()));
  370. }
  371. }
  372. } else {
  373. for (const auto &in_anchor : node->GetAllInDataAnchors()) {
  374. auto out_anchor = in_anchor->GetPeerOutAnchor();
  375. if (out_anchor == nullptr || out_anchor->GetOwnerNode()->GetOpDesc() == nullptr) {
  376. continue;
  377. }
  378. if (out_anchor->GetOwnerNode()->GetOpDesc()->GetType() != CONSTANT) {
  379. ret.push_back(node->GetOpDesc()->GetInputDesc(in_anchor->GetIdx()));
  380. }
  381. }
  382. }
  383. return ret;
  384. }
  385. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector<ge::NodePtr> OpDescUtils::GetConstInputs(const ge::Node &node) {
  386. vector<ge::NodePtr> ret;
  387. auto in_anchors = node.GetAllInDataAnchors();
  388. for (const auto &in_anchor : in_anchors) {
  389. auto out_anchor = in_anchor->GetPeerOutAnchor();
  390. if (out_anchor == nullptr) continue;
  391. auto in_node = out_anchor->GetOwnerNode();
  392. if (in_node->GetType() == CONSTANT) {
  393. ret.push_back(in_node);
  394. } else if (in_node->GetType() == SWITCH && node.GetType() == MATMUL) {
  395. // const --> switch --> matmul
  396. auto switch_input = GetConstInputs(*in_node);
  397. if (switch_input.size() > 0) {
  398. ret.insert(ret.end(), switch_input.begin(), switch_input.end());
  399. }
  400. } else if (in_node->GetType() == DATA) {
  401. auto parent = NodeUtils::GetParentInput(in_node);
  402. if ((parent != nullptr) && (parent->GetType() == CONSTANT)) {
  403. ret.push_back(parent);
  404. }
  405. }
  406. }
  407. return ret;
  408. }
  409. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector<GeTensorPtr> OpDescUtils::MutableWeights(const ge::Node &node) {
  410. vector<GeTensorPtr> ret;
  411. auto op_desc = node.GetOpDesc();
  412. GE_CHK_BOOL_EXEC(op_desc != nullptr, return ret, "op_desc is nullptr!");
  413. // Place holder operator, try to get the weight from parent node
  414. // when parent node is const operator
  415. if (node.GetType() == PLACEHOLDER) {
  416. std::string parent_op;
  417. (void)AttrUtils::GetStr(op_desc, "parentOpType", parent_op);
  418. // This if judgment is necessary because the current subgraph optimization is multithreaded
  419. // and the parent node of the PLD operation should be a stable type, such as const
  420. if (parent_op == CONSTANT || parent_op == CONSTANTOP) {
  421. NodePtr parent_node = nullptr;
  422. parent_node = op_desc->TryGetExtAttr("parentNode", parent_node);
  423. if (parent_node != nullptr) {
  424. op_desc = parent_node->GetOpDesc();
  425. GELOGD("pld[%s] get weight from const[%s]", node.GetName().c_str(), op_desc->GetName().c_str());
  426. }
  427. }
  428. }
  429. // Const operator, take the weight directly
  430. if (op_desc->GetType() == CONSTANT || (op_desc->GetType() == CONSTANTOP)) {
  431. auto weight = MutableWeights(op_desc);
  432. if (weight == nullptr) {
  433. GELOGI("const op has no weight, op name:%s", node.GetName().c_str());
  434. return ret;
  435. }
  436. ret.push_back(weight);
  437. return ret;
  438. }
  439. if (node.GetType() == DATA) {
  440. auto parent = NodeUtils::GetParentInput(node);
  441. if ((parent != nullptr) && NodeUtils::IsConst(*parent)) {
  442. auto weight = MutableWeights(parent->GetOpDesc());
  443. if (weight == nullptr) {
  444. GELOGI("const op has no weight, op name:%s", parent->GetName().c_str());
  445. return ret;
  446. }
  447. ret.push_back(weight);
  448. }
  449. return ret;
  450. }
  451. // Other operators, get weights from connected constop
  452. auto input_nodes = GetConstInputs(node);
  453. for (const auto &input_node : input_nodes) {
  454. auto temp_weight = MutableWeights(input_node->GetOpDesc());
  455. if (temp_weight == nullptr) {
  456. GELOGE(GRAPH_FAILED, "const op's weight is null, name: %s", input_node->GetName().c_str());
  457. return vector<GeTensorPtr>();
  458. }
  459. ret.push_back(temp_weight);
  460. }
  461. return ret;
  462. }
  463. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector<GeTensorPtr> OpDescUtils::MutableWeights(const ge::NodePtr node) {
  464. if (node == nullptr) {
  465. GELOGE(GRAPH_FAILED, "Node is nullptr");
  466. return vector<ge::GeTensorPtr>();
  467. }
  468. return MutableWeights(*node);
  469. }
  470. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus
  471. OpDescUtils::SetWeights(ge::Node &node, const vector<ge::GeTensorPtr> &weights) {
  472. GE_CHK_BOOL_EXEC(node.GetOpDesc() != nullptr, return GRAPH_PARAM_INVALID, "node.GetOpDesc is nullptr!");
  473. if (node.GetOpDesc()->GetType() == CONSTANT) {
  474. if (weights.size() == CONST_OP_NORMAL_WEIGHT_SIZE) {
  475. return SetWeights(node.GetOpDesc(), weights[0]);
  476. }
  477. GELOGI("const op weight size %zu should be 1", weights.size());
  478. return GRAPH_PARAM_INVALID;
  479. }
  480. auto input_nodes = GetConstInputs(node);
  481. if (weights.size() < input_nodes.size()) {
  482. GELOGE(GRAPH_FAILED, "weights count can't be less than const input count");
  483. return GRAPH_PARAM_INVALID;
  484. }
  485. ge::GeAttrValue::NAMED_ATTRS named_attrs;
  486. (void)ge::AttrUtils::SetListTensor(named_attrs, "key", weights);
  487. vector<ge::GeTensorPtr> copy_weights;
  488. (void)ge::AttrUtils::MutableListTensor(named_attrs, "key", copy_weights);
  489. for (size_t i = 0; i < input_nodes.size(); ++i) {
  490. if (input_nodes[i]->GetOpDesc() != nullptr) {
  491. SetWeights(input_nodes[i]->GetOpDesc(), copy_weights[i]);
  492. }
  493. }
  494. // If set more weights than constop, need to add constop
  495. for (size_t i = input_nodes.size(); i < copy_weights.size(); ++i) {
  496. // Use org weight before SetWeights Overwrite
  497. auto const_opdesc = CreateConstOp(copy_weights[i]);
  498. GE_CHECK_NOTNULL(const_opdesc);
  499. auto owner_graph = node.GetOwnerComputeGraph();
  500. if (owner_graph == nullptr) {
  501. GELOGE(GRAPH_FAILED, "node's graph is empty, name: %s", node.GetName().c_str());
  502. return GRAPH_PARAM_INVALID;
  503. }
  504. auto const_node = owner_graph->AddNodeFront(const_opdesc);
  505. GE_CHK_BOOL_EXEC(node.AddLinkFrom(const_node) == GRAPH_SUCCESS, return GRAPH_FAILED, "graph add link failed!");
  506. std::vector<ge::NodePtr> original_nodes;
  507. ge::GraphUtils::RecordOriginalNames(original_nodes, const_node);
  508. }
  509. return GRAPH_SUCCESS;
  510. }
  511. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus
  512. OpDescUtils::SetWeights(ge::Node &node, const map<int, ge::GeTensorPtr> &weights_map) {
  513. GE_CHECK_NOTNULL(node.GetOpDesc());
  514. // 1. node is const
  515. if (node.GetOpDesc()->GetType() == CONSTANT) {
  516. if (weights_map.size() == CONST_OP_NORMAL_WEIGHT_SIZE) {
  517. return SetWeights(node.GetOpDesc(), weights_map.begin()->second);
  518. }
  519. GELOGE(GRAPH_PARAM_INVALID, "const op %s weight size %zu should be 1", node.GetName().c_str(), weights_map.size());
  520. return GRAPH_PARAM_INVALID;
  521. }
  522. // 2. node is not const
  523. for (const auto &pair : weights_map) {
  524. auto in_data_anchor = node.GetInDataAnchor(pair.first);
  525. if (in_data_anchor != nullptr && in_data_anchor->GetPeerOutAnchor() != nullptr) {
  526. // a. update const input node
  527. auto out_anchor = in_data_anchor->GetPeerOutAnchor();
  528. auto peer_node = out_anchor->GetOwnerNode();
  529. if (peer_node == nullptr) {
  530. GELOGE(GRAPH_PARAM_INVALID, "op %s [%d]'s input node is null", node.GetName().c_str(), pair.first);
  531. return GRAPH_PARAM_INVALID;
  532. }
  533. if (peer_node->GetType() != CONSTANT) {
  534. GELOGE(GRAPH_PARAM_INVALID, " op %s [%d]'s input node should be const, but is %s type:%s ",
  535. node.GetName().c_str(), pair.first, peer_node->GetName().c_str(), peer_node->GetType().c_str());
  536. }
  537. SetWeights(peer_node->GetOpDesc(), pair.second);
  538. } else {
  539. // b. create new const input node
  540. auto const_opdesc = CreateConstOp(pair.second);
  541. GE_CHECK_NOTNULL(const_opdesc);
  542. auto owner_graph = node.GetOwnerComputeGraph();
  543. if (owner_graph == nullptr) {
  544. GELOGE(GRAPH_PARAM_INVALID, "node's graph is empty, name: %s", node.GetName().c_str());
  545. return GRAPH_PARAM_INVALID;
  546. }
  547. auto const_node = owner_graph->AddNodeFront(const_opdesc);
  548. if (node.AddLinkFrom(static_cast<uint32_t>(pair.first), const_node) != GRAPH_SUCCESS) {
  549. GELOGE(GRAPH_FAILED, "op %s add const to input index[%d] failed", node.GetName().c_str(), pair.first);
  550. return GRAPH_FAILED;
  551. }
  552. }
  553. }
  554. NodeUtils::UpdateIsInputConst(node);
  555. return GRAPH_SUCCESS;
  556. }
  557. OpDescPtr OpDescUtils::CreateConstOp(const GeTensorPtr &tensor_ptr) {
  558. GE_CHK_BOOL_EXEC(tensor_ptr != nullptr, return nullptr, "tensor_ptr is nullptr!");
  559. shared_ptr<OpDesc> const_opdesc = ComGraphMakeShared<OpDesc>();
  560. if (const_opdesc == nullptr) {
  561. GELOGE(GRAPH_FAILED, "failed to make_shared ");
  562. return nullptr;
  563. }
  564. CHECK_FALSE_EXEC(SetWeights(const_opdesc, tensor_ptr) == ge::GRAPH_SUCCESS, return nullptr);
  565. const_opdesc->SetType(CONSTANT);
  566. thread_local int64_t const_count = 0;
  567. const_opdesc->SetName("dynamic_const_" + std::to_string(GetTid()) + "_" + std::to_string(const_count));
  568. GELOGI("add const op: %s", const_opdesc->GetName().c_str());
  569. ++const_count;
  570. (void)const_opdesc->AddOutputDesc(tensor_ptr->GetTensorDesc());
  571. GELOGI("after add const op: %s", const_opdesc->GetName().c_str());
  572. return const_opdesc;
  573. }
  574. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus
  575. OpDescUtils::AddConstOpToAnchor(InDataAnchorPtr in_anchor, const GeTensorPtr &tensor_ptr) {
  576. GE_CHECK_NOTNULL(in_anchor);
  577. GE_CHECK_NOTNULL(tensor_ptr);
  578. auto const_opdesc = CreateConstOp(tensor_ptr);
  579. GE_CHECK_NOTNULL(const_opdesc);
  580. auto in_node = in_anchor->GetOwnerNode();
  581. GE_CHECK_NOTNULL(in_node);
  582. auto owner_graph = in_node->GetOwnerComputeGraph();
  583. if (owner_graph == nullptr) {
  584. GELOGE(GRAPH_PARAM_INVALID, "node's graph is empty, name: %s", in_node->GetName().c_str());
  585. return GRAPH_PARAM_INVALID;
  586. }
  587. auto const_node = in_node->GetOwnerComputeGraph()->AddNodeFront(const_opdesc);
  588. GE_CHECK_NOTNULL(const_node);
  589. if (GraphUtils::AddEdge(const_node->GetOutDataAnchor(0), in_anchor) != GRAPH_SUCCESS) {
  590. GELOGE(GRAPH_PARAM_INVALID, "Addedge const to node failed.");
  591. return GRAPH_PARAM_INVALID;
  592. }
  593. return GRAPH_SUCCESS;
  594. }
  595. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus
  596. OpDescUtils::SetWeights(ge::NodePtr node, const vector<ge::GeTensorPtr> &weights) {
  597. GE_CHECK_NOTNULL(node);
  598. return SetWeights(*node, weights);
  599. }
  600. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus OpDescUtils::ClearWeights(const ge::NodePtr node) {
  601. GE_CHECK_NOTNULL(node);
  602. auto const_ops = GetConstInputs(node);
  603. auto graph = node->GetOwnerComputeGraph();
  604. if (graph == nullptr) {
  605. GELOGE(GRAPH_FAILED, "Graph is nullptr");
  606. return GRAPH_PARAM_INVALID;
  607. }
  608. for (const auto &const_op : const_ops) {
  609. GE_CHK_STATUS_RET(GraphUtils::IsolateNode(const_op, {}), "Isolate removed node: %s, type: %s failed",
  610. const_op->GetName().c_str(), const_op->GetType().c_str());
  611. GE_CHK_STATUS_RET(GraphUtils::RemoveNodeWithoutRelink(graph, const_op),
  612. "Remove node: %s, type: %s without relink failed", const_op->GetName().c_str(),
  613. const_op->GetType().c_str());
  614. }
  615. return GRAPH_SUCCESS;
  616. }
  617. ///
  618. /// @brief Add input
  619. /// @param [in] name
  620. /// @return OpDescBuilder
  621. ///
  622. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescBuilder &OpDescBuilder::AddInput(const std::string &name) {
  623. inputs_.emplace_back(std::make_pair(name, GeTensorDesc()));
  624. return *this;
  625. }
  626. ///
  627. /// @brief Add input
  628. /// @param [in] name
  629. /// @param [in] tensor
  630. /// @return OpDescBuilder
  631. ///
  632. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescBuilder &OpDescBuilder::AddInput(const std::string &name,
  633. const GeTensorDesc &tensor) {
  634. inputs_.emplace_back(std::make_pair(name, tensor));
  635. return *this;
  636. }
  637. ///
  638. /// @brief Add dynamic input
  639. /// @param [in] name
  640. /// @param [in] num
  641. /// @return OpDescBuilder
  642. ///
  643. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescBuilder &OpDescBuilder::AddDynamicInput(const std::string &name,
  644. uint32_t num) {
  645. for (uint32_t i = 0; i < num; i++) {
  646. inputs_.emplace_back(std::make_pair(name + std::to_string(i), GeTensorDesc()));
  647. }
  648. return *this;
  649. }
  650. ///
  651. /// @brief Add dynamic input
  652. /// @param [in] name
  653. /// @param [in] num
  654. /// @param [in] tensor
  655. /// @return OpDescBuilder
  656. ///
  657. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescBuilder &OpDescBuilder::AddDynamicInput(
  658. const std::string &name, uint32_t num, const GeTensorDesc &tensor) {
  659. for (uint32_t i = 0; i < num; i++) {
  660. inputs_.emplace_back(std::make_pair(name + std::to_string(i), tensor));
  661. }
  662. return *this;
  663. }
  664. ///
  665. /// @brief Add output
  666. /// @param [in] name
  667. /// @return OpDescBuilder
  668. ///
  669. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescBuilder &OpDescBuilder::AddOutput(const std::string &name) {
  670. outputs_.emplace_back(std::make_pair(name, GeTensorDesc()));
  671. return *this;
  672. }
  673. ///
  674. /// @brief Add output
  675. /// @param [in] name
  676. /// @param [in] tensor
  677. /// @return OpDescBuilder
  678. ///
  679. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescBuilder &OpDescBuilder::AddOutput(const std::string &name,
  680. const GeTensorDesc &tensor) {
  681. outputs_.emplace_back(std::make_pair(name, tensor));
  682. return *this;
  683. }
  684. ///
  685. /// @brief Add dynamic output
  686. /// @param [in] name
  687. /// @param [in] num
  688. /// @return OpDescBuilder
  689. ///
  690. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescBuilder &OpDescBuilder::AddDynamicOutput(const std::string &name,
  691. uint32_t num) {
  692. for (uint32_t i = 0; i < num; i++) {
  693. outputs_.emplace_back(std::make_pair(name + std::to_string(i), GeTensorDesc()));
  694. }
  695. return *this;
  696. }
  697. ///
  698. /// @brief Add dynamic output
  699. /// @param [in] name
  700. /// @param [in] num
  701. /// @param [in] tensor
  702. /// @return OpDescBuilder
  703. ///
  704. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescBuilder &OpDescBuilder::AddDynamicOutput(
  705. const std::string &name, uint32_t num, const GeTensorDesc &tensor) {
  706. for (uint32_t i = 0; i < num; i++) {
  707. outputs_.emplace_back(std::make_pair(name + std::to_string(i), tensor));
  708. }
  709. return *this;
  710. }
  711. ///
  712. /// @brief Build op_desc
  713. /// @return OpDescPtr
  714. ///
  715. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescPtr OpDescBuilder::Build() {
  716. OpDescPtr op_desc = shared_ptr<OpDesc>(new (std::nothrow) OpDesc(name_, type_));
  717. if (op_desc == nullptr) {
  718. GELOGE(GRAPH_FAILED, "OpDesc is nullptr");
  719. return nullptr;
  720. }
  721. for (auto &input : inputs_) {
  722. if (op_desc->AddInputDesc(input.first, input.second) != GRAPH_SUCCESS) {
  723. GELOGE(GRAPH_FAILED, "Add input_desc failed.");
  724. return nullptr;
  725. }
  726. }
  727. for (auto &output : outputs_) {
  728. if (op_desc->AddOutputDesc(output.first, output.second) != GRAPH_SUCCESS) {
  729. GELOGE(GRAPH_FAILED, "Add output_desc failed.");
  730. return nullptr;
  731. }
  732. }
  733. return op_desc;
  734. }
  735. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus OpDescUtils::SetSubgraphInstanceName(
  736. const std::string &subgraph_name, const std::string &subgraph_instance_name, OpDescPtr &op_desc) {
  737. const auto &subgraph_names_to_index = op_desc->GetSubgraphNameIndexes();
  738. auto iter = subgraph_names_to_index.find(subgraph_name);
  739. if (iter == subgraph_names_to_index.end()) {
  740. GELOGE(GRAPH_PARAM_INVALID,
  741. "Failed to set subgraph instance %s for node %s type %s, the subgraph name %s does not exists",
  742. subgraph_instance_name.c_str(), op_desc->GetName().c_str(), op_desc->GetType().c_str(),
  743. subgraph_name.c_str());
  744. return GRAPH_PARAM_INVALID;
  745. }
  746. return op_desc->SetSubgraphInstanceName(iter->second, subgraph_instance_name);
  747. }
  748. } // namespace ge

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示