You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

shape_refiner.cc 15 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376
  1. /**
  2. * Copyright 2019-2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "graph/shape_refiner.h"
  17. #include <memory>
  18. #include <string>
  19. #include <unordered_map>
  20. #include <utility>
  21. #include <vector>
  22. #include "graph/debug/ge_attr_define.h"
  23. #include "graph/utils/graph_utils.h"
  24. #include "debug/ge_log.h"
  25. #include "debug/ge_op_types.h"
  26. #include "external/graph/operator.h"
  27. #include "external/graph/operator_factory.h"
  28. #include "framework/common/debug/ge_log.h"
  29. #include "graph/compute_graph.h"
  30. #include "utils/node_utils.h"
  31. #include "utils/op_desc_utils.h"
  32. #include "utils/tensor_utils.h"
  33. #include "utils/type_utils.h"
  34. namespace ge {
  35. namespace {
  36. graphStatus UpdateSubGraphDataNodes(const ConstNodePtr &node) {
  37. auto op_desc = node->GetOpDesc();
  38. auto sub_graph_names = op_desc->GetSubgraphInstanceNames();
  39. if (sub_graph_names.empty()) {
  40. return GRAPH_SUCCESS;
  41. }
  42. auto root_graph = GraphUtils::FindRootGraph(node->GetOwnerComputeGraph());
  43. for (const auto &name : sub_graph_names) {
  44. if (name.empty()) {
  45. GELOGW("The node %s contains empty subgraph instance name", node->GetName().c_str());
  46. continue;
  47. }
  48. auto sub_graph = root_graph->GetSubgraph(name);
  49. if (sub_graph == nullptr) {
  50. GE_LOGE("Can node find the subgrpah %s for node %s", name.c_str(), node->GetName().c_str());
  51. return GRAPH_FAILED;
  52. }
  53. for (const auto &node_sub : sub_graph->GetDirectNode()) {
  54. if (node_sub->GetType() != DATA) {
  55. continue;
  56. }
  57. int ref_i;
  58. auto data_opdesc = node_sub->GetOpDesc();
  59. if (data_opdesc == nullptr) {
  60. GE_LOGE("Invalid data node on the sub graph %s parent node %s, no OpDesc", name.c_str(),
  61. node->GetName().c_str());
  62. return GRAPH_FAILED;
  63. }
  64. if (!AttrUtils::GetInt(node_sub->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, ref_i)) {
  65. GE_LOGE("Invalid data node on the sub graph %s parent node %s, no ref-index attribute", name.c_str(),
  66. node->GetName().c_str());
  67. return GRAPH_FAILED;
  68. }
  69. auto input_desc = op_desc->MutableInputDesc(ref_i);
  70. if (input_desc == nullptr) {
  71. GE_LOGE(
  72. "The ref index(%d) on the data %s on the sub graph %s "
  73. "parent node %s are incompatible, inputs num %u",
  74. ref_i, node_sub->GetName().c_str(), name.c_str(), node->GetName().c_str(), node->GetAllOutDataAnchorsSize());
  75. return GRAPH_FAILED;
  76. }
  77. GELOGI("Ref index is %d, input_desc dtype is %d, node name is %s", ref_i, input_desc->GetDataType(),
  78. node->GetName().c_str());
  79. auto ret = data_opdesc->UpdateInputDesc(0, *input_desc);
  80. if (ret != GRAPH_SUCCESS) {
  81. GE_LOGE("Failed to update input desc of data %s on the sub graph %s parent node %s",
  82. node_sub->GetName().c_str(), name.c_str(), node->GetName().c_str());
  83. return ret;
  84. }
  85. ret = data_opdesc->UpdateOutputDesc(0, *input_desc);
  86. if (ret != GRAPH_SUCCESS) {
  87. GE_LOGE("Failed to update output desc of data %s on the sub graph %s parent node %s",
  88. node_sub->GetName().c_str(), name.c_str(), node->GetName().c_str());
  89. return ret;
  90. }
  91. }
  92. }
  93. return GRAPH_SUCCESS;
  94. }
  95. graphStatus UpdateParentNodeOutTensor(const ConstNodePtr &node) {
  96. auto op_desc = node->GetOpDesc();
  97. auto sub_graph_names = op_desc->GetSubgraphInstanceNames();
  98. if (sub_graph_names.empty()) {
  99. return GRAPH_SUCCESS;
  100. }
  101. auto root_graph = GraphUtils::FindRootGraph(node->GetOwnerComputeGraph());
  102. for (const auto &name : sub_graph_names) {
  103. if (name.empty()) {
  104. GELOGW("The node %s contains empty subgraph instance name", node->GetName().c_str());
  105. continue;
  106. }
  107. auto sub_graph = root_graph->GetSubgraph(name);
  108. if (sub_graph == nullptr) {
  109. GE_LOGE("Can node find the subgrpah %s for node %s", name.c_str(), node->GetName().c_str());
  110. return GRAPH_FAILED;
  111. }
  112. NodePtr netoutput = nullptr;
  113. auto sub_nodes = sub_graph->GetDirectNode();
  114. for (size_t i = sub_nodes.size(); i > 0; --i) {
  115. auto sub_node = sub_nodes.at(i - 1);
  116. if (sub_node->GetType() == NETOUTPUT) {
  117. netoutput = sub_node;
  118. break;
  119. }
  120. }
  121. if (netoutput == nullptr) {
  122. GE_LOGE("No NetOutput node on sub graph %s, parent node %s", name.c_str(), node->GetName().c_str());
  123. return GRAPH_FAILED;
  124. }
  125. auto netoutput_opdesc = netoutput->GetOpDesc();
  126. if (netoutput_opdesc == nullptr) {
  127. GE_LOGE("Invalid NetOutput node on sub graph %s, parent node %s, no OpDesc on it", name.c_str(),
  128. node->GetName().c_str());
  129. return GRAPH_FAILED;
  130. }
  131. for (auto &edge_anchor : netoutput->GetAllInDataAnchors()) {
  132. auto edge_desc = netoutput_opdesc->MutableInputDesc(edge_anchor->GetIdx());
  133. if (edge_desc == nullptr) {
  134. GE_LOGE("Invalid NetOutput node on sub graph %s, parent node %s, can not find input tensor %d", name.c_str(),
  135. node->GetName().c_str(), edge_anchor->GetIdx());
  136. return GRAPH_FAILED;
  137. }
  138. GELOGI("Netoutput in anchor index is %zu, input tensor dim is %zu", edge_anchor->GetIdx(),
  139. edge_desc->GetShape().GetDimNum());
  140. int ref_i;
  141. if (!AttrUtils::GetInt(edge_desc, ATTR_NAME_PARENT_NODE_INDEX, ref_i)) {
  142. // if there is no ref index on the TensorDesc, it means the output data will be ignored outer.
  143. continue;
  144. }
  145. GELOGI("Parent node index of edge desc is %d", ref_i);
  146. auto output_desc = op_desc->MutableOutputDesc(static_cast<uint32_t>(ref_i));
  147. if (output_desc == nullptr) {
  148. GE_LOGE(
  149. "The ref index(%d) on the input %d of netoutput %s on the sub graph %s "
  150. "parent node %s are incompatible, outputs num %u",
  151. ref_i, edge_anchor->GetIdx(), netoutput->GetName().c_str(), name.c_str(), node->GetName().c_str(),
  152. node->GetAllOutDataAnchorsSize());
  153. return GRAPH_FAILED;
  154. }
  155. op_desc->UpdateOutputDesc(edge_anchor->GetIdx(), *edge_desc);
  156. }
  157. }
  158. return GRAPH_SUCCESS;
  159. }
  160. } // namespace
  161. void ShapeRefiner::PrintInOutTensorShape(const ge::NodePtr &node, const std::string &phase) {
  162. if (node == nullptr) {
  163. GELOGE(GRAPH_FAILED, "node is null");
  164. return;
  165. }
  166. ge::OpDescPtr op_desc = node->GetOpDesc();
  167. GE_IF_BOOL_EXEC(op_desc == nullptr, GELOGE(GRAPH_FAILED, "op_desc is null."); return );
  168. std::string str;
  169. if (op_desc->GetInputsSize() != 0) {
  170. std::string input_desc_str = "input shape: ";
  171. for (const auto &input_desc : op_desc->GetAllInputsDescPtr()) {
  172. input_desc_str += "[";
  173. for (int64_t dim : input_desc->GetShape().GetDims()) {
  174. input_desc_str += std::to_string(dim) + " ";
  175. }
  176. input_desc_str += "]";
  177. input_desc_str += ":" + TypeUtils::DataTypeToSerialString(input_desc->GetDataType()) + ":" +
  178. TypeUtils::FormatToSerialString(input_desc->GetFormat()) + " ";
  179. }
  180. str += input_desc_str;
  181. }
  182. if (op_desc->GetAllOutputsDescSize() != 0) {
  183. std::string output_desc_str = "output shape: ";
  184. for (const auto &output_desc : op_desc->GetAllOutputsDescPtr()) {
  185. if (output_desc == nullptr) {
  186. continue;
  187. }
  188. output_desc_str += "[";
  189. for (int64_t dim : output_desc->GetShape().GetDims()) {
  190. output_desc_str += std::to_string(dim) + " ";
  191. }
  192. output_desc_str += "]";
  193. output_desc_str += ":" + TypeUtils::DataTypeToSerialString(output_desc->GetDataType()) + ":" +
  194. TypeUtils::FormatToSerialString(output_desc->GetFormat()) + " ";
  195. }
  196. str += output_desc_str;
  197. }
  198. GELOGD("Shape dump [%s], Node name: [%s]. %s", phase.c_str(), node->GetName().c_str(), str.c_str());
  199. }
  200. graphStatus ShapeRefiner::InferShapeAndType(const ConstNodePtr &node, Operator &op) {
  201. return InferShapeAndType(node, op, true);
  202. }
  203. graphStatus ShapeRefiner::InferShapeAndType(const ConstNodePtr &node, Operator &op, bool before_subgraph) {
  204. GE_IF_BOOL_EXEC(node == nullptr, GELOGE(GRAPH_FAILED, "node is null."); return GRAPH_FAILED);
  205. auto op_desc = node->GetOpDesc();
  206. GE_IF_BOOL_EXEC(op_desc == nullptr, GELOGE(GRAPH_FAILED, "op_desc is null."); return GRAPH_FAILED);
  207. const auto &op_type = op_desc->GetType();
  208. graphStatus ret;
  209. if (before_subgraph) {
  210. ret = UpdateSubGraphDataNodes(node);
  211. if (ret != GRAPH_SUCCESS) {
  212. return ret;
  213. }
  214. }
  215. // Get infer func and execute
  216. ret = op_desc->CallInferFunc(op);
  217. if (ret == GRAPH_PARAM_INVALID) {
  218. // Op ir no infer func, try to get infer func from operator factory
  219. auto node_op = ge::OperatorFactory::CreateOperator("node_op", op_desc->GetType());
  220. if (node_op.IsEmpty()) {
  221. GELOGW("get op from OperatorFactory fail. opType: %s", op_type.c_str());
  222. return ret;
  223. }
  224. GELOGD("get op from OperatorFactory success. opType: %s", op_type.c_str());
  225. auto temp_op_desc = ge::OpDescUtils::GetOpDescFromOperator(node_op);
  226. node_op.BreakConnect();
  227. if (temp_op_desc == nullptr) {
  228. GELOGE(GRAPH_FAILED, "temp op desc is null");
  229. return GRAPH_FAILED;
  230. }
  231. if (!op_desc->UpdateInputName(temp_op_desc->GetAllInputName())) {
  232. GELOGW("InferShapeAndType UpdateInputName failed");
  233. for (const auto &out_desc : op_desc->GetAllOutputsDescPtr()) {
  234. if (out_desc != nullptr && out_desc->GetShape().GetDims().empty()) {
  235. break;
  236. }
  237. return GRAPH_SUCCESS;
  238. }
  239. }
  240. if (!op_desc->UpdateOutputName(temp_op_desc->GetAllOutputName())) {
  241. GELOGW("InferShapeAndType UpdateOutputName failed");
  242. }
  243. op_desc->AddInferFunc(temp_op_desc->GetInferFunc());
  244. ret = op_desc->CallInferFunc(op);
  245. GELOGI("op CallInferFunc second. ret: %u", ret);
  246. }
  247. if (ret != GRAPH_SUCCESS) {
  248. return ret;
  249. }
  250. if (!before_subgraph) {
  251. return UpdateParentNodeOutTensor(node);
  252. }
  253. return GRAPH_SUCCESS;
  254. }
  255. InferenceContextPtr CreateInferenceContext(const std::unordered_map<NodePtr, InferenceContextPtr> &context_map,
  256. const NodePtr &node) {
  257. if (node == nullptr) {
  258. GELOGE(GRAPH_FAILED, "node is null");
  259. return nullptr;
  260. }
  261. InferenceContextPtr inference_context = std::shared_ptr<InferenceContext>(InferenceContext::Create());
  262. if (inference_context == nullptr) {
  263. GELOGE(GRAPH_FAILED, "Failed to alloc InferenceContext");
  264. return nullptr;
  265. }
  266. auto all_in_data_anchors = node->GetAllInDataAnchors();
  267. std::vector<std::vector<ShapeAndType>> input_shapes_and_types(all_in_data_anchors.size());
  268. std::vector<std::string> marks;
  269. bool has_input_shapes_and_types = false;
  270. for (const auto &in_anchor : all_in_data_anchors) {
  271. const auto &out_anchor = in_anchor->GetPeerOutAnchor();
  272. if (out_anchor == nullptr) {
  273. continue;
  274. }
  275. auto input_node = out_anchor->GetOwnerNode();
  276. if (input_node == nullptr) {
  277. continue;
  278. }
  279. auto iter = context_map.find(input_node);
  280. if (iter != context_map.end()) {
  281. const auto &src_context = iter->second;
  282. GE_IF_BOOL_EXEC(src_context == nullptr, GELOGE(GRAPH_FAILED, "src_context is null."); return nullptr);
  283. GELOGD("node:%s get %ld marks from node:%s", node->GetName().c_str(), src_context->GetMarks().size(),
  284. input_node->GetName().c_str());
  285. for (auto mark : src_context->GetMarks()) {
  286. marks.push_back(mark);
  287. }
  288. auto output_idx = out_anchor->GetIdx();
  289. auto input_idx = in_anchor->GetIdx();
  290. auto output_shape_and_type = src_context->GetOutputHandleShapesAndTypes();
  291. if (output_idx < static_cast<int>(output_shape_and_type.size())) {
  292. GELOGI("Add shape and type from %s:%d to %s:%d", input_node->GetName().c_str(), output_idx,
  293. node->GetName().c_str(), input_idx);
  294. input_shapes_and_types[input_idx] = output_shape_and_type[output_idx];
  295. has_input_shapes_and_types = true;
  296. } else {
  297. GELOGI("[%s] Output out of range. index = %d, size = %zu", node->GetName().c_str(), output_idx,
  298. output_shape_and_type.size());
  299. }
  300. }
  301. }
  302. if (has_input_shapes_and_types) {
  303. inference_context->SetInputHandleShapesAndTypes(std::move(input_shapes_and_types));
  304. }
  305. inference_context->SetMarks(marks);
  306. return inference_context;
  307. }
  308. namespace {
  309. std::unordered_map<NodePtr, InferenceContextPtr> context_map;
  310. }
  311. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus ShapeRefiner::InferShapeAndType(const NodePtr &node) {
  312. return InferShapeAndType(node, true);
  313. }
  314. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus ShapeRefiner::InferShapeAndType(const NodePtr &node,
  315. bool before_subgraph) {
  316. GE_IF_BOOL_EXEC(node == nullptr, GELOGE(GRAPH_FAILED, "node is null."); return GRAPH_FAILED);
  317. if (node->Verify() != GRAPH_SUCCESS) {
  318. GELOGE(GRAPH_FAILED, "Verifying %s failed.", node->GetName().c_str());
  319. return GRAPH_FAILED;
  320. }
  321. auto inference_context = CreateInferenceContext(context_map, node);
  322. if (inference_context == nullptr) {
  323. GELOGE(GRAPH_FAILED, "inference context is null");
  324. return GRAPH_FAILED;
  325. }
  326. GELOGD("create context for node:%s, marks %zu", node->GetName().c_str(), inference_context->GetMarks().size());
  327. PrintInOutTensorShape(node, "before_infershape");
  328. Operator op = OpDescUtils::CreateOperatorFromNode(node);
  329. op.SetInferenceContext(inference_context);
  330. graphStatus status = InferShapeAndType(node, op, before_subgraph);
  331. if (status == GRAPH_PARAM_INVALID || status == GRAPH_SUCCESS) {
  332. (void)ge::NodeUtils::UpdatePeerNodeInputDesc(node);
  333. } else {
  334. GELOGE(GRAPH_FAILED, "%s call infer function failed.", node->GetName().c_str());
  335. return GRAPH_FAILED;
  336. }
  337. auto ctx_after_infer = op.GetInferenceContext();
  338. if (ctx_after_infer != nullptr) {
  339. GELOGD("[%s] after infershape. mark:%zu", node->GetName().c_str(), ctx_after_infer->GetMarks().size());
  340. if (!ctx_after_infer->GetOutputHandleShapesAndTypes().empty() || !ctx_after_infer->GetMarks().empty()) {
  341. GELOGD("[%s] set inference context after. mark:%zu", node->GetName().c_str(), ctx_after_infer->GetMarks().size());
  342. (void)context_map.emplace(node, ctx_after_infer);
  343. }
  344. }
  345. PrintInOutTensorShape(node, "after_infershape");
  346. return GRAPH_SUCCESS;
  347. }
  348. } // namespace ge

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示