You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

ref_relation.cc 18 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455
  1. /**
  2. * Copyright 2019-2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "graph/ref_relation.h"
  17. #include <unordered_set>
  18. #include <unordered_map>
  19. #include "utils/mem_utils.h"
  20. #include "debug/ge_log.h"
  21. #include "debug/ge_op_types.h"
  22. #include "debug/ge_util.h"
  23. #include "debug/ge_attr_define.h"
  24. #include "graph/ge_error_codes.h"
  25. #include "graph/utils/graph_utils.h"
  26. #include "framework/common/debug/ge_log.h"
  27. using namespace std;
  28. using namespace ge;
  29. namespace ge {
  30. namespace {
  31. const char *kRefIndex = "_parent_node_index";
  32. const string kWhile = "While";
  33. const string kIf = "If";
  34. const string kCase = "Case";
  35. const uint16_t kMaxElementNum = 100;
  36. std::unordered_set<string> function_op = {kWhile, kIf, kCase};
  37. } // namespace
  38. /* Impl */
  39. class RefRelations::Impl {
  40. public:
  41. graphStatus LookUpRefRelations(const RefCell &key, unordered_set<RefCell, RefCellHash> &result) {
  42. unsigned long number = static_cast<unsigned long>(reinterpret_cast<uintptr_t>(key.node.get()));
  43. std::string lookup_key =
  44. key.node_name + std::to_string(key.in_out) + std::to_string(key.in_out_idx) + std::to_string(number);
  45. auto iter = look_up_table_.find(lookup_key);
  46. if (iter != look_up_table_.end()) {
  47. for (auto &c : iter->second) {
  48. result.insert(c);
  49. }
  50. return GRAPH_SUCCESS;
  51. }
  52. GELOGW("can not find any relations! key value of dest relation is %s", lookup_key.c_str());
  53. return GRAPH_SUCCESS;
  54. };
  55. graphStatus BuildRefRelations(ge::ComputeGraph &root_graph);
  56. graphStatus Clear() {
  57. GELOGD("Start clear boundary reflections between main graph and sub graph!");
  58. look_up_table_.clear();
  59. values_.clear();
  60. return GRAPH_SUCCESS;
  61. };
  62. private:
  63. graphStatus BuildLookUpTables();
  64. graphStatus BuildRefRelationsForBranch(const NodePtr &root_node, const vector<vector<NodePtr>> &classed_data_nodes,
  65. const vector<vector<std::pair<NodePtr, size_t>>> &classed_netoutput_nodes,
  66. vector<vector<RefCell>> &node_refs);
  67. graphStatus BuildRefRelationsForWhile(const NodePtr &root_node, const vector<vector<NodePtr>> &classed_data_nodes,
  68. const vector<vector<std::pair<NodePtr, size_t>>> &classed_netoutput_nodes,
  69. vector<vector<RefCell>> &node_refs);
  70. graphStatus BuildRelationsWithFuncNodeType(const NodePtr &root_node,
  71. const vector<vector<NodePtr>> &classed_data_nodes,
  72. const vector<vector<std::pair<NodePtr, size_t>>> &classed_netoutput_nodes,
  73. vector<vector<RefCell>> &node_refs);
  74. void GetDataAndNetoutputOfSubGraph(const ge::ComputeGraph &root_graph, vector<NodePtr> &data_nodes,
  75. vector<NodePtr> &netoutput_nodes, const std::vector<std::string> &sub_graph_names,
  76. const std::string &node_type);
  77. graphStatus GetRootGraph(ge::ComputeGraph &graph, ge::ComputeGraph &root_graph);
  78. graphStatus ProcessSubgraphDataNodes(vector<NodePtr> &data_nodes, vector<vector<NodePtr>> &classed_data_nodes);
  79. graphStatus ProcessSubgraphNetoutput(const vector<NodePtr> &netoutput_nodes,
  80. vector<vector<std::pair<NodePtr, size_t>>> &classed_netoutput_nodes);
  81. std::unordered_map<string, vector<RefCell>> look_up_table_;
  82. std::vector<vector<vector<RefCell>>> values_;
  83. };
  84. // Node Level
  85. graphStatus RefRelations::Impl::BuildRefRelationsForBranch(
  86. const NodePtr &root_node, const vector<vector<NodePtr>> &classed_data_nodes,
  87. const vector<vector<std::pair<NodePtr, size_t>>> &classed_netoutput_nodes, vector<vector<RefCell>> &node_refs) {
  88. GELOGD("Enter BuildRefRelationsForBranch!");
  89. size_t ref_i = 0;
  90. for (const auto &ref_i_data_nodes : classed_data_nodes) {
  91. vector<RefCell> in_ref_i_all_refs;
  92. RefCell cell_root;
  93. cell_root.node_name = root_node->GetName();
  94. cell_root.node = root_node;
  95. cell_root.in_out = NODE_IN;
  96. cell_root.in_out_idx = ref_i;
  97. in_ref_i_all_refs.emplace_back(cell_root);
  98. for (const auto &data : ref_i_data_nodes) {
  99. RefCell cell_in;
  100. RefCell cell_out;
  101. cell_in.node_name = data->GetName();
  102. cell_in.node = data;
  103. cell_in.in_out = NODE_IN;
  104. cell_in.in_out_idx = 0;
  105. cell_out.node_name = data->GetName();
  106. cell_out.node = data;
  107. cell_out.in_out = NODE_OUT;
  108. cell_out.in_out_idx = 0;
  109. in_ref_i_all_refs.emplace_back(cell_in);
  110. in_ref_i_all_refs.emplace_back(cell_out);
  111. }
  112. node_refs.emplace_back(in_ref_i_all_refs);
  113. ref_i++;
  114. }
  115. size_t ref_o = 0;
  116. for (const auto &ref_o_net_nodes : classed_netoutput_nodes) {
  117. vector<RefCell> out_ref_i_all_refs;
  118. RefCell cell_root;
  119. cell_root.node_name = root_node->GetName();
  120. cell_root.node = root_node;
  121. cell_root.in_out = NODE_OUT;
  122. cell_root.in_out_idx = ref_o;
  123. out_ref_i_all_refs.emplace_back(cell_root);
  124. for (const auto &ele : ref_o_net_nodes) {
  125. RefCell cell_netoutput_in;
  126. cell_netoutput_in.node_name = (ele.first)->GetName();
  127. cell_netoutput_in.node = ele.first;
  128. cell_netoutput_in.in_out = NODE_IN;
  129. cell_netoutput_in.in_out_idx = ele.second;
  130. out_ref_i_all_refs.emplace_back(cell_netoutput_in);
  131. }
  132. node_refs.emplace_back(out_ref_i_all_refs);
  133. ref_o++;
  134. }
  135. return GRAPH_SUCCESS;
  136. }
  137. graphStatus RefRelations::Impl::BuildLookUpTables() {
  138. GELOGD("start to build look up table!");
  139. for (size_t i = 0; i < values_.size(); i++) {
  140. vector<vector<RefCell>> &val = values_[i];
  141. for (const auto &ele : val) {
  142. for (const auto &ref_cell : ele) {
  143. string key = ref_cell.node_name + std::to_string(ref_cell.in_out) + std::to_string(ref_cell.in_out_idx) +
  144. std::to_string(static_cast<unsigned long>(reinterpret_cast<uintptr_t>(ref_cell.node.get())));
  145. look_up_table_[key] = ele;
  146. }
  147. }
  148. }
  149. return GRAPH_SUCCESS;
  150. }
  151. graphStatus RefRelations::Impl::BuildRefRelationsForWhile(
  152. const NodePtr &root_node, const vector<vector<NodePtr>> &classed_data_nodes,
  153. const vector<vector<std::pair<NodePtr, size_t>>> &classed_netoutput_nodes, vector<vector<RefCell>> &node_refs) {
  154. GELOGD("Enter BuildRefRelations for while op!");
  155. // data_nodes has been sorted
  156. // for while, input num must be same as output num
  157. auto input_num = root_node->GetAllInDataAnchorsSize();
  158. NodePtr netoutput = nullptr;
  159. size_t ref_i = 0;
  160. while (ref_i < input_num) {
  161. auto &ref_i_data_nodes = classed_data_nodes[ref_i];
  162. auto &ref_i_net_nodes = classed_netoutput_nodes[ref_i];
  163. vector<RefCell> ref_i_all_refs;
  164. RefCell cell_root_i;
  165. RefCell cell_root_o;
  166. cell_root_i.node_name = root_node->GetName();
  167. cell_root_i.node = root_node;
  168. cell_root_i.in_out = NODE_IN;
  169. cell_root_i.in_out_idx = ref_i;
  170. ref_i_all_refs.emplace_back(cell_root_i);
  171. cell_root_o.node_name = root_node->GetName();
  172. cell_root_o.node = root_node;
  173. cell_root_o.in_out = NODE_OUT;
  174. cell_root_o.in_out_idx = ref_i;
  175. ref_i_all_refs.emplace_back(cell_root_o);
  176. for (const auto &data : ref_i_data_nodes) {
  177. RefCell cell_in;
  178. RefCell cell_out;
  179. cell_in.node_name = data->GetName();
  180. cell_in.node = data;
  181. cell_in.in_out = NODE_IN;
  182. cell_in.in_out_idx = 0;
  183. cell_out.node_name = data->GetName();
  184. cell_out.node = data;
  185. cell_out.in_out = NODE_OUT;
  186. cell_out.in_out_idx = 0;
  187. ref_i_all_refs.emplace_back(cell_in);
  188. ref_i_all_refs.emplace_back(cell_out);
  189. }
  190. for (const auto &ele : ref_i_net_nodes) {
  191. RefCell cell_netoutput_in;
  192. RefCell cell_netoutput_out;
  193. cell_netoutput_in.node_name = (ele.first)->GetName();
  194. cell_netoutput_in.node = ele.first;
  195. cell_netoutput_in.in_out = NODE_IN;
  196. cell_netoutput_in.in_out_idx = ele.second;
  197. ref_i_all_refs.emplace_back(cell_netoutput_in);
  198. netoutput = ele.first;
  199. }
  200. node_refs.emplace_back(ref_i_all_refs);
  201. ref_i++;
  202. }
  203. /* There exist scene like the follows, it means data0 data1 netoutput 0'th
  204. * and 1'th tensor should be the same addr.
  205. * Data0 Data1
  206. * \/
  207. * /\
  208. * netoutput
  209. */
  210. if (netoutput == nullptr) {
  211. return GRAPH_SUCCESS;
  212. }
  213. for (const auto &in_anchor : netoutput->GetAllInDataAnchors()) {
  214. auto peer_out_data_anchor = in_anchor->GetPeerOutAnchor();
  215. if (peer_out_data_anchor == nullptr) {
  216. continue;
  217. }
  218. auto peer_out_data_node = peer_out_data_anchor->GetOwnerNode();
  219. if (peer_out_data_node == nullptr || peer_out_data_node->GetOpDesc() == nullptr) {
  220. GELOGW("Node[%s]\'s peer_out_data_node or peer_out_data_node desc is null", (netoutput->GetName()).c_str());
  221. continue;
  222. }
  223. if (peer_out_data_node->GetType() != DATA) {
  224. continue;
  225. }
  226. auto in_data_anchor_idx = in_anchor->GetIdx();
  227. auto net_in_desc = netoutput->GetOpDesc()->MutableInputDesc(static_cast<uint32_t>(in_data_anchor_idx));
  228. int ref_d = 0;
  229. int ref_n = 0;
  230. (void)AttrUtils::GetInt(peer_out_data_node->GetOpDesc(), kRefIndex, ref_d);
  231. (void)AttrUtils::GetInt(net_in_desc, kRefIndex, ref_n);
  232. node_refs[ref_d].insert(node_refs[ref_d].end(), node_refs[ref_n].begin(), node_refs[ref_n].end());
  233. node_refs[ref_n].insert(node_refs[ref_n].end(), node_refs[ref_d].begin(), node_refs[ref_d].end());
  234. }
  235. return GRAPH_SUCCESS;
  236. }
  237. // build ref relations according to diff func op type
  238. graphStatus RefRelations::Impl::BuildRelationsWithFuncNodeType(
  239. const NodePtr &root_node, const vector<vector<NodePtr>> &classed_data_nodes,
  240. const vector<vector<std::pair<NodePtr, size_t>>> &classed_netoutput_nodes, vector<vector<RefCell>> &node_refs) {
  241. // data_nodes has been sorted
  242. auto node_type = root_node->GetType();
  243. auto status = GRAPH_SUCCESS;
  244. if (node_type != kWhile) {
  245. status = BuildRefRelationsForBranch(root_node, classed_data_nodes, classed_netoutput_nodes, node_refs);
  246. } else {
  247. status = BuildRefRelationsForWhile(root_node, classed_data_nodes, classed_netoutput_nodes, node_refs);
  248. }
  249. return status;
  250. }
  251. void RefRelations::Impl::GetDataAndNetoutputOfSubGraph(const ge::ComputeGraph &root_graph, vector<NodePtr> &data_nodes,
  252. vector<NodePtr> &netoutput_nodes,
  253. const std::vector<std::string> &sub_graph_names,
  254. const std::string &node_type) {
  255. int sub_graph_idx = 0;
  256. for (const auto &name : sub_graph_names) {
  257. auto sub_graph = root_graph.GetSubgraph(name);
  258. if (sub_graph == nullptr) {
  259. GELOGW("Can not find the sub graph %s for root graph %s.", name.c_str(), root_graph.GetName().c_str());
  260. continue;
  261. }
  262. for (const auto &sub_graph_node : sub_graph->GetDirectNode()) {
  263. auto sub_graph_node_type = sub_graph_node->GetType();
  264. if (sub_graph_node_type == DATA) {
  265. data_nodes.emplace_back(sub_graph_node);
  266. } else if (sub_graph_node_type == NETOUTPUT) {
  267. // if while, the first subgraph must be cond subgraph.
  268. // There is no meaning for refs ,so continue
  269. if (node_type == kWhile && sub_graph_idx == 0) {
  270. continue;
  271. }
  272. netoutput_nodes.emplace_back(sub_graph_node);
  273. }
  274. continue;
  275. }
  276. sub_graph_idx++;
  277. }
  278. }
  279. graphStatus RefRelations::Impl::GetRootGraph(ge::ComputeGraph &graph, ge::ComputeGraph &root_graph) {
  280. auto parent_graph_ptr = graph.GetParentGraph();
  281. if (parent_graph_ptr == nullptr) {
  282. root_graph = graph;
  283. return GRAPH_SUCCESS;
  284. }
  285. auto root_graph_ptr = GraphUtils::FindRootGraph(parent_graph_ptr);
  286. if (root_graph_ptr == nullptr) {
  287. GE_LOGE("Get null root graph");
  288. return GRAPH_PARAM_INVALID;
  289. }
  290. root_graph = *root_graph_ptr;
  291. return GRAPH_SUCCESS;
  292. }
  293. graphStatus RefRelations::Impl::ProcessSubgraphDataNodes(vector<NodePtr> &data_nodes,
  294. vector<vector<NodePtr>> &classed_data_nodes) {
  295. GELOGD("start to process subgraph data nodes!");
  296. int max_ref_idx = 0;
  297. for (const auto &e : data_nodes) {
  298. int i;
  299. bool is_exist = true;
  300. is_exist = AttrUtils::GetInt(e->GetOpDesc(), kRefIndex, i);
  301. if (!is_exist) {
  302. GELOGE(GRAPH_FAILED, "Invalid SubGraph NetOutput node[%s].no attr %s", e->GetName().c_str(), kRefIndex);
  303. return GRAPH_FAILED;
  304. }
  305. max_ref_idx = (i > max_ref_idx) ? i : max_ref_idx;
  306. }
  307. while (!data_nodes.empty()) {
  308. auto data = data_nodes.back();
  309. data_nodes.pop_back();
  310. int ref_idx = 0;
  311. (void)AttrUtils::GetInt(data->GetOpDesc(), kRefIndex, ref_idx);
  312. if (ref_idx >= static_cast<int>(classed_data_nodes.size())) {
  313. return GRAPH_FAILED;
  314. }
  315. classed_data_nodes[ref_idx].emplace_back(data);
  316. }
  317. return GRAPH_SUCCESS;
  318. }
  319. graphStatus RefRelations::Impl::ProcessSubgraphNetoutput(
  320. const vector<NodePtr> &netoutput_nodes, vector<vector<std::pair<NodePtr, size_t>>> &classed_netoutput_nodes) {
  321. GELOGD("[RefRelations]Start to process subgraph netoutput!");
  322. for (const auto &sub_netoutput_node : netoutput_nodes) {
  323. auto op_desc = sub_netoutput_node->GetOpDesc();
  324. GE_CHECK_NOTNULL(op_desc);
  325. for (const auto &in_data_anchor : sub_netoutput_node->GetAllInDataAnchors()) {
  326. auto in_desc = op_desc->MutableInputDesc(in_data_anchor->GetIdx());
  327. if (in_desc == nullptr) {
  328. GELOGE(GRAPH_FAILED, "Invalid NetOutput node [%s] idx [%lu], no tensor on it",
  329. sub_netoutput_node->GetName().c_str(), in_data_anchor->GetIdx());
  330. return GRAPH_FAILED;
  331. }
  332. int ref_o;
  333. if (AttrUtils::GetInt(in_desc, kRefIndex, ref_o)) {
  334. if (ref_o >= static_cast<int>(classed_netoutput_nodes.size())) {
  335. return GRAPH_FAILED;
  336. }
  337. classed_netoutput_nodes[ref_o].emplace_back(
  338. std::pair<NodePtr, size_t>({sub_netoutput_node, static_cast<size_t>(in_data_anchor->GetIdx())}));
  339. }
  340. }
  341. }
  342. return GRAPH_SUCCESS;
  343. }
  344. graphStatus RefRelations::Impl::BuildRefRelations(ge::ComputeGraph &graph) {
  345. GELOGD("Start to build ref relations!");
  346. /* First Step: Get root graph */
  347. ge::ComputeGraph &root_graph = graph;
  348. auto status = GetRootGraph(graph, root_graph);
  349. if (status != GRAPH_SUCCESS) {
  350. return status;
  351. }
  352. for (const auto &node : graph.GetAllNodes()) {
  353. auto node_type = node->GetType();
  354. std::vector<NodePtr> ref_nodes;
  355. auto op_desc = node->GetOpDesc();
  356. auto sub_graph_names = op_desc->GetSubgraphInstanceNames();
  357. if (sub_graph_names.empty()) {
  358. continue;
  359. }
  360. vector<NodePtr> data_nodes;
  361. vector<NodePtr> netoutput_nodes;
  362. // Get data and netoutput of sub_graph
  363. GetDataAndNetoutputOfSubGraph(root_graph, data_nodes, netoutput_nodes, sub_graph_names, node_type);
  364. size_t max_elem_num = (data_nodes.size() > kMaxElementNum) ? data_nodes.size() : kMaxElementNum;
  365. vector<vector<NodePtr>> classed_data_nodes(max_elem_num); // according to ref_idx
  366. vector<vector<std::pair<NodePtr, size_t>>> classed_netoutput_nodes(max_elem_num); // according to ref_idx
  367. status = ProcessSubgraphDataNodes(data_nodes, classed_data_nodes);
  368. if (status != GRAPH_SUCCESS) {
  369. GELOGE(GRAPH_FAILED, "classfy data nodes failed!");
  370. return status;
  371. }
  372. // for netoutput
  373. // check netoutput
  374. // here main graph output number must be the same as every sub_graph netoutput node
  375. // key: netoutput node_ptr ,<ref_idx, net_in_idx>
  376. status = ProcessSubgraphNetoutput(netoutput_nodes, classed_netoutput_nodes);
  377. if (status != GRAPH_SUCCESS) {
  378. GELOGE(GRAPH_FAILED, "process netoutput failed!");
  379. return status;
  380. }
  381. vector<vector<RefCell>> node_refs;
  382. status = BuildRelationsWithFuncNodeType(node, classed_data_nodes, classed_netoutput_nodes, node_refs);
  383. if (status != GRAPH_SUCCESS) {
  384. GELOGE(status, "BuildRelationsWithFuncNodeType Failed! Node is [%s]!", node->GetName().c_str());
  385. return status;
  386. }
  387. if (!node_refs.empty()) {
  388. values_.push_back(node_refs);
  389. }
  390. }
  391. /* Seconde Step: generate map */
  392. status = BuildLookUpTables();
  393. if (status != GRAPH_SUCCESS) {
  394. GELOGE(status, "Build look up tables failed!");
  395. return status;
  396. }
  397. return GRAPH_SUCCESS;
  398. }
  399. /* Ref Relations Interface */
  400. RefRelations::RefRelations() {
  401. impl_ = MakeShared<Impl>();
  402. if (impl_ == nullptr) {
  403. GELOGE(GRAPH_FAILED, "MakeShared failed!");
  404. return;
  405. }
  406. }
  407. graphStatus RefRelations::LookUpRefRelations(const RefCell &key, unordered_set<RefCell, RefCellHash> &result) {
  408. GE_CHECK_NOTNULL(impl_);
  409. return impl_->LookUpRefRelations(key, result);
  410. }
  411. graphStatus RefRelations::BuildRefRelations(ge::ComputeGraph &root_graph) {
  412. GE_CHECK_NOTNULL(impl_);
  413. return impl_->BuildRefRelations(root_graph);
  414. }
  415. graphStatus RefRelations::Clear() {
  416. GE_CHECK_NOTNULL(impl_);
  417. return impl_->Clear();
  418. }
  419. } // namespace ge

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示