You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

ref_relation.cc 18 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "graph/ref_relation.h"
  17. #include <unordered_set>
  18. #include <unordered_map>
  19. #include "utils/mem_utils.h"
  20. #include "debug/ge_log.h"
  21. #include "debug/ge_op_types.h"
  22. #include "debug/ge_util.h"
  23. #include "debug/ge_attr_define.h"
  24. #include "graph/ge_error_codes.h"
  25. #include "graph/utils/graph_utils.h"
  26. #include "framework/common/debug/ge_log.h"
  27. using namespace std;
  28. using namespace ge;
  29. namespace ge {
  30. namespace {
  31. const char *kRefIndex = "_parent_node_index";
  32. const string kWhile = "While";
  33. const string kIf = "If";
  34. const string kCase = "Case";
  35. const uint16_t kMaxElementNum = 100;
  36. std::unordered_set<string> function_op = {
  37. kWhile,
  38. kIf,
  39. kCase
  40. };
  41. }
  42. /* Impl */
  43. class RefRelations::Impl {
  44. public:
  45. graphStatus LookUpRefRelations(const RefCell &key, unordered_set<RefCell, RefCellHash> &result) {
  46. unsigned long number = static_cast<unsigned long>(reinterpret_cast<uintptr_t>(key.node.get()));
  47. std::string lookup_key = key.node_name + std::to_string(key.in_out) + std::to_string(key.in_out_idx)
  48. + std::to_string(number);
  49. auto iter = look_up_table_.find(lookup_key);
  50. if (iter != look_up_table_.end()) {
  51. for (auto &c : iter->second) {
  52. result.insert(c);
  53. }
  54. return GRAPH_SUCCESS;
  55. }
  56. GELOGW("can not find any relations! key value of dest relation is %s", lookup_key.c_str());
  57. return GRAPH_SUCCESS;
  58. };
  59. graphStatus BuildRefRelations(ge::ComputeGraph &root_graph);
  60. graphStatus Clear() {
  61. GELOGD("Start clear boundary reflections between main graph and sub graph!");
  62. look_up_table_.clear();
  63. values_.clear();
  64. return GRAPH_SUCCESS;
  65. };
  66. private:
  67. graphStatus BuildLookUpTables();
  68. graphStatus BuildRefRelationsForBranch(
  69. const NodePtr &root_node,
  70. const vector<vector<NodePtr>> &classed_data_nodes,
  71. const vector<vector<std::pair<NodePtr, size_t>>> &classed_netoutput_nodes,
  72. vector<vector<RefCell>> &node_refs);
  73. graphStatus BuildRefRelationsForWhile(
  74. const NodePtr &root_node,
  75. const vector<vector<NodePtr>> &classed_data_nodes,
  76. const vector<vector<std::pair<NodePtr, size_t>>> &classed_netoutput_nodes,
  77. vector<vector<RefCell>> &node_refs);
  78. graphStatus BuildRelationsWithFuncNodeType(
  79. const NodePtr &root_node,
  80. const vector<vector<NodePtr>> &classed_data_nodes,
  81. const vector<vector<std::pair<NodePtr, size_t>>> &classed_netoutput_nodes,
  82. vector<vector<RefCell>> &node_refs);
  83. void GetDataAndNetoutputOfSubGraph(
  84. const ge::ComputeGraph &root_graph,
  85. vector<NodePtr> &data_nodes,
  86. vector<NodePtr> &netoutput_nodes,
  87. const std::vector<std::string> &sub_graph_names,
  88. const std::string &node_type);
  89. graphStatus GetRootGraph(ge::ComputeGraph &graph, ge::ComputeGraph &root_graph);
  90. graphStatus ProcessSubgraphDataNodes(
  91. vector<NodePtr> &data_nodes,
  92. vector<vector<NodePtr>> &classed_data_nodes);
  93. graphStatus ProcessSubgraphNetoutput(
  94. const vector<NodePtr> &netoutput_nodes,
  95. vector<vector<std::pair<NodePtr, size_t>>> &classed_netoutput_nodes);
  96. std::unordered_map<string, vector<RefCell>> look_up_table_;
  97. std::vector<vector<vector<RefCell>>> values_;
  98. };
  99. // Node Level
  100. graphStatus RefRelations::Impl::BuildRefRelationsForBranch(
  101. const NodePtr &root_node,
  102. const vector<vector<NodePtr>> &classed_data_nodes,
  103. const vector<vector<std::pair<NodePtr, size_t>>> &classed_netoutput_nodes,
  104. vector<vector<RefCell>> &node_refs) {
  105. GELOGD("Enter BuildRefRelationsForBranch!");
  106. size_t ref_i = 0;
  107. for (const auto &ref_i_data_nodes : classed_data_nodes) {
  108. vector<RefCell> in_ref_i_all_refs;
  109. RefCell cell_root;
  110. cell_root.node_name = root_node->GetName();
  111. cell_root.node = root_node;
  112. cell_root.in_out = NODE_IN;
  113. cell_root.in_out_idx = ref_i;
  114. in_ref_i_all_refs.emplace_back(cell_root);
  115. for (const auto &data : ref_i_data_nodes) {
  116. RefCell cell_in;
  117. RefCell cell_out;
  118. cell_in.node_name = data->GetName();
  119. cell_in.node = data;
  120. cell_in.in_out = NODE_IN;
  121. cell_in.in_out_idx = 0;
  122. cell_out.node_name = data->GetName();
  123. cell_out.node = data;
  124. cell_out.in_out = NODE_OUT;
  125. cell_out.in_out_idx = 0;
  126. in_ref_i_all_refs.emplace_back(cell_in);
  127. in_ref_i_all_refs.emplace_back(cell_out);
  128. }
  129. node_refs.emplace_back(in_ref_i_all_refs);
  130. ref_i++;
  131. }
  132. size_t ref_o = 0;
  133. for (const auto &ref_o_net_nodes : classed_netoutput_nodes) {
  134. vector<RefCell> out_ref_i_all_refs;
  135. RefCell cell_root;
  136. cell_root.node_name = root_node->GetName();
  137. cell_root.node = root_node;
  138. cell_root.in_out = NODE_OUT;
  139. cell_root.in_out_idx = ref_o;
  140. out_ref_i_all_refs.emplace_back(cell_root);
  141. for (const auto &ele : ref_o_net_nodes) {
  142. RefCell cell_netoutput_in;
  143. cell_netoutput_in.node_name = (ele.first)->GetName();
  144. cell_netoutput_in.node = ele.first;
  145. cell_netoutput_in.in_out = NODE_IN;
  146. cell_netoutput_in.in_out_idx = ele.second;
  147. out_ref_i_all_refs.emplace_back(cell_netoutput_in);
  148. }
  149. node_refs.emplace_back(out_ref_i_all_refs);
  150. ref_o++;
  151. }
  152. return GRAPH_SUCCESS;
  153. }
  154. graphStatus RefRelations::Impl::BuildLookUpTables() {
  155. GELOGD("start to build look up table!");
  156. for (size_t i = 0; i < values_.size(); i++) {
  157. vector<vector<RefCell>> &val = values_[i];
  158. for (const auto &ele : val) {
  159. for (const auto &ref_cell : ele) {
  160. string key = ref_cell.node_name + std::to_string(ref_cell.in_out) +
  161. std::to_string(ref_cell.in_out_idx) +
  162. std::to_string(static_cast<unsigned long>(reinterpret_cast<uintptr_t>(ref_cell.node.get())));
  163. look_up_table_[key] = ele;
  164. }
  165. }
  166. }
  167. return GRAPH_SUCCESS;
  168. }
  169. graphStatus RefRelations::Impl::BuildRefRelationsForWhile(
  170. const NodePtr &root_node,
  171. const vector<vector<NodePtr>> &classed_data_nodes,
  172. const vector<vector<std::pair<NodePtr, size_t>>> &classed_netoutput_nodes,
  173. vector<vector<RefCell>> &node_refs) {
  174. GELOGD("Enter BuildRefRelations for while op!");
  175. // data_nodes has been sorted
  176. // for while, input num must be same as output num
  177. auto input_num = root_node->GetAllInDataAnchorsSize();
  178. NodePtr netoutput = nullptr;
  179. size_t ref_i = 0;
  180. while (ref_i < input_num) {
  181. auto &ref_i_data_nodes = classed_data_nodes[ref_i];
  182. auto &ref_i_net_nodes = classed_netoutput_nodes[ref_i];
  183. vector<RefCell> ref_i_all_refs;
  184. RefCell cell_root_i;
  185. RefCell cell_root_o;
  186. cell_root_i.node_name = root_node->GetName();
  187. cell_root_i.node = root_node;
  188. cell_root_i.in_out = NODE_IN;
  189. cell_root_i.in_out_idx = ref_i;
  190. ref_i_all_refs.emplace_back(cell_root_i);
  191. cell_root_o.node_name = root_node->GetName();
  192. cell_root_o.node = root_node;
  193. cell_root_o.in_out = NODE_OUT;
  194. cell_root_o.in_out_idx = ref_i;
  195. ref_i_all_refs.emplace_back(cell_root_o);
  196. for (const auto &data : ref_i_data_nodes) {
  197. RefCell cell_in;
  198. RefCell cell_out;
  199. cell_in.node_name = data->GetName();
  200. cell_in.node = data;
  201. cell_in.in_out = NODE_IN;
  202. cell_in.in_out_idx = 0;
  203. cell_out.node_name = data->GetName();
  204. cell_out.node = data;
  205. cell_out.in_out = NODE_OUT;
  206. cell_out.in_out_idx = 0;
  207. ref_i_all_refs.emplace_back(cell_in);
  208. ref_i_all_refs.emplace_back(cell_out);
  209. }
  210. for (const auto &ele : ref_i_net_nodes) {
  211. RefCell cell_netoutput_in;
  212. RefCell cell_netoutput_out;
  213. cell_netoutput_in.node_name = (ele.first)->GetName();
  214. cell_netoutput_in.node = ele.first;
  215. cell_netoutput_in.in_out = NODE_IN;
  216. cell_netoutput_in.in_out_idx = ele.second;
  217. ref_i_all_refs.emplace_back(cell_netoutput_in);
  218. netoutput = ele.first;
  219. }
  220. node_refs.emplace_back(ref_i_all_refs);
  221. ref_i++;
  222. }
  223. /* There exist scene like the follows, it means data0 data1 netoutput 0'th
  224. * and 1'th tensor should be the same addr.
  225. * Data0 Data1
  226. * \/
  227. * /\
  228. * netoutput
  229. */
  230. if (netoutput == nullptr) {
  231. return GRAPH_SUCCESS;
  232. }
  233. for (const auto &in_anchor : netoutput->GetAllInDataAnchors()) {
  234. auto peer_out_data_anchor = in_anchor->GetPeerOutAnchor();
  235. if (peer_out_data_anchor == nullptr) {
  236. continue;
  237. }
  238. auto peer_out_data_node = peer_out_data_anchor->GetOwnerNode();
  239. if (peer_out_data_node == nullptr || peer_out_data_node->GetOpDesc() == nullptr) {
  240. GELOGW("Node[%s]\'s peer_out_data_node or peer_out_data_node desc is null", (netoutput->GetName()).c_str());
  241. continue;
  242. }
  243. if (peer_out_data_node->GetType() != DATA) {
  244. continue;
  245. }
  246. auto in_data_anchor_idx = in_anchor->GetIdx();
  247. auto net_in_desc =
  248. netoutput->GetOpDesc()->MutableInputDesc(static_cast<uint32_t>(in_data_anchor_idx));
  249. int ref_d = 0;
  250. int ref_n = 0;
  251. (void)AttrUtils::GetInt(peer_out_data_node->GetOpDesc(), kRefIndex, ref_d);
  252. (void)AttrUtils::GetInt(net_in_desc, kRefIndex, ref_n);
  253. node_refs[ref_d].insert(node_refs[ref_d].end(), node_refs[ref_n].begin(), node_refs[ref_n].end());
  254. node_refs[ref_n].insert(node_refs[ref_n].end(), node_refs[ref_d].begin(), node_refs[ref_d].end());
  255. }
  256. return GRAPH_SUCCESS;
  257. }
  258. // build ref relations according to diff func op type
  259. graphStatus RefRelations::Impl::BuildRelationsWithFuncNodeType(
  260. const NodePtr &root_node,
  261. const vector<vector<NodePtr>> &classed_data_nodes,
  262. const vector<vector<std::pair<NodePtr, size_t>>> &classed_netoutput_nodes,
  263. vector<vector<RefCell>> &node_refs) {
  264. // data_nodes has been sorted
  265. auto node_type = root_node->GetType();
  266. auto status = GRAPH_SUCCESS;
  267. if (node_type != kWhile) {
  268. status = BuildRefRelationsForBranch(root_node, classed_data_nodes, classed_netoutput_nodes, node_refs);
  269. } else {
  270. status = BuildRefRelationsForWhile(root_node, classed_data_nodes, classed_netoutput_nodes, node_refs);
  271. }
  272. return status;
  273. }
  274. void RefRelations::Impl::GetDataAndNetoutputOfSubGraph(
  275. const ge::ComputeGraph &root_graph,
  276. vector<NodePtr> &data_nodes,
  277. vector<NodePtr> &netoutput_nodes,
  278. const std::vector<std::string> &sub_graph_names,
  279. const std::string &node_type) {
  280. int sub_graph_idx = 0;
  281. for (const auto &name : sub_graph_names) {
  282. auto sub_graph = root_graph.GetSubgraph(name);
  283. if (sub_graph == nullptr) {
  284. GELOGW("Can not find the sub graph %s for root graph %s.", name.c_str(), root_graph.GetName().c_str());
  285. continue;
  286. }
  287. for (const auto &sub_graph_node : sub_graph->GetDirectNode()) {
  288. auto sub_graph_node_type = sub_graph_node->GetType();
  289. if (sub_graph_node_type == DATA) {
  290. data_nodes.emplace_back(sub_graph_node);
  291. } else if (sub_graph_node_type == NETOUTPUT) {
  292. // if while, the first subgraph must be cond subgraph.
  293. // There is no meaning for refs ,so continue
  294. if (node_type == kWhile && sub_graph_idx == 0) {
  295. continue;
  296. }
  297. netoutput_nodes.emplace_back(sub_graph_node);
  298. }
  299. continue;
  300. }
  301. sub_graph_idx++;
  302. }
  303. }
  304. graphStatus RefRelations::Impl::GetRootGraph(ge::ComputeGraph &graph, ge::ComputeGraph &root_graph) {
  305. auto parent_graph_ptr = graph.GetParentGraph();
  306. if (parent_graph_ptr == nullptr) {
  307. root_graph = graph;
  308. return GRAPH_SUCCESS;
  309. }
  310. auto root_graph_ptr = GraphUtils::FindRootGraph(parent_graph_ptr);
  311. if (root_graph_ptr == nullptr) {
  312. GE_LOGE("Get null root graph");
  313. return GRAPH_PARAM_INVALID;
  314. }
  315. root_graph = *root_graph_ptr;
  316. return GRAPH_SUCCESS;
  317. }
  318. graphStatus RefRelations::Impl::ProcessSubgraphDataNodes(
  319. vector<NodePtr> &data_nodes,
  320. vector<vector<NodePtr>> &classed_data_nodes) {
  321. GELOGD("start to process subgraph data nodes!");
  322. int max_ref_idx = 0;
  323. for (const auto &e : data_nodes) {
  324. int i;
  325. bool is_exist = true;
  326. is_exist = AttrUtils::GetInt(e->GetOpDesc(), kRefIndex, i);
  327. if (!is_exist) {
  328. GELOGE(GRAPH_FAILED, "Invalid SubGraph NetOutput node[%s].no attr %s",
  329. e->GetName().c_str(), kRefIndex);
  330. return GRAPH_FAILED;
  331. }
  332. max_ref_idx = (i > max_ref_idx) ? i : max_ref_idx;
  333. }
  334. while (!data_nodes.empty()) {
  335. auto data = data_nodes.back();
  336. data_nodes.pop_back();
  337. int ref_idx = 0;
  338. (void)AttrUtils::GetInt(data->GetOpDesc(), kRefIndex, ref_idx);
  339. if (ref_idx >= static_cast<int>(classed_data_nodes.size())) {
  340. return GRAPH_FAILED;
  341. }
  342. classed_data_nodes[ref_idx].emplace_back(data);
  343. }
  344. return GRAPH_SUCCESS;
  345. }
  346. graphStatus RefRelations::Impl::ProcessSubgraphNetoutput(
  347. const vector<NodePtr> &netoutput_nodes,
  348. vector<vector<std::pair<NodePtr, size_t>>> &classed_netoutput_nodes) {
  349. GELOGD("[RefRelations]Start to process subgraph netoutput!");
  350. for (const auto &sub_netoutput_node : netoutput_nodes) {
  351. auto op_desc = sub_netoutput_node->GetOpDesc();
  352. GE_CHECK_NOTNULL(op_desc);
  353. for (const auto &in_data_anchor : sub_netoutput_node->GetAllInDataAnchors()) {
  354. auto in_desc = op_desc->MutableInputDesc(in_data_anchor->GetIdx());
  355. if (in_desc == nullptr) {
  356. GELOGE(GRAPH_FAILED, "Invalid NetOutput node [%s] idx [%lu], no tensor on it",
  357. sub_netoutput_node->GetName().c_str(), in_data_anchor->GetIdx());
  358. return GRAPH_FAILED;
  359. }
  360. int ref_o;
  361. if (AttrUtils::GetInt(in_desc, kRefIndex, ref_o)) {
  362. if (ref_o >= static_cast<int>(classed_netoutput_nodes.size())) {
  363. return GRAPH_FAILED;
  364. }
  365. classed_netoutput_nodes[ref_o].emplace_back(std::pair<NodePtr, size_t>(
  366. {sub_netoutput_node, static_cast<size_t>(in_data_anchor->GetIdx())}
  367. ));
  368. }
  369. }
  370. }
  371. return GRAPH_SUCCESS;
  372. }
  373. graphStatus RefRelations::Impl::BuildRefRelations(ge::ComputeGraph &graph) {
  374. GELOGD("Start to build ref relations!");
  375. /* First Step: Get root graph */
  376. ge::ComputeGraph &root_graph = graph;
  377. auto status = GetRootGraph(graph, root_graph);
  378. if (status != GRAPH_SUCCESS) {
  379. return status;
  380. }
  381. for (const auto &node : graph.GetAllNodes()) {
  382. auto node_type = node->GetType();
  383. std::vector<NodePtr> ref_nodes;
  384. auto op_desc = node->GetOpDesc();
  385. auto sub_graph_names = op_desc->GetSubgraphInstanceNames();
  386. if (sub_graph_names.empty()) {
  387. continue;
  388. }
  389. vector<NodePtr> data_nodes;
  390. vector<NodePtr> netoutput_nodes;
  391. // Get data and netoutput of sub_graph
  392. GetDataAndNetoutputOfSubGraph(root_graph, data_nodes, netoutput_nodes, sub_graph_names, node_type);
  393. size_t max_elem_num = (data_nodes.size() > kMaxElementNum) ? data_nodes.size() : kMaxElementNum;
  394. vector<vector<NodePtr>> classed_data_nodes(max_elem_num); // according to ref_idx
  395. vector<vector<std::pair<NodePtr, size_t>>> classed_netoutput_nodes(max_elem_num); // according to ref_idx
  396. status = ProcessSubgraphDataNodes(data_nodes, classed_data_nodes);
  397. if (status != GRAPH_SUCCESS) {
  398. GELOGE(GRAPH_FAILED, "classfy data nodes failed!");
  399. return status;
  400. }
  401. // for netoutput
  402. // check netoutput
  403. // here main graph output number must be the same as every sub_graph netoutput node
  404. // key: netoutput node_ptr ,<ref_idx, net_in_idx>
  405. status = ProcessSubgraphNetoutput(netoutput_nodes, classed_netoutput_nodes);
  406. if (status != GRAPH_SUCCESS) {
  407. GELOGE(GRAPH_FAILED, "process netoutput failed!");
  408. return status;
  409. }
  410. vector<vector<RefCell>> node_refs;
  411. status = BuildRelationsWithFuncNodeType(node, classed_data_nodes, classed_netoutput_nodes, node_refs);
  412. if (status != GRAPH_SUCCESS) {
  413. GELOGE(status, "BuildRelationsWithFuncNodeType Failed! Node is [%s]!", node->GetName().c_str());
  414. return status;
  415. }
  416. if (!node_refs.empty()) {
  417. values_.push_back(node_refs);
  418. }
  419. }
  420. /* Seconde Step: generate map */
  421. status = BuildLookUpTables();
  422. if (status != GRAPH_SUCCESS) {
  423. GELOGE(status, "Build look up tables failed!");
  424. return status;
  425. }
  426. return GRAPH_SUCCESS;
  427. }
  428. /* Ref Relations Interface */
  429. RefRelations::RefRelations() {
  430. impl_ = MakeShared<Impl>();
  431. if (impl_ == nullptr) {
  432. GELOGE(GRAPH_FAILED, "MakeShared failed!");
  433. return;
  434. }
  435. }
  436. graphStatus RefRelations::LookUpRefRelations(const RefCell &key, unordered_set<RefCell, RefCellHash> &result) {
  437. GE_CHECK_NOTNULL(impl_);
  438. return impl_->LookUpRefRelations(key, result);
  439. }
  440. graphStatus RefRelations::BuildRefRelations(ge::ComputeGraph &root_graph) {
  441. GE_CHECK_NOTNULL(impl_);
  442. return impl_->BuildRefRelations(root_graph);
  443. }
  444. graphStatus RefRelations::Clear() {
  445. GE_CHECK_NOTNULL(impl_);
  446. return impl_->Clear();
  447. }
  448. }

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示