You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

node_item.cc 13 kB

4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381
  1. /**
  2. * Copyright 2019-2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "node_item.h"
  17. #include <sstream>
  18. #include "common/debug/log.h"
  19. #include "graph/common/omg_util.h"
  20. #include "graph/compute_graph.h"
  21. #include "graph/debug/ge_attr_define.h"
  22. #include "graph/utils/node_utils.h"
  23. #include "hybrid/executor/worker/shape_inference_engine.h"
  24. #include "hybrid/node_executor/node_executor.h"
  25. namespace ge {
  26. namespace hybrid {
  27. namespace {
  28. const char *const kAttrNameOriginalFusionGraph = "_original_fusion_graph";
  29. const char *const kNodeTypeRetVal = "_RetVal";
  30. std::set<std::string> kControlOpTypes{
  31. IF, STATELESSIF, CASE, WHILE, STATELESSWHILE
  32. };
  33. Status ParseInputMapping(Node &node, OpDesc &op_desc, FusedSubgraph &fused_subgraph) {
  34. uint32_t parent_index = 0;
  35. if (!AttrUtils::GetInt(op_desc, ATTR_NAME_PARENT_NODE_INDEX, parent_index)) {
  36. GELOGE(FAILED,
  37. "[%s] Failed to get attr [%s]",
  38. op_desc.GetName().c_str(),
  39. ATTR_NAME_PARENT_NODE_INDEX.c_str());
  40. return FAILED;
  41. }
  42. for (auto &node_and_anchor : node.GetOutDataNodesAndAnchors()) {
  43. auto dst_op_desc = node_and_anchor.first->GetOpDesc();
  44. GE_CHECK_NOTNULL(dst_op_desc);
  45. auto in_idx = node_and_anchor.second->GetIdx();
  46. auto tensor_desc = dst_op_desc->MutableInputDesc(in_idx);
  47. fused_subgraph.input_mapping[static_cast<int>(parent_index)].emplace_back(tensor_desc);
  48. GELOGD("Input[%u] mapped to [%s:%u]", parent_index, dst_op_desc->GetName().c_str(), in_idx);
  49. }
  50. return SUCCESS;
  51. }
  52. Status ParseOutputMapping(const OpDescPtr &op_desc, FusedSubgraph &fused_subgraph) {
  53. uint32_t parent_index = 0;
  54. if (!AttrUtils::GetInt(op_desc, ATTR_NAME_PARENT_NODE_INDEX, parent_index)) {
  55. GELOGE(FAILED,
  56. "[%s] Failed to get attr [%s]",
  57. op_desc->GetName().c_str(),
  58. ATTR_NAME_PARENT_NODE_INDEX.c_str());
  59. return FAILED;
  60. }
  61. fused_subgraph.output_mapping.emplace(static_cast<int>(parent_index), op_desc);
  62. return SUCCESS;
  63. }
  64. Status ParseFusedSubgraph(NodeItem &node_item) {
  65. if (!node_item.op_desc->HasAttr(kAttrNameOriginalFusionGraph)) {
  66. return SUCCESS;
  67. }
  68. GELOGI("[%s] Start to parse fused subgraph.", node_item.node_name.c_str());
  69. auto fused_subgraph = std::unique_ptr<FusedSubgraph>(new(std::nothrow)FusedSubgraph());
  70. GE_CHECK_NOTNULL(fused_subgraph);
  71. ComputeGraphPtr fused_graph;
  72. (void) AttrUtils::GetGraph(*node_item.op_desc, kAttrNameOriginalFusionGraph, fused_graph);
  73. GE_CHECK_NOTNULL(fused_graph);
  74. fused_graph->SetGraphUnknownFlag(true);
  75. fused_subgraph->graph = fused_graph;
  76. GE_CHK_GRAPH_STATUS_RET(fused_graph->TopologicalSorting());
  77. for (auto &node : fused_graph->GetAllNodes()) {
  78. GE_CHECK_NOTNULL(node);
  79. auto op_desc = node->GetOpDesc();
  80. GE_CHECK_NOTNULL(op_desc);
  81. std::string node_type;
  82. GE_CHK_STATUS_RET(GetOriginalType(node, node_type));
  83. if (node_type == DATA) {
  84. GE_CHK_GRAPH_STATUS_RET(ParseInputMapping(*node, *op_desc, *fused_subgraph));
  85. } else if (node_type == kNodeTypeRetVal) {
  86. GE_CHK_GRAPH_STATUS_RET(ParseOutputMapping(op_desc, *fused_subgraph));
  87. } else {
  88. fused_subgraph->nodes.emplace_back(node);
  89. }
  90. }
  91. node_item.fused_subgraph = std::move(fused_subgraph);
  92. GELOGI("[%s] Done parsing fused subgraph successfully.", node_item.NodeName().c_str());
  93. return SUCCESS;
  94. }
  95. } // namespace
  96. bool IsControlOp(const std::string &op_type) {
  97. return kControlOpTypes.count(op_type) > 0;
  98. }
  99. NodeItem::NodeItem(NodePtr node) : node(std::move(node)) {
  100. this->op_desc = this->node->GetOpDesc().get();
  101. this->node_name = this->node->GetName();
  102. this->node_type = this->node->GetType();
  103. }
  104. Status NodeItem::Create(const NodePtr &node, std::unique_ptr<NodeItem> &node_item) {
  105. GE_CHECK_NOTNULL(node);
  106. GE_CHECK_NOTNULL(node->GetOpDesc());
  107. std::unique_ptr<NodeItem> instance(new(std::nothrow)NodeItem(node));
  108. GE_CHECK_NOTNULL(instance);
  109. GE_CHK_STATUS_RET(instance->Init(), "Failed to init NodeItem [%s] .", node->GetName().c_str());
  110. node_item = std::move(instance);
  111. return SUCCESS;
  112. }
  113. void NodeItem::ResolveOptionalInputs() {
  114. if (op_desc->GetAllInputsSize() != op_desc->GetInputsSize()) {
  115. has_optional_inputs = true;
  116. for (size_t i = 0; i < op_desc->GetAllInputsSize(); ++i) {
  117. const auto &input_desc = op_desc->MutableInputDesc(i);
  118. if (input_desc == nullptr) {
  119. GELOGD("[%s] Input[%zu] is optional and invalid", NodeName().c_str(), i);
  120. } else {
  121. input_desc_indices_.emplace_back(static_cast<uint32_t>(i));
  122. }
  123. }
  124. }
  125. }
  126. Status NodeItem::InitInputsAndOutputs() {
  127. GE_CHECK_LE(op_desc->GetInputsSize(), INT32_MAX);
  128. GE_CHECK_LE(op_desc->GetOutputsSize(), INT32_MAX);
  129. num_inputs = static_cast<int>(op_desc->GetInputsSize());
  130. num_outputs = static_cast<int>(op_desc->GetOutputsSize());
  131. if (AttrUtils::GetInt(op_desc, ::ge::ATTR_STAGE_LEVEL, group)) {
  132. GELOGD("[%s] Got stage level from op_desc = %d", op_desc->GetName().c_str(), group);
  133. } else {
  134. if (node->GetOwnerComputeGraph() != nullptr) {
  135. if (AttrUtils::GetInt(node->GetOwnerComputeGraph(), ::ge::ATTR_STAGE_LEVEL, group)) {
  136. GELOGD("[%s] Got stage level from parent graph = %d", op_desc->GetName().c_str(), group);
  137. } else {
  138. auto parent_node = node->GetOwnerComputeGraph()->GetParentNode();
  139. if ((parent_node != nullptr) && (AttrUtils::GetInt(parent_node->GetOpDesc(), ::ge::ATTR_STAGE_LEVEL, group))) {
  140. GELOGD("[%s] Got stage level from parent node = %d", op_desc->GetName().c_str(), group);
  141. } else {
  142. GELOGD("[%s] Node do not set stage level", op_desc->GetName().c_str());
  143. }
  144. }
  145. }
  146. }
  147. ResolveOptionalInputs();
  148. return SUCCESS;
  149. }
  150. Status NodeItem::ResolveDynamicState() {
  151. (void) AttrUtils::GetBool(op_desc, ATTR_NAME_FORCE_UNKNOWN_SHAPE, is_dynamic);
  152. GELOGD("node name = %s, is_dynamic = %d.", this->node_name.c_str(), is_dynamic);
  153. if (!is_dynamic) {
  154. GE_CHK_STATUS_RET(NodeUtils::GetNodeUnknownShapeStatus(*node, is_dynamic),
  155. "[%s] Failed to get shape status.",
  156. node->GetName().c_str());
  157. }
  158. return SUCCESS;
  159. }
  160. Status NodeItem::ResolveStaticInputsAndOutputs() {
  161. for (int i = 0; i < num_inputs; ++i) {
  162. // Data has unconnected input but set by framework
  163. if (node_type != DATA) {
  164. int origin_index = i;
  165. if (has_optional_inputs) {
  166. origin_index = input_desc_indices_[i];
  167. }
  168. auto in_data_anchor = node->GetInDataAnchor(origin_index);
  169. GE_CHECK_NOTNULL(in_data_anchor);
  170. // If no node was connected to the current input anchor
  171. // increase num_static_input_shapes in case dead wait in ShapeInferenceState::AwaitShapesReady
  172. if (in_data_anchor->GetPeerOutAnchor() == nullptr ||
  173. in_data_anchor->GetPeerOutAnchor()->GetOwnerNode() == nullptr) {
  174. num_static_input_shapes++;
  175. is_input_shape_static_.push_back(true);
  176. GELOGW("[%s] Peer node of input[%d] is empty", NodeName().c_str(), i);
  177. continue;
  178. }
  179. }
  180. const auto &input_desc = MutableInputDesc(i);
  181. GE_CHECK_NOTNULL(input_desc);
  182. if (input_desc->MutableShape().IsUnknownShape()) {
  183. is_input_shape_static_.push_back(false);
  184. } else {
  185. num_static_input_shapes++;
  186. is_input_shape_static_.push_back(true);
  187. GELOGD("[%s] The shape of input[%d] is static. shape = [%s]",
  188. NodeName().c_str(), i, input_desc->MutableShape().ToString().c_str());
  189. }
  190. }
  191. for (int i = 0; i < num_outputs; ++i) {
  192. const auto &output_desc = op_desc->MutableOutputDesc(i);
  193. GE_CHECK_NOTNULL(output_desc);
  194. if (output_desc->MutableShape().IsUnknownShape()) {
  195. is_output_shape_static = false;
  196. break;
  197. }
  198. }
  199. if (is_output_shape_static) {
  200. GE_CHK_STATUS_RET_NOLOG(ShapeInferenceEngine::CalcOutputTensorSizes(*this));
  201. }
  202. return SUCCESS;
  203. }
  204. void NodeItem::ResolveUnknownShapeType() {
  205. if (IsControlOp() || node_type == PARTITIONEDCALL) {
  206. shape_inference_type = DEPEND_COMPUTE;
  207. } else {
  208. int32_t unknown_shape_type_val = 0;
  209. (void) AttrUtils::GetInt(op_desc, ::ge::ATTR_NAME_UNKNOWN_SHAPE_TYPE, unknown_shape_type_val);
  210. shape_inference_type = static_cast<UnknowShapeOpType>(unknown_shape_type_val);
  211. }
  212. }
  213. Status NodeItem::Init() {
  214. GE_CHK_STATUS_RET_NOLOG(InitInputsAndOutputs());
  215. GE_CHK_STATUS_RET_NOLOG(ResolveDynamicState());
  216. ResolveUnknownShapeType();
  217. if (is_dynamic) {
  218. GE_CHK_STATUS_RET_NOLOG(ResolveStaticInputsAndOutputs());
  219. GE_CHK_STATUS_RET(ParseFusedSubgraph(*this), "[%s] Failed to parse fused subgraph", node_name.c_str());
  220. }
  221. return SUCCESS;
  222. }
  223. bool NodeItem::IsControlOp() const {
  224. return ge::hybrid::IsControlOp(op_desc->GetType());
  225. }
  226. bool NodeItem::IsHcclOp() const {
  227. return NodeExecutorManager::GetInstance().ResolveExecutorType(*node) == NodeExecutorManager::ExecutorType::HCCL;
  228. }
  229. std::string NodeItem::DebugString() const {
  230. std::stringstream ss;
  231. ss << "Node: ";
  232. ss << "id = " << node_id;
  233. ss << ", name = [" << node->GetName();
  234. ss << "], type = " << node->GetType();
  235. ss << ", is_dynamic = " << (is_dynamic ? "True" : "False");
  236. ss << ", is_output_static = " << (is_output_shape_static ? "True" : "False");
  237. ss << ", unknown_shape_op_type = " << shape_inference_type;
  238. ss << ", stage = " << group;
  239. ss << ", input_start = " << input_start;
  240. ss << ", num_inputs = " << num_inputs;
  241. ss << ", output_start = " << output_start;
  242. ss << ", num_outputs = " << num_outputs;
  243. ss << ", dependent_nodes = [";
  244. for (const auto &dep_node : dependents_for_shape_inference) {
  245. ss << dep_node->GetName() << ", ";
  246. }
  247. ss << "]";
  248. int index = 0;
  249. for (auto &items : outputs) {
  250. ss << ", output[" << index++ << "]: ";
  251. for (auto &item : items) {
  252. ss << "(" << item.second->NodeName() << ":" << item.first << "), ";
  253. }
  254. }
  255. return ss.str();
  256. }
  257. void NodeItem::SetToDynamic() {
  258. num_static_input_shapes = 0;
  259. is_dynamic = true;
  260. for (size_t i = 0; i < is_input_shape_static_.size(); ++i) {
  261. is_input_shape_static_[i] = false;
  262. }
  263. if (kernel_task != nullptr && !kernel_task->IsSupportDynamicShape()) {
  264. GELOGD("[%s] Dynamic shape is not supported, clear node task.", node_name.c_str());
  265. kernel_task = nullptr;
  266. }
  267. }
  268. GeTensorDescPtr NodeItem::DoGetInputDesc(int index) const {
  269. if (!has_optional_inputs) {
  270. return op_desc->MutableInputDesc(static_cast<uint32_t>(index));
  271. }
  272. if (index < 0 || index >= num_inputs) {
  273. GELOGE(PARAM_INVALID,
  274. "[%s] Invalid input index, num inputs = %d, index = %d",
  275. node_name.c_str(),
  276. num_inputs,
  277. index);
  278. return nullptr;
  279. }
  280. return op_desc->MutableInputDesc(input_desc_indices_[index]);
  281. }
  282. GeTensorDescPtr NodeItem::MutableInputDesc(int index) const {
  283. std::lock_guard<std::mutex> lk(mu_);
  284. return DoGetInputDesc(index);
  285. }
  286. Status NodeItem::GetInputDesc(int index, GeTensorDesc &tensor_desc) const {
  287. std::lock_guard<std::mutex> lk(mu_);
  288. auto input_desc = DoGetInputDesc(index);
  289. GE_CHECK_NOTNULL(input_desc);
  290. tensor_desc = *input_desc;
  291. return SUCCESS;
  292. }
  293. Status NodeItem::GetOutputDesc(int index, GeTensorDesc &tensor_desc) const {
  294. std::lock_guard<std::mutex> lk(mu_);
  295. auto output_desc = op_desc->MutableOutputDesc(static_cast<uint32_t>(index));
  296. GE_CHECK_NOTNULL(output_desc);
  297. tensor_desc = *output_desc;
  298. return SUCCESS;
  299. }
  300. GeTensorDescPtr NodeItem::MutableOutputDesc(int index) const {
  301. std::lock_guard<std::mutex> lk(mu_);
  302. return op_desc->MutableOutputDesc(static_cast<uint32_t>(index));
  303. }
  304. Status NodeItem::UpdateInputDesc(int index, const GeTensorDesc &tensor_desc) {
  305. std::lock_guard<std::mutex> lk(mu_);
  306. auto input_desc = DoGetInputDesc(index);
  307. GE_CHECK_NOTNULL(input_desc);
  308. *input_desc = tensor_desc;
  309. return SUCCESS;
  310. }
  311. Status NodeItem::GetCanonicalInputIndex(uint32_t index, int &canonical_index) const {
  312. if (!has_optional_inputs) {
  313. canonical_index = index;
  314. return SUCCESS;
  315. }
  316. auto iter = std::find(input_desc_indices_.begin(), input_desc_indices_.end(), index);
  317. if (iter == input_desc_indices_.end()) {
  318. GELOGE(INTERNAL_ERROR, "[%s] Invalid input index: %u", node_name.c_str(), index);
  319. return INTERNAL_ERROR;
  320. }
  321. canonical_index = static_cast<int>(iter - input_desc_indices_.begin());
  322. GELOGD("[%s] Canonicalize input index from [%u] to [%d]", node_name.c_str(), index, canonical_index);
  323. return SUCCESS;
  324. }
  325. bool NodeItem::IsInputShapeStatic(int index) const {
  326. if (!is_dynamic) {
  327. return true;
  328. }
  329. if (static_cast<size_t>(index) >= is_input_shape_static_.size()) {
  330. GELOGE(PARAM_INVALID, "Input index(%d) out of range: [0, %zu)", index, is_input_shape_static_.size());
  331. return false;
  332. }
  333. return is_input_shape_static_[index];
  334. }
  335. } // namespace hybrid
  336. } // namespace ge

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示