You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

shape_refiner.cc 14 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362
  1. /**
  2. * Copyright 2019-2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "graph/shape_refiner.h"
  17. #include <memory>
  18. #include <string>
  19. #include <unordered_map>
  20. #include <utility>
  21. #include <vector>
  22. #include "framework/common/types.h"
  23. #include "graph/utils/graph_utils.h"
  24. #include "debug/ge_log.h"
  25. #include "debug/ge_op_types.h"
  26. #include "external/graph/operator.h"
  27. #include "external/graph/operator_factory.h"
  28. #include "framework/common/debug/ge_log.h"
  29. #include "graph/compute_graph.h"
  30. #include "utils/node_utils.h"
  31. #include "utils/op_desc_utils.h"
  32. #include "utils/tensor_utils.h"
  33. #include "utils/type_utils.h"
  34. namespace ge {
  35. namespace {
  36. constexpr const char *kRefIndex = "parent_node_index";
  37. graphStatus UpdateSubGraphDataNodes(const ConstNodePtr &node) {
  38. auto op_desc = node->GetOpDesc();
  39. auto sub_graph_names = op_desc->GetSubgraphInstanceNames();
  40. if (sub_graph_names.empty()) {
  41. return GRAPH_SUCCESS;
  42. }
  43. auto root_graph = GraphUtils::FindRootGraph(node->GetOwnerComputeGraph());
  44. for (const auto &name : sub_graph_names) {
  45. auto sub_graph = root_graph->GetSubgraph(name);
  46. if (sub_graph == nullptr) {
  47. GE_LOGE("Can node find the subgrpah %s for node %s", name.c_str(), node->GetName().c_str());
  48. return GRAPH_FAILED;
  49. }
  50. for (const auto &node_sub : sub_graph->GetDirectNode()) {
  51. if (node_sub->GetType() != DATA) {
  52. continue;
  53. }
  54. int ref_i;
  55. auto data_opdesc = node_sub->GetOpDesc();
  56. if (data_opdesc == nullptr) {
  57. GE_LOGE("Invalid data node on the sub graph %s parent node %s, no OpDesc", name.c_str(),
  58. node->GetName().c_str());
  59. return GRAPH_FAILED;
  60. }
  61. if (!AttrUtils::GetInt(node_sub->GetOpDesc(), kRefIndex, ref_i)) {
  62. GE_LOGE("Invalid data node on the sub graph %s parent node %s, no ref-index attribute", name.c_str(),
  63. node->GetName().c_str());
  64. return GRAPH_FAILED;
  65. }
  66. auto input_desc = op_desc->MutableInputDesc(ref_i);
  67. if (input_desc == nullptr) {
  68. GE_LOGE(
  69. "The ref index(%d) on the data %s on the sub graph %s "
  70. "parent node %s are incompatible, inputs num %u",
  71. ref_i, node_sub->GetName().c_str(), name.c_str(), node->GetName().c_str(), node->GetAllOutDataAnchorsSize());
  72. return GRAPH_FAILED;
  73. }
  74. auto ret = data_opdesc->UpdateInputDesc(0, *input_desc);
  75. if (ret != GRAPH_SUCCESS) {
  76. GE_LOGE("Failed to update input desc of data %s on the sub graph %s parent node %s",
  77. node_sub->GetName().c_str(), name.c_str(), node->GetName().c_str());
  78. return ret;
  79. }
  80. ret = data_opdesc->UpdateOutputDesc(0, *input_desc);
  81. if (ret != GRAPH_SUCCESS) {
  82. GE_LOGE("Failed to update output desc of data %s on the sub graph %s parent node %s",
  83. node_sub->GetName().c_str(), name.c_str(), node->GetName().c_str());
  84. return ret;
  85. }
  86. }
  87. }
  88. return GRAPH_SUCCESS;
  89. }
  90. graphStatus UpdateParentNodeOutTensor(const ConstNodePtr &node) {
  91. auto op_desc = node->GetOpDesc();
  92. auto sub_graph_names = op_desc->GetSubgraphInstanceNames();
  93. if (sub_graph_names.empty()) {
  94. return GRAPH_SUCCESS;
  95. }
  96. auto root_graph = GraphUtils::FindRootGraph(node->GetOwnerComputeGraph());
  97. for (const auto &name : sub_graph_names) {
  98. auto sub_graph = root_graph->GetSubgraph(name);
  99. if (sub_graph == nullptr) {
  100. GE_LOGE("Can node find the subgrpah %s for node %s", name.c_str(), node->GetName().c_str());
  101. return GRAPH_FAILED;
  102. }
  103. NodePtr netoutput = nullptr;
  104. auto sub_nodes = sub_graph->GetDirectNode();
  105. for (size_t i = sub_nodes.size(); i > 0; --i) {
  106. auto sub_node = sub_nodes.at(i - 1);
  107. if (sub_node->GetType() == NETOUTPUT) {
  108. netoutput = sub_node;
  109. break;
  110. }
  111. }
  112. if (netoutput == nullptr) {
  113. GE_LOGE("No NetOutput node on sub graph %s, parent node %s", name.c_str(), node->GetName().c_str());
  114. return GRAPH_FAILED;
  115. }
  116. auto netoutput_opdesc = netoutput->GetOpDesc();
  117. if (netoutput_opdesc == nullptr) {
  118. GE_LOGE("Invalid NetOutput node on sub graph %s, parent node %s, no OpDesc on it", name.c_str(),
  119. node->GetName().c_str());
  120. return GRAPH_FAILED;
  121. }
  122. for (auto &edge_anchor : netoutput->GetAllInDataAnchors()) {
  123. auto edge_desc = netoutput_opdesc->MutableInputDesc(edge_anchor->GetIdx());
  124. if (edge_desc == nullptr) {
  125. GE_LOGE("Invalid NetOutput node on sub graph %s, parent node %s, can not find input tensor %d", name.c_str(),
  126. node->GetName().c_str(), edge_anchor->GetIdx());
  127. return GRAPH_FAILED;
  128. }
  129. int ref_i;
  130. if (!AttrUtils::GetInt(edge_desc, kRefIndex, ref_i)) {
  131. // if there is no ref index on the TensorDesc, it means the output data will be ignored outer.
  132. continue;
  133. }
  134. auto output_desc = op_desc->MutableOutputDesc(static_cast<uint32_t>(ref_i));
  135. if (output_desc == nullptr) {
  136. GE_LOGE(
  137. "The ref index(%d) on the input %d of netoutput %s on the sub graph %s "
  138. "parent node %s are incompatible, outputs num %u",
  139. ref_i, edge_anchor->GetIdx(), netoutput->GetName().c_str(), name.c_str(), node->GetName().c_str(),
  140. node->GetAllOutDataAnchorsSize());
  141. return GRAPH_FAILED;
  142. }
  143. op_desc->UpdateOutputDesc(edge_anchor->GetIdx(), *edge_desc);
  144. }
  145. }
  146. return GRAPH_SUCCESS;
  147. }
  148. } // namespace
  149. void ShapeRefiner::PrintInOutTensorShape(const ge::NodePtr &node, const std::string &phase) {
  150. if (node == nullptr) {
  151. GELOGE(GRAPH_FAILED, "node is null");
  152. return;
  153. }
  154. ge::OpDescPtr op_desc = node->GetOpDesc();
  155. GE_IF_BOOL_EXEC(op_desc == nullptr, GELOGE(GRAPH_FAILED, "op_desc is null."); return );
  156. std::string str;
  157. if (op_desc->GetInputsSize() != 0) {
  158. std::string input_desc_str = "input shape: ";
  159. for (const auto &input_desc : op_desc->GetAllInputsDescPtr()) {
  160. input_desc_str += "[";
  161. for (int64_t dim : input_desc->GetShape().GetDims()) {
  162. input_desc_str += std::to_string(dim) + " ";
  163. }
  164. input_desc_str += "]";
  165. input_desc_str += ":" + TypeUtils::DataTypeToSerialString(input_desc->GetDataType()) + ":" +
  166. TypeUtils::FormatToSerialString(input_desc->GetFormat()) + " ";
  167. }
  168. str += input_desc_str;
  169. }
  170. if (op_desc->GetAllOutputsDescSize() != 0) {
  171. std::string output_desc_str = "output shape: ";
  172. for (const auto &output_desc : op_desc->GetAllOutputsDescPtr()) {
  173. if (output_desc == nullptr) {
  174. continue;
  175. }
  176. output_desc_str += "[";
  177. for (int64_t dim : output_desc->GetShape().GetDims()) {
  178. output_desc_str += std::to_string(dim) + " ";
  179. }
  180. output_desc_str += "]";
  181. output_desc_str += ":" + TypeUtils::DataTypeToSerialString(output_desc->GetDataType()) + ":" +
  182. TypeUtils::FormatToSerialString(output_desc->GetFormat()) + " ";
  183. }
  184. str += output_desc_str;
  185. }
  186. GELOGD("Shape dump [%s], Node name: [%s]. %s", phase.c_str(), node->GetName().c_str(), str.c_str());
  187. }
  188. graphStatus ShapeRefiner::InferShapeAndType(const ConstNodePtr &node, Operator &op) {
  189. return InferShapeAndType(node, op, true);
  190. }
  191. graphStatus ShapeRefiner::InferShapeAndType(const ConstNodePtr &node, Operator &op, bool before_subgraph) {
  192. GE_IF_BOOL_EXEC(node == nullptr, GELOGE(GRAPH_FAILED, "node is null."); return GRAPH_FAILED);
  193. auto op_desc = node->GetOpDesc();
  194. GE_IF_BOOL_EXEC(op_desc == nullptr, GELOGE(GRAPH_FAILED, "op_desc is null."); return GRAPH_FAILED);
  195. const auto &op_type = op_desc->GetType();
  196. graphStatus ret;
  197. if (before_subgraph) {
  198. ret = UpdateSubGraphDataNodes(node);
  199. if (ret != GRAPH_SUCCESS) {
  200. return ret;
  201. }
  202. }
  203. // Get infer func and execute
  204. ret = op_desc->CallInferFunc(op);
  205. if (ret == GRAPH_PARAM_INVALID) {
  206. // Op ir no infer func, try to get infer func from operator factory
  207. auto node_op = ge::OperatorFactory::CreateOperator("node_op", op_desc->GetType());
  208. if (node_op.IsEmpty()) {
  209. GELOGW("get op from OperatorFactory fail. opType: %s", op_type.c_str());
  210. return ret;
  211. }
  212. GELOGD("get op from OperatorFactory success. opType: %s", op_type.c_str());
  213. auto temp_op_desc = ge::OpDescUtils::GetOpDescFromOperator(node_op);
  214. if (temp_op_desc == nullptr) {
  215. GELOGE(GRAPH_FAILED, "temp op desc is null");
  216. return GRAPH_FAILED;
  217. }
  218. if (!op_desc->UpdateInputName(temp_op_desc->GetAllInputName())) {
  219. GELOGW("InferShapeAndType UpdateInputName failed");
  220. for (const auto &out_desc : op_desc->GetAllOutputsDescPtr()) {
  221. if (out_desc != nullptr && out_desc->GetShape().GetDims().empty()) {
  222. break;
  223. }
  224. return GRAPH_SUCCESS;
  225. }
  226. }
  227. if (!op_desc->UpdateOutputName(temp_op_desc->GetAllOutputName())) {
  228. GELOGW("InferShapeAndType UpdateOutputName failed");
  229. }
  230. op_desc->AddInferFunc(temp_op_desc->GetInferFunc());
  231. ret = op_desc->CallInferFunc(op);
  232. GELOGI("op CallInferFunc second. ret: %u", ret);
  233. }
  234. if (ret != GRAPH_SUCCESS) {
  235. return ret;
  236. }
  237. if (!before_subgraph) {
  238. return UpdateParentNodeOutTensor(node);
  239. }
  240. return GRAPH_SUCCESS;
  241. }
  242. InferenceContextPtr CreateInferenceContext(const std::unordered_map<NodePtr, InferenceContextPtr> &context_map,
  243. const NodePtr &node) {
  244. if (node == nullptr) {
  245. GELOGE(GRAPH_FAILED, "node is null");
  246. return nullptr;
  247. }
  248. InferenceContextPtr inference_context = std::shared_ptr<InferenceContext>(InferenceContext::Create());
  249. if (inference_context == nullptr) {
  250. GELOGE(GRAPH_FAILED, "Failed to alloc InferenceContext");
  251. return nullptr;
  252. }
  253. auto all_in_data_anchors = node->GetAllInDataAnchors();
  254. std::vector<std::vector<ShapeAndType>> input_shapes_and_types(all_in_data_anchors.size());
  255. std::vector<std::string> marks;
  256. bool has_input_shapes_and_types = false;
  257. for (const auto &in_anchor : all_in_data_anchors) {
  258. const auto &out_anchor = in_anchor->GetPeerOutAnchor();
  259. if (out_anchor == nullptr) {
  260. continue;
  261. }
  262. auto input_node = out_anchor->GetOwnerNode();
  263. if (input_node == nullptr) {
  264. continue;
  265. }
  266. auto iter = context_map.find(input_node);
  267. if (iter != context_map.end()) {
  268. const auto &src_context = iter->second;
  269. GE_IF_BOOL_EXEC(src_context == nullptr, GELOGE(GRAPH_FAILED, "src_context is null."); return nullptr);
  270. GELOGD("node:%s get %ld marks from node:%s", node->GetName().c_str(), src_context->GetMarks().size(),
  271. input_node->GetName().c_str());
  272. for (auto mark : src_context->GetMarks()) {
  273. marks.push_back(mark);
  274. }
  275. auto output_idx = out_anchor->GetIdx();
  276. auto input_idx = in_anchor->GetIdx();
  277. auto output_shape_and_type = src_context->GetOutputHandleShapesAndTypes();
  278. if (output_idx < static_cast<int>(output_shape_and_type.size())) {
  279. GELOGI("Add shape and type from %s:%d to %s:%d", input_node->GetName().c_str(), output_idx,
  280. node->GetName().c_str(), input_idx);
  281. input_shapes_and_types[input_idx] = output_shape_and_type[output_idx];
  282. has_input_shapes_and_types = true;
  283. } else {
  284. GELOGI("[%s] Output out of range. index = %d, size = %zu", node->GetName().c_str(), output_idx,
  285. output_shape_and_type.size());
  286. }
  287. }
  288. }
  289. if (has_input_shapes_and_types) {
  290. inference_context->SetInputHandleShapesAndTypes(std::move(input_shapes_and_types));
  291. }
  292. inference_context->SetMarks(marks);
  293. return inference_context;
  294. }
  295. namespace {
  296. std::unordered_map<NodePtr, InferenceContextPtr> context_map;
  297. }
  298. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus ShapeRefiner::InferShapeAndType(const NodePtr &node) {
  299. return InferShapeAndType(node, true);
  300. }
  301. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus ShapeRefiner::InferShapeAndType(const NodePtr &node,
  302. bool before_subgraph) {
  303. GE_IF_BOOL_EXEC(node == nullptr, GELOGE(GRAPH_FAILED, "node is null."); return GRAPH_FAILED);
  304. if (node->Verify() != GRAPH_SUCCESS) {
  305. GELOGE(GRAPH_FAILED, "Verifying %s failed.", node->GetName().c_str());
  306. return GRAPH_FAILED;
  307. }
  308. auto inference_context = CreateInferenceContext(context_map, node);
  309. if (inference_context == nullptr) {
  310. GELOGE(GRAPH_FAILED, "inference context is null");
  311. return GRAPH_FAILED;
  312. }
  313. GELOGD("create context for node:%s, marks %zu", node->GetName().c_str(), inference_context->GetMarks().size());
  314. PrintInOutTensorShape(node, "before_infershape");
  315. Operator op = OpDescUtils::CreateOperatorFromNode(node);
  316. op.SetInferenceContext(inference_context);
  317. graphStatus status = InferShapeAndType(node, op, before_subgraph);
  318. if (status == GRAPH_PARAM_INVALID || status == GRAPH_SUCCESS) {
  319. (void)ge::NodeUtils::UpdatePeerNodeInputDesc(node);
  320. } else {
  321. GELOGE(GRAPH_FAILED, "%s call infer function failed.", node->GetName().c_str());
  322. return GRAPH_FAILED;
  323. }
  324. auto ctx_after_infer = op.GetInferenceContext();
  325. if (ctx_after_infer != nullptr) {
  326. GELOGD("[%s] after infershape. mark:%zu", node->GetName().c_str(), ctx_after_infer->GetMarks().size());
  327. if (!ctx_after_infer->GetOutputHandleShapesAndTypes().empty() || !ctx_after_infer->GetMarks().empty()) {
  328. GELOGD("[%s] set inference context after. mark:%zu", node->GetName().c_str(), ctx_after_infer->GetMarks().size());
  329. (void)context_map.emplace(node, ctx_after_infer);
  330. }
  331. }
  332. PrintInOutTensorShape(node, "after_infershape");
  333. return GRAPH_SUCCESS;
  334. }
  335. } // namespace ge

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示