You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

variable_op_pass_unittest.cc 47 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago

  1. /**
  2. * Copyright 2019-2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include <gtest/gtest.h>
  17. #include <memory>
  18. #include <mutex>
  19. #include <thread>
  20. #include <vector>
  21. #include "common/types.h"
  22. #define protected public
  23. #define private public
  24. #include "graph/passes/variable_op_pass.h"
  25. #include "common/op/ge_op_utils.h"
  26. #include "graph/utils/op_desc_utils.h"
  27. #include "graph/utils/attr_utils.h"
  28. #include "graph/utils/graph_utils.h"
  29. #include "graph/op_desc.h"
  30. #include "graph/types.h"
  31. #include "graph/manager/graph_context.h"
  32. #include "graph/optimize/graph_optimize.h"
  33. #include "graph/manager/util/variable_accelerate_ctrl.h"
  34. #include "graph/manager/graph_mem_manager.h"
  35. #include "graph/manager/graph_var_manager.h"
  36. #include "graph_builder_utils.h"
  37. #include "cce/dnn.h"
  38. #include "cce/dnn_struct_base.hpp"
  39. #include "common/formats/format_transfers/format_transfer_nchw_nc1hwc0.h"
  40. #include "common/formats/format_transfers/format_transfer_nhwc_nc1hwc0.h"
  41. #include "common/formats/format_transfers/datatype_transfer.h"
  42. #undef private
  43. #undef protected
  44. using namespace std;
  45. using namespace ge;
  46. using namespace cce;
  47. class UtestVariableOpPassUnit : public testing::Test {
  48. protected:
  49. void SetUp() {}
  50. void TearDown() {}
  51. // AUTO GEN PLEASE DO NOT MODIFY IT
  52. };
  53. namespace {
  54. /// c
  55. /// var1ref1 --> netoutput1
  56. /// \ /
  57. /// transdata2
  58. /// |
  59. /// assign1
  60. /// / \.
  61. /// transdata1 |
  62. /// | |
  63. /// var1 const1
  64. ComputeGraphPtr BuildGraph1() {
  65. auto builder = ut::GraphBuilder("g1");
  66. auto var1 = builder.AddNode("var1", "Variable", 0, 1);
  67. auto const1 =
  68. builder.AddNode("const1", "Const", 0, 1, FORMAT_NC1HWC0, DT_FLOAT, std::vector<int64_t>({1, 1, 224, 224, 16}));
  69. auto transdata1 = builder.AddNode("transdata1", "TransData", 1, 1, FORMAT_NC1HWC0, DT_FLOAT,
  70. std::vector<int64_t>({1, 1, 224, 224, 16}));
  71. transdata1->GetOpDesc()->MutableInputDesc(0)->SetFormat(FORMAT_NCHW);
  72. transdata1->GetOpDesc()->MutableInputDesc(0)->SetShape(GeShape(std::vector<int64_t>({1, 3, 224, 224})));
  73. auto assign1 =
  74. builder.AddNode("assign1", "Assign", 2, 1, FORMAT_NC1HWC0, DT_FLOAT, std::vector<int64_t>({1, 1, 224, 224, 16}));
  75. auto transdata2 = builder.AddNode("transdata2", "TransData", 1, 1, FORMAT_NC1HWC0, DT_FLOAT,
  76. std::vector<int64_t>({1, 1, 224, 224, 16}));
  77. transdata2->GetOpDesc()->MutableOutputDesc(0)->SetFormat(FORMAT_NCHW);
  78. transdata2->GetOpDesc()->MutableOutputDesc(0)->SetShape(GeShape(std::vector<int64_t>({1, 3, 224, 224})));
  79. auto var1ref1 = builder.AddNode("var1ref1", "Variable", 1, 0);
  80. AttrUtils::SetStr(var1ref1->GetOpDesc(), REF_VAR_SRC_VAR_NAME, "var1");
  81. auto netoutput1 = builder.AddNode("netoutput1", "Netoutput", 2, 0);
  82. builder.AddDataEdge(var1, 0, transdata1, 0);
  83. builder.AddDataEdge(const1, 0, assign1, 1);
  84. builder.AddDataEdge(transdata1, 0, assign1, 0);
  85. builder.AddDataEdge(assign1, 0, transdata2, 0);
  86. builder.AddDataEdge(transdata2, 0, var1ref1, 0);
  87. builder.AddDataEdge(transdata2, 0, netoutput1, 0);
  88. builder.AddControlEdge(var1ref1, netoutput1);
  89. return builder.GetGraph();
  90. }
  91. /// conv1
  92. /// |
  93. /// reshape1
  94. /// |
  95. /// var1
  96. ComputeGraphPtr BuildGraph2() {
  97. auto builder = ut::GraphBuilder("g1");
  98. auto var1 = builder.AddNode("var1", "Variable", 0, 1, FORMAT_ND, DT_FLOAT, std::vector<int64_t>({8 * 8 * 3, 2}));
  99. auto reshape1 =
  100. builder.AddNode("reshape1", "Reshape", 2, 1, FORMAT_HWCN, DT_FLOAT, std::vector<int64_t>({8, 8, 3, 2}));
  101. reshape1->GetOpDesc()->MutableInputDesc(0)->SetFormat(FORMAT_ND);
  102. reshape1->GetOpDesc()->MutableInputDesc(0)->SetShape(GeShape(std::vector<int64_t>({8 * 8 * 3, 2})));
  103. auto conv1 = builder.AddNode("conv1", "Conv2D", 2, 1, FORMAT_HWCN, DT_FLOAT, std::vector<int64_t>({8, 8, 3, 2}));
  104. builder.AddDataEdge(var1, 0, reshape1, 0);
  105. builder.AddDataEdge(reshape1, 0, conv1, 1);
  106. return builder.GetGraph();
  107. }
  108. /// conv1
  109. /// |
  110. /// reformat1
  111. /// |
  112. /// var1
  113. ComputeGraphPtr BuildGraph3() {
  114. auto builder = ut::GraphBuilder("g1");
  115. auto var1 = builder.AddNode("var1", "Variable", 0, 1, FORMAT_NCHW, DT_FLOAT, std::vector<int64_t>({8, 8, 3, 2}));
  116. auto reformat1 =
  117. builder.AddNode("reformat1", "ReFormat", 1, 1, FORMAT_ND, DT_FLOAT, std::vector<int64_t>({8, 8, 3, 2}));
  118. reformat1->GetOpDesc()->MutableInputDesc(0)->SetFormat(FORMAT_NCHW);
  119. reformat1->GetOpDesc()->MutableInputDesc(0)->SetShape(GeShape(std::vector<int64_t>({8, 8, 3, 2})));
  120. auto conv1 = builder.AddNode("conv1", "Conv2D", 2, 1, FORMAT_ND, DT_FLOAT, std::vector<int64_t>({8, 8, 3, 2}));
  121. builder.AddDataEdge(var1, 0, reformat1, 0);
  122. builder.AddDataEdge(reformat1, 0, conv1, 1);
  123. return builder.GetGraph();
  124. }
  125. class NodeBuilder {
  126. public:
  127. NodeBuilder(const std::string &name, const std::string &type) { op_desc_ = std::make_shared<OpDesc>(name, type); }
  128. NodeBuilder &AddInputDesc(std::initializer_list<int64_t> shape, ge::Format format = FORMAT_NCHW,
  129. ge::DataType data_type = DT_FLOAT) {
  130. op_desc_->AddInputDesc(CreateTensorDesc(shape, format, data_type)->Clone());
  131. return *this;
  132. }
  133. NodeBuilder &AddOutputDesc(std::initializer_list<int64_t> shape, ge::Format format = FORMAT_NCHW,
  134. ge::DataType data_type = DT_FLOAT) {
  135. op_desc_->AddOutputDesc(CreateTensorDesc(shape, format, data_type)->Clone());
  136. return *this;
  137. }
  138. ge::NodePtr Build(const ge::ComputeGraphPtr &graph) { return graph->AddNode(op_desc_); }
  139. private:
  140. ge::GeTensorDescPtr CreateTensorDesc(std::initializer_list<int64_t> shape, ge::Format format = FORMAT_NCHW,
  141. ge::DataType data_type = DT_FLOAT) {
  142. GeShape ge_shape{std::vector<int64_t>(shape)};
  143. ge::GeTensorDescPtr tensor_desc = std::make_shared<ge::GeTensorDesc>();
  144. tensor_desc->SetShape(ge_shape);
  145. tensor_desc->SetFormat(format);
  146. tensor_desc->SetDataType(data_type);
  147. return tensor_desc;
  148. }
  149. ge::OpDescPtr op_desc_;
  150. };
  151. std::string var_ref_name_0;
  152. ge::NodePtr CreatVariableRef(ge::NodePtr &final_writable_node, ge::NodePtr &var_node) {
  153. GELOGI("Create VarRef Op: final_writable_node: [%s] var_node: [%s]>>>>", final_writable_node->GetName().c_str(),
  154. var_node->GetName().c_str());
  155. static uint32_t var_ref_count = 0;
  156. std::stringstream var_ref_name;
  157. var_ref_name << "_to_" << final_writable_node->GetName() << "_REF_" << var_ref_count++;
  158. OpDescPtr var_op_desc = var_node->GetOpDesc();
  159. GE_CHK_BOOL_EXEC(var_op_desc != nullptr, return nullptr, "get var opdesc is nullptr");
  160. OpDescPtr var_ref_op_desc = nullptr;
  161. GE_MAKE_SHARED(var_ref_op_desc =
  162. std::make_shared<OpDesc>(var_node->GetName() + var_ref_name.str().c_str(), var_op_desc->GetType()),
  163. return nullptr);
  164. var_ref_op_desc->AddOutputDesc(var_op_desc->GetOutputDesc(0));
  165. var_ref_op_desc->AddInputDesc(var_op_desc->GetOutputDesc(0));
  166. const map<string, ge::GeAttrValue> var_attr_value = var_op_desc->GetAllAttrs();
  167. for (auto const &attrIt : var_attr_value) {
  168. var_ref_op_desc->SetAttr(attrIt.first, attrIt.second);
  169. }
  170. NodePtr var_ref_node = var_node->GetOwnerComputeGraph()->AddNode(var_ref_op_desc);
  171. GE_CHK_BOOL_EXEC(var_ref_node != nullptr, return nullptr, "create var_REF_node failed")
  172. GE_IF_BOOL_EXEC(ge::AttrUtils::SetStr(var_ref_op_desc, REF_VAR_SRC_VAR_NAME, var_op_desc->GetName()),
  173. GELOGI("Set node [%s] VAR_ATTR_VAR_IS_REF [%s]", var_ref_node->GetName().c_str(),
  174. var_op_desc->GetName().c_str()));
  175. var_ref_name_0 = var_ref_node->GetName();
  176. return var_ref_node;
  177. }
  178. bool BuildComputeGraph0(ge::ComputeGraphPtr &graph) {
  179. // graph = std::make_shared<ComputeGraph>("test");
  180. ge::NodePtr node_4d_new =
  181. NodeBuilder("Node4D_new", VARIABLE).AddOutputDesc({1, 2, 3, 4}, FORMAT_NCHW, DT_INT32).Build(graph);
  182. ge::NodePtr node_4d_to_5d_1_new = NodeBuilder("4d_to_5d_1_new", TRANSDATA)
  183. .AddInputDesc({1, 2, 3, 4}, FORMAT_NCHW, DT_INT32)
  184. .AddOutputDesc({1, 2, 3, 4, 5}, FORMAT_NC1HWC0, DT_FLOAT)
  185. .Build(graph);
  186. ge::NodePtr node_4d_to_5d_2_new = NodeBuilder("4d_to_5d_2_new", TRANSDATA)
  187. .AddInputDesc({1, 2, 3, 4}, FORMAT_NCHW, DT_INT32)
  188. .AddOutputDesc({1, 2, 3, 4, 5}, FORMAT_NC1HWC0, DT_INT32)
  189. .Build(graph);
  190. ge::GraphUtils::AddEdge(node_4d_new->GetOutDataAnchor(0), node_4d_to_5d_1_new->GetInDataAnchor(0));
  191. ge::GraphUtils::AddEdge(node_4d_new->GetOutDataAnchor(0), node_4d_to_5d_2_new->GetInDataAnchor(0));
  192. // Node4D
  193. ge::NodePtr node_4d =
  194. NodeBuilder("Node4D", VARIABLE).AddOutputDesc({1, 2, 3, 4}, FORMAT_NCHW, DT_INT32).Build(graph);
  195. // NodeTrans4DTo5D
  196. ge::NodePtr node_4d_to_5d_1 = NodeBuilder("4d_to_5d_1", TRANSDATA)
  197. .AddInputDesc({1, 2, 3, 4}, FORMAT_NCHW, DT_INT32)
  198. .AddOutputDesc({1, 2, 3, 4, 5}, FORMAT_NC1HWC0, DT_FLOAT)
  199. .Build(graph);
  200. ge::NodePtr node_4d_to_5d_2 = NodeBuilder("4d_to_5d_2", TRANSDATA)
  201. .AddInputDesc({1, 2, 3, 4}, FORMAT_NCHW, DT_INT32)
  202. .AddOutputDesc({1, 2, 3, 4, 5}, FORMAT_NC1HWC0, DT_FLOAT)
  203. .Build(graph);
  204. // Node5D
  205. ge::NodePtr node_5d_1 =
  206. NodeBuilder("5D_1", RELU).AddInputDesc({1, 2, 3, 4, 5}, FORMAT_NC1HWC0, DT_FLOAT).Build(graph);
  207. ge::NodePtr node_5d_2 =
  208. NodeBuilder("5D_2", RELU).AddInputDesc({1, 2, 3, 4, 5}, FORMAT_NC1HWC0, DT_FLOAT).Build(graph);
  209. // add edge
  210. ge::GraphUtils::AddEdge(node_4d->GetOutDataAnchor(0), node_4d_to_5d_1->GetInDataAnchor(0));
  211. ge::GraphUtils::AddEdge(node_4d->GetOutDataAnchor(0), node_4d_to_5d_2->GetInDataAnchor(0));
  212. ge::GraphUtils::AddEdge(node_4d_to_5d_1->GetOutDataAnchor(0), node_5d_1->GetInDataAnchor(0));
  213. ge::GraphUtils::AddEdge(node_4d_to_5d_2->GetOutDataAnchor(0), node_5d_2->GetInDataAnchor(0));
  214. // Node4D
  215. ge::NodePtr node_4d_nhwc =
  216. NodeBuilder("Node4D_NHWC", VARIABLE).AddOutputDesc({1, 2, 3, 4}, FORMAT_NHWC, DT_INT32).Build(graph);
  217. // NodeTrans4DTo5D
  218. ge::NodePtr node_4d_to_5d_1_nhwc = NodeBuilder("4d_to_5d_1_NHWC", TRANSDATA)
  219. .AddInputDesc({1, 2, 3, 4}, FORMAT_NHWC, DT_INT32)
  220. .AddOutputDesc({1, 2, 3, 4, 5}, FORMAT_NC1HWC0, DT_FLOAT)
  221. .Build(graph);
  222. // Node5D
  223. ge::NodePtr node_5d_1_nhwc =
  224. NodeBuilder("5D_1_NHWC", RELU).AddInputDesc({1, 2, 3, 4, 5}, FORMAT_NC1HWC0, DT_FLOAT).Build(graph);
  225. // add edge
  226. ge::GraphUtils::AddEdge(node_4d_nhwc->GetOutDataAnchor(0), node_4d_to_5d_1_nhwc->GetInDataAnchor(0));
  227. ge::GraphUtils::AddEdge(node_4d_to_5d_1_nhwc->GetOutDataAnchor(0), node_5d_1_nhwc->GetInDataAnchor(0));
  228. // Node4D
  229. ge::NodePtr node_4d_hwcn =
  230. NodeBuilder("Node4D_HWCN", VARIABLE).AddOutputDesc({1, 2, 3, 4}, FORMAT_HWCN, DT_INT32).Build(graph);
  231. // NodeTrans4DTo5D
  232. ge::NodePtr node_4d_to_5d_1_hwcn = NodeBuilder("4d_to_5d_1_HWCN", TRANSDATA)
  233. .AddInputDesc({1, 2, 3, 4}, FORMAT_HWCN, DT_INT32)
  234. .AddOutputDesc({1, 2, 3, 4, 5}, FORMAT_NC1HWC0, DT_FLOAT)
  235. .Build(graph);
  236. // Node5D
  237. ge::NodePtr node_5d_1_hwcn =
  238. NodeBuilder("5D_1_HWCN", RELU).AddInputDesc({1, 2, 3, 4, 5}, FORMAT_NC1HWC0, DT_FLOAT).Build(graph);
  239. // add edge
  240. ge::GraphUtils::AddEdge(node_4d_hwcn->GetOutDataAnchor(0), node_4d_to_5d_1_hwcn->GetInDataAnchor(0));
  241. ge::GraphUtils::AddEdge(node_4d_to_5d_1_hwcn->GetOutDataAnchor(0), node_5d_1_hwcn->GetInDataAnchor(0));
  242. ge::NodePtr node_4d_chwn =
  243. NodeBuilder("Node4D_CHWN", VARIABLE).AddOutputDesc({1, 2, 3, 4}, FORMAT_CHWN, DT_INT32).Build(graph);
  244. // NodeTrans4DTo5D
  245. ge::NodePtr node_4d_to_5d_1_chwn = NodeBuilder("4d_to_5d_1_CHWN", TRANSDATA)
  246. .AddInputDesc({1, 2, 3, 4}, FORMAT_CHWN, DT_INT32)
  247. .AddOutputDesc({1, 2, 3, 4, 5}, FORMAT_NC1HWC0, DT_FLOAT)
  248. .Build(graph);
  249. // Node5D
  250. ge::NodePtr node_5d_1_chwn =
  251. NodeBuilder("5D_1_CHWN", RELU).AddInputDesc({1, 2, 3, 4, 5}, FORMAT_NC1HWC0, DT_FLOAT).Build(graph);
  252. // add edge
  253. ge::GraphUtils::AddEdge(node_4d_chwn->GetOutDataAnchor(0), node_4d_to_5d_1_chwn->GetInDataAnchor(0));
  254. ge::GraphUtils::AddEdge(node_4d_to_5d_1_chwn->GetOutDataAnchor(0), node_5d_1_chwn->GetInDataAnchor(0));
  255. ge::NodePtr node_4d_d =
  256. NodeBuilder("Node4D_D", VARIABLE).AddOutputDesc({1}, FORMAT_CHWN, DT_INT32).Build(graph);
  257. // NodeTrans4DTo5D
  258. ge::NodePtr node_4d_to_5d_1_d = NodeBuilder("4d_to_5d_1_D", TRANSDATA)
  259. .AddInputDesc({1, 2, 3, 4}, FORMAT_CHWN, DT_INT32)
  260. .AddOutputDesc({1, 2, 3, 4, 5}, FORMAT_NC1HWC0, DT_FLOAT)
  261. .Build(graph);
  262. // Node5D
  263. ge::NodePtr node_5d_1_d =
  264. NodeBuilder("5D_1_D", RELU).AddInputDesc({1, 2, 3, 4, 5}, FORMAT_NC1HWC0, DT_FLOAT).Build(graph);
  265. ge::NodePtr node_apply_monetum = NodeBuilder("apply_monetum", APPLYMOMENTUM)
  266. .AddInputDesc({1, 2, 3, 4}, FORMAT_NCHW, DT_INT32)
  267. .AddOutputDesc({1, 2, 3, 4, 5}, FORMAT_NC1HWC0, DT_FLOAT)
  268. .Build(graph);
  269. ge::NodePtr node_5d_to_4d_1 = NodeBuilder("5d_to_4d_1", TRANSDATA)
  270. .AddInputDesc({1, 2, 3, 4, 5}, FORMAT_NC1HWC0, DT_FLOAT)
  271. .AddOutputDesc({1, 2, 3, 4}, FORMAT_NCHW, DT_INT32)
  272. .Build(graph);
  273. ge::NodePtr node_ref = CreatVariableRef(node_5d_to_4d_1, node_4d);
  274. // add edge
  275. ge::GraphUtils::AddEdge(node_4d_d->GetOutDataAnchor(0), node_4d_to_5d_1_d->GetInDataAnchor(0));
  276. ge::GraphUtils::AddEdge(node_4d_to_5d_1_d->GetOutDataAnchor(0), node_5d_1_d->GetInDataAnchor(0));
  277. if (ge::GraphUtils::AddEdge(node_apply_monetum->GetOutDataAnchor(0), node_5d_to_4d_1->GetInDataAnchor(0)) !=
  278. ge::SUCCESS) {
  279. };
  280. ge::GraphUtils::AddEdge(node_5d_to_4d_1->GetOutDataAnchor(0), node_ref->GetInDataAnchor(0));
  281. return true;
  282. }
  283. bool BuildComputeGraph1(ge::ComputeGraphPtr &graph) {
  284. // Node4D
  285. ge::NodePtr node_4d =
  286. NodeBuilder("Node4D", VARIABLE).AddOutputDesc({1, 2, 3, 4}, FORMAT_NCHW, DT_INT32).Build(graph);
  287. // NodeTrans4DTo5D
  288. ge::NodePtr node_4d_to_5d_1 = NodeBuilder("4d_to_5d_1", TRANSDATA)
  289. .AddInputDesc({1, 2, 3, 4}, FORMAT_NCHW, DT_INT32)
  290. .AddOutputDesc({1, 2, 3, 4, 5}, FORMAT_NC1HWC0, DT_FLOAT)
  291. .Build(graph);
  292. ge::NodePtr node_4d_to_5d_2 = NodeBuilder("4d_to_5d_2", TRANSDATA)
  293. .AddInputDesc({1, 2, 3, 4}, FORMAT_NCHW, DT_INT32)
  294. .AddOutputDesc({1, 2, 3, 4, 5}, FORMAT_NC1HWC0, DT_FLOAT)
  295. .Build(graph);
  296. // Node5D
  297. ge::NodePtr node_5d_1 =
  298. NodeBuilder("5D_1", RELU).AddInputDesc({1, 2, 3, 4, 5}, FORMAT_NC1HWC0, DT_FLOAT).Build(graph);
  299. ge::NodePtr node_5d_2 =
  300. NodeBuilder("5D_2", RELU).AddInputDesc({1, 2, 3, 4, 5}, FORMAT_NC1HWC0, DT_FLOAT).Build(graph);
  301. ge::NodePtr node_5d_to_4d_1 = NodeBuilder("5d_to_4d_1", TRANSDATA)
  302. .AddInputDesc({1, 2, 3, 4, 5}, FORMAT_NC1HWC0, DT_INT32)
  303. .AddOutputDesc({1, 2, 3, 4}, FORMAT_NCHW, DT_INT32)
  304. .Build(graph);
  305. ge::NodePtr node_apply_monetum = NodeBuilder("apply_monetum", APPLYMOMENTUM)
  306. .AddInputDesc({1, 2, 3, 4}, FORMAT_NCHW, DT_INT32)
  307. .AddOutputDesc({1, 2, 3, 4, 5}, FORMAT_NC1HWC0, DT_INT32)
  308. .Build(graph);
  309. ge::NodePtr node_ref = CreatVariableRef(node_5d_to_4d_1, node_4d);
  310. // add edge
  311. ge::GraphUtils::AddEdge(node_4d->GetOutDataAnchor(0), node_4d_to_5d_1->GetInDataAnchor(0));
  312. ge::GraphUtils::AddEdge(node_4d->GetOutDataAnchor(0), node_4d_to_5d_2->GetInDataAnchor(0));
  313. ge::GraphUtils::AddEdge(node_4d_to_5d_1->GetOutDataAnchor(0), node_5d_1->GetInDataAnchor(0));
  314. ge::GraphUtils::AddEdge(node_4d_to_5d_2->GetOutDataAnchor(0), node_5d_2->GetInDataAnchor(0));
  315. if (ge::GraphUtils::AddEdge(node_apply_monetum->GetOutDataAnchor(0), node_5d_to_4d_1->GetInDataAnchor(0)) !=
  316. ge::SUCCESS) {
  317. };
  318. ge::GraphUtils::AddEdge(node_5d_to_4d_1->GetOutDataAnchor(0), node_ref->GetInDataAnchor(0));
  319. return true;
  320. }
  321. bool BuildComputeGraph4(ge::ComputeGraphPtr &graph) {
  322. // Node4D
  323. ge::NodePtr node_4d =
  324. NodeBuilder("Node4D", VARIABLE).AddOutputDesc({1, 2, 3, 4}, FORMAT_NCHW, DT_INT32).Build(graph);
  325. // NodeTrans4DTo5D
  326. ge::NodePtr node_4d_to_5d_1 = NodeBuilder("4d_to_5d_1", TRANSDATA)
  327. .AddInputDesc({1, 2, 3, 4}, FORMAT_NCHW, DT_INT32)
  328. .AddOutputDesc({1, 2, 3, 4, 5}, FORMAT_NC1HWC0, DT_FLOAT)
  329. .Build(graph);
  330. ge::NodePtr node_4d_to_5d_2 = NodeBuilder("4d_to_5d_2", TRANSDATA)
  331. .AddInputDesc({1, 2, 3, 4}, FORMAT_NCHW, DT_INT32)
  332. .AddOutputDesc({1, 2, 3, 4, 5}, FORMAT_NC1HWC0, DT_FLOAT)
  333. .Build(graph);
  334. // Node5D
  335. ge::NodePtr node_5d_1 =
  336. NodeBuilder("5D_1", RELU).AddInputDesc({1, 2, 3, 4, 5}, FORMAT_NC1HWC0, DT_FLOAT).Build(graph);
  337. ge::NodePtr node_5d_2 =
  338. NodeBuilder("5D_2", RELU).AddInputDesc({1, 2, 3, 4, 5}, FORMAT_NC1HWC0, DT_FLOAT).Build(graph);
  339. ge::NodePtr node_5d_to_4d_1 = NodeBuilder("5d_to_4d_1", TRANSDATA)
  340. .AddInputDesc({1, 2, 3, 4, 5}, FORMAT_NC1HWC0, DT_INT32)
  341. .AddOutputDesc({1, 2, 3, 4}, FORMAT_NCHW, DT_INT32)
  342. .Build(graph);
  343. ge::NodePtr node_5d_to_4d_2 = NodeBuilder("5d_to_4d_2", TRANSDATA)
  344. .AddInputDesc({1, 2, 3, 4, 5}, FORMAT_NC1HWC0, DT_INT32)
  345. .AddOutputDesc({1, 2, 3, 4}, FORMAT_NCHW, DT_INT32)
  346. .Build(graph);
  347. ge::NodePtr node_apply_monetum = NodeBuilder("apply_monetum", APPLYMOMENTUM)
  348. .AddInputDesc({1, 2, 3, 4}, FORMAT_NCHW, DT_INT32)
  349. .AddOutputDesc({1, 2, 3, 4, 5}, FORMAT_NC1HWC0, DT_INT32)
  350. .Build(graph);
  351. ge::NodePtr node_ref = CreatVariableRef(node_5d_to_4d_1, node_4d);
  352. // add edge
  353. ge::GraphUtils::AddEdge(node_4d->GetOutDataAnchor(0), node_4d_to_5d_1->GetInDataAnchor(0));
  354. ge::GraphUtils::AddEdge(node_4d->GetOutDataAnchor(0), node_4d_to_5d_2->GetInDataAnchor(0));
  355. ge::GraphUtils::AddEdge(node_4d_to_5d_1->GetOutDataAnchor(0), node_5d_1->GetInDataAnchor(0));
  356. ge::GraphUtils::AddEdge(node_4d_to_5d_2->GetOutDataAnchor(0), node_5d_2->GetInDataAnchor(0));
  357. ge::GraphUtils::AddEdge(node_apply_monetum->GetOutDataAnchor(0), node_5d_to_4d_1->GetInDataAnchor(0));
  358. ge::GraphUtils::AddEdge(node_5d_to_4d_1->GetOutDataAnchor(0), node_ref->GetInDataAnchor(0));
  359. ge::GraphUtils::AddEdge(node_5d_to_4d_2->GetOutDataAnchor(0), node_ref->GetInDataAnchor(0));
  360. return true;
  361. }
  362. bool BuildComputeGraph5(ge::ComputeGraphPtr &graph) {
  363. // Node4D
  364. ge::NodePtr node_4d =
  365. NodeBuilder("Node4D", VARIABLE).AddOutputDesc({1, 2, 3, 4}, FORMAT_NCHW, DT_INT32).Build(graph);
  366. return true;
  367. }
  368. bool BuildComputeGraph6(ge::ComputeGraphPtr &graph) {
  369. // Node4D
  370. ge::NodePtr node_4d =
  371. NodeBuilder("Node4D", VARIABLE).AddOutputDesc({1, 2, 3, 4}, FORMAT_NCHW, DT_INT32).Build(graph);
  372. // NodeTrans4DTo5D
  373. ge::NodePtr node_4d_to_5d_1 = NodeBuilder("4d_to_5d_1", TRANSDATA)
  374. .AddInputDesc({1, 2, 3, 4}, FORMAT_NCHW, DT_INT32)
  375. .AddOutputDesc({1, 2, 3, 4, 5}, FORMAT_NC1HWC0, DT_FLOAT)
  376. .Build(graph);
  377. ge::NodePtr node_float_to_int_1 = NodeBuilder("float_to_int_1", CAST)
  378. .AddInputDesc({1, 2, 3, 4, 5}, FORMAT_NC1HWC0, DT_FLOAT)
  379. .AddOutputDesc({1, 2, 3, 4, 5}, FORMAT_NC1HWC0, DT_INT32)
  380. .Build(graph);
  381. ge::NodePtr node_4d_to_5d_2 = NodeBuilder("4d_to_5d_2", TRANSDATA)
  382. .AddInputDesc({1, 2, 3, 4}, FORMAT_NCHW, DT_INT32)
  383. .AddOutputDesc({1, 2, 3, 4, 5}, FORMAT_NC1HWC0, DT_FLOAT)
  384. .Build(graph);
  385. ge::NodePtr node_float_to_int_2 = NodeBuilder("float_to_int_2", CAST)
  386. .AddInputDesc({1, 2, 3, 4, 5}, FORMAT_NC1HWC0, DT_FLOAT)
  387. .AddOutputDesc({1, 2, 3, 4, 5}, FORMAT_NC1HWC0, DT_INT32)
  388. .Build(graph);
  389. // Node5D
  390. ge::NodePtr node_5d_1 =
  391. NodeBuilder("5D_1", RELU).AddInputDesc({1, 2, 3, 4, 5}, FORMAT_NC1HWC0, DT_INT32).Build(graph);
  392. ge::NodePtr node_5d_2 =
  393. NodeBuilder("5D_2", RELU).AddInputDesc({1, 2, 3, 4, 5}, FORMAT_NC1HWC0, DT_INT32).Build(graph);
  394. // add edge
  395. ge::GraphUtils::AddEdge(node_4d->GetOutDataAnchor(0), node_4d_to_5d_1->GetInDataAnchor(0));
  396. ge::GraphUtils::AddEdge(node_4d->GetOutDataAnchor(0), node_4d_to_5d_2->GetInDataAnchor(0));
  397. ge::GraphUtils::AddEdge(node_4d_to_5d_1->GetOutDataAnchor(0), node_float_to_int_1->GetInDataAnchor(0));
  398. ge::GraphUtils::AddEdge(node_4d_to_5d_2->GetOutDataAnchor(0), node_float_to_int_2->GetInDataAnchor(0));
  399. ge::GraphUtils::AddEdge(node_float_to_int_1->GetOutDataAnchor(0), node_5d_1->GetInDataAnchor(0));
  400. ge::GraphUtils::AddEdge(node_float_to_int_2->GetOutDataAnchor(0), node_5d_2->GetInDataAnchor(0));
  401. return true;
  402. }
  403. } // namespace
  404. bool BuildComputeGraph7(ge::ComputeGraphPtr &graph) {
  405. // Node4D
  406. ge::NodePtr node_4d =
  407. NodeBuilder("Node4D", VARIABLE).AddOutputDesc({1, 2, 3, 4}, FORMAT_NCHW, DT_INT32).Build(graph);
  408. // NodeTrans4DTo5D
  409. ge::NodePtr node_4d_to_4d_1 = NodeBuilder("4d_to_4d_1", TRANSDATA)
  410. .AddInputDesc({1, 2, 3, 4}, FORMAT_NCHW, DT_INT32)
  411. .AddOutputDesc({1, 2, 3, 4}, FORMAT_NCHW, DT_INT32)
  412. .Build(graph);
  413. // Node5D
  414. ge::NodePtr node_4d_1 = NodeBuilder("4D_1", RELU).AddInputDesc({1, 2, 3, 4}, FORMAT_NC1HWC0, DT_INT32).Build(graph);
  415. // add edge
  416. ge::GraphUtils::AddEdge(node_4d->GetOutDataAnchor(0), node_4d_to_4d_1->GetInDataAnchor(0));
  417. ge::GraphUtils::AddEdge(node_4d_to_4d_1->GetOutDataAnchor(0), node_4d_1->GetInDataAnchor(0));
  418. return true;
  419. }
  420. class VariableOpPassSimulator {
  421. public:
  422. bool DoTest0() {
  423. ge::ComputeGraphPtr compute_graph = std::make_shared<ComputeGraph>("0");
  424. const std::string var_name = "Node4D";
  425. uint64_t session_id = 0;
  426. uint32_t device_id = 0;
  427. uint64_t job_id = 0;
  428. uint32_t session_version = 0;
  429. std::vector<int64_t> dims(4, 20);
  430. ge::GeShape shape(dims);
  431. MemManager::Instance().Initialize(std::vector<rtMemType_t>({RT_MEMORY_HBM}));
  432. VarManager::Instance(session_id)->Init(session_version, session_id, device_id, job_id);
  433. BuildComputeGraph0(compute_graph);
  434. std::vector<std::string> var_names = {"Node4D_new", "Node4D", "Node4D_NHWC",
  435. "Node4D_HWCN", "Node4D_CHWN", "Node4D_D"};
  436. for (auto name : var_names) {
  437. auto var_node = compute_graph->FindNode(name);
  438. auto var_tensor_desc = var_node->GetOpDesc()->GetOutputDesc(0);
  439. uint8_t *dev_ptr = nullptr;
  440. ge::VarManager::Instance(session_id)->AssignVarMem(name, var_tensor_desc, RT_MEMORY_HBM);
  441. ge::VarManager::Instance(session_id)->SetVarAddr(name, var_tensor_desc, dev_ptr, RT_MEMORY_HBM);
  442. }
  443. ge::GraphNodePtr graph_node = make_shared<GraphNode>(0);
  444. compute_graph->InferShapeInNeed();
  445. graph_node->SetComputeGraph(compute_graph);
  446. auto tmp_graph = GraphUtils::CreateGraphFromComputeGraph(compute_graph);
  447. auto tmp_graph_ptr = std::make_shared<Graph>(tmp_graph);
  448. graph_node->SetGraph(tmp_graph_ptr);
  449. VarAccelerateCtrl ctrl;
  450. ctrl.AddGraph(graph_node->GetGraphId(), compute_graph);
  451. ge::formats::FormatTransferNchwNc1hwc0 ClassObj;
  452. VariableOpPass pass(&ctrl);
  453. pass.Run(compute_graph);
  454. MemManager::Instance().Finalize();
  455. return CheckTest0(compute_graph);
  456. }
  457. bool DoTest1() {
  458. ge::ComputeGraphPtr compute_graph = std::make_shared<ComputeGraph>("0");
  459. const std::string var_name = "Node4D";
  460. uint64_t session_id = 0;
  461. uint32_t device_id = 0;
  462. uint64_t job_id = 0;
  463. uint32_t session_version = 0;
  464. std::vector<int64_t> dims(4, 20);
  465. ge::GeShape shape(dims);
  466. VarManager::Instance(session_id)->Init(session_version, session_id, device_id, job_id);
  467. BuildComputeGraph1(compute_graph);
  468. auto var_node = compute_graph->FindNode(var_name);
  469. auto var_tensor_desc = var_node->GetOpDesc()->GetOutputDesc(0);
  470. uint8_t *dev_ptr = nullptr;
  471. ge::GraphNodePtr graph_node = make_shared<GraphNode>(0);
  472. compute_graph->InferShapeInNeed();
  473. graph_node->SetComputeGraph(compute_graph);
  474. auto tmp_graph = GraphUtils::CreateGraphFromComputeGraph(compute_graph);
  475. auto tmp_graph_ptr = std::make_shared<Graph>(tmp_graph);
  476. graph_node->SetGraph(tmp_graph_ptr);
  477. VarAccelerateCtrl ctrl;
  478. ctrl.AddGraph(graph_node->GetGraphId(), compute_graph);
  479. VariableOpPass pass(&ctrl);
  480. pass.Run(compute_graph);
  481. return CheckTest1(compute_graph);
  482. }
  483. bool DoTest2() {
  484. VarAccelerateCtrl ctrl;
  485. VariableOpPass pass(&ctrl);
  486. return pass.Run(nullptr) == ge::INTERNAL_ERROR;
  487. }
  488. bool DoTest3() {
  489. std::vector<rtMemType_t> mem_type;
  490. std::map<std::string, std::string> empty_options;
  491. mem_type.push_back(RT_MEMORY_HBM);
  492. MemManager::Instance().Initialize(mem_type);
  493. ge::ComputeGraphPtr compute_graph = std::make_shared<ComputeGraph>("0");
  494. std::vector<std::string> var_names = {"Node4D", "Node4D_NHWC", "Node4D_HWCN", "Node4D_CHWN", "Node4D_D"};
  495. std::vector<ge::GeTensorDesc> tensor_descs;
  496. uint64_t session_id = 0;
  497. uint32_t device_id = 0;
  498. uint64_t job_id = 0;
  499. uint32_t session_version = 0;
  500. compute_graph->SetSessionID(session_id);
  501. std::vector<int64_t> dims(4, 20);
  502. ge::GeShape shape(dims);
  503. VarManager::Instance(session_id)->Init(session_version, session_id, device_id, job_id);
  504. BuildComputeGraph0(compute_graph);
  505. for (auto var_name : var_names) {
  506. auto var_node = compute_graph->FindNode(var_name);
  507. auto var_tensor_desc = var_node->GetOpDesc()->GetOutputDesc(0);
  508. uint8_t *dev_ptr = nullptr;
  509. ge::VarManager::Instance(session_id)->AssignVarMem(var_name, var_tensor_desc, RT_MEMORY_HBM);
  510. ge::VarManager::Instance(session_id)->SetVarAddr(var_name, var_tensor_desc, dev_ptr, RT_MEMORY_HBM);
  511. }
  512. ge::GraphNodePtr graph_node = make_shared<GraphNode>(0);
  513. compute_graph->InferShapeInNeed();
  514. graph_node->SetComputeGraph(compute_graph);
  515. auto tmp_graph = GraphUtils::CreateGraphFromComputeGraph(compute_graph);
  516. auto tmp_graph_ptr = std::make_shared<Graph>(tmp_graph);
  517. graph_node->SetGraph(tmp_graph_ptr);
  518. VarAccelerateCtrl ctrl;
  519. ctrl.AddGraph(graph_node->GetGraphId(), compute_graph);
  520. VariableOpPass pass(&ctrl);
  521. auto ret = pass.Run(compute_graph);
  522. MemManager::Instance().Finalize();
  523. return ret == GE_GRAPH_VARIABLE_OP_PASS_FAILED;
  524. }
  525. bool DoTest4() {
  526. ge::ComputeGraphPtr compute_graph = std::make_shared<ComputeGraph>("0");
  527. const std::string var_name = "Node4D";
  528. uint64_t session_id = 0;
  529. uint32_t device_id = 0;
  530. uint64_t job_id = 0;
  531. uint32_t session_version = 0;
  532. std::vector<int64_t> dims(4, 20);
  533. ge::GeShape shape(dims);
  534. VarManager::Instance(session_id)->Init(session_version, session_id, device_id, job_id);
  535. BuildComputeGraph4(compute_graph);
  536. auto var_node = compute_graph->FindNode(var_name);
  537. auto var_tensor_desc = var_node->GetOpDesc()->GetOutputDesc(0);
  538. uint8_t *dev_ptr = nullptr;
  539. ge::GraphNodePtr graph_node = make_shared<GraphNode>(0);
  540. compute_graph->InferShapeInNeed();
  541. graph_node->SetComputeGraph(compute_graph);
  542. auto tmp_graph = GraphUtils::CreateGraphFromComputeGraph(compute_graph);
  543. auto tmp_graph_ptr = std::make_shared<Graph>(tmp_graph);
  544. graph_node->SetGraph(tmp_graph_ptr);
  545. VarAccelerateCtrl ctrl;
  546. ctrl.AddGraph(graph_node->GetGraphId(), compute_graph);
  547. VariableOpPass pass(&ctrl);
  548. auto ret = pass.Run(compute_graph);
  549. return ret == ge::SUCCESS;
  550. }
  551. bool DoTest5() {
  552. ge::ComputeGraphPtr compute_graph = std::make_shared<ComputeGraph>("0");
  553. BuildComputeGraph5(compute_graph);
  554. const std::string var_name = "Node4D";
  555. uint64_t session_id = 0;
  556. uint32_t device_id = 0;
  557. uint64_t job_id = 0;
  558. uint32_t session_version = 0;
  559. std::vector<int64_t> dims(4, 20);
  560. ge::GeShape shape(dims);
  561. VarManager::Instance(session_id)->Init(session_version, session_id, device_id, job_id);
  562. BuildComputeGraph4(compute_graph);
  563. auto var_node = compute_graph->FindNode(var_name);
  564. auto var_tensor_desc = var_node->GetOpDesc()->GetOutputDesc(0);
  565. uint8_t *dev_ptr = nullptr;
  566. ge::GraphNodePtr graph_node = make_shared<GraphNode>(0);
  567. compute_graph->InferShapeInNeed();
  568. graph_node->SetComputeGraph(compute_graph);
  569. auto tmp_graph = GraphUtils::CreateGraphFromComputeGraph(compute_graph);
  570. auto tmp_graph_ptr = std::make_shared<Graph>(tmp_graph);
  571. graph_node->SetGraph(tmp_graph_ptr);
  572. VarAccelerateCtrl ctrl;
  573. ctrl.AddGraph(graph_node->GetGraphId(), compute_graph);
  574. VariableOpPass pass(&ctrl);
  575. auto ret = pass.Run(compute_graph);
  576. return ret == ge::SUCCESS;
  577. }
  578. bool DoTest6() {
  579. ge::ComputeGraphPtr compute_graph = std::make_shared<ComputeGraph>("0");
  580. const std::string var_name = "Node4D";
  581. uint64_t session_id = 0;
  582. uint32_t device_id = 0;
  583. uint64_t job_id = 0;
  584. uint32_t session_version = 0;
  585. std::vector<int64_t> dims(4, 20);
  586. ge::GeShape shape(dims);
  587. MemManager::Instance().Initialize(std::vector<rtMemType_t>({RT_MEMORY_HBM}));
  588. VarManager::Instance(session_id)->Init(session_version, session_id, device_id, job_id);
  589. BuildComputeGraph6(compute_graph);
  590. auto var_node = compute_graph->FindNode(var_name);
  591. auto var_tensor_desc = var_node->GetOpDesc()->GetOutputDesc(0);
  592. uint8_t *dev_ptr = nullptr;
  593. ge::VarManager::Instance(session_id)->AssignVarMem(var_name, var_tensor_desc, RT_MEMORY_HBM);
  594. ge::VarManager::Instance(session_id)->SetVarAddr(var_name, var_tensor_desc, dev_ptr, RT_MEMORY_HBM);
  595. ge::GraphNodePtr graph_node = make_shared<GraphNode>(0);
  596. compute_graph->InferShapeInNeed();
  597. graph_node->SetComputeGraph(compute_graph);
  598. auto tmp_graph = GraphUtils::CreateGraphFromComputeGraph(compute_graph);
  599. auto tmp_graph_ptr = std::make_shared<Graph>(tmp_graph);
  600. graph_node->SetGraph(tmp_graph_ptr);
  601. VarAccelerateCtrl ctrl;
  602. ctrl.AddGraph(graph_node->GetGraphId(), compute_graph);
  603. ge::formats::FormatTransferNchwNc1hwc0 ClassObj;
  604. VariableOpPass pass(&ctrl);
  605. auto ret = pass.Run(compute_graph);
  606. MemManager::Instance().Finalize();
  607. return CheckTest6(compute_graph);
  608. }
  609. bool DoTest7() {
  610. ge::ComputeGraphPtr compute_graph = std::make_shared<ComputeGraph>("0");
  611. const std::string var_name = "Node4D";
  612. uint64_t session_id = 0;
  613. uint32_t device_id = 0;
  614. uint64_t job_id = 0;
  615. uint32_t session_version = 0;
  616. std::vector<int64_t> dims(4, 20);
  617. ge::GeShape shape(dims);
  618. VarManager::Instance(session_id)->Init(session_version, session_id, device_id, job_id);
  619. BuildComputeGraph7(compute_graph);
  620. auto var_node = compute_graph->FindNode(var_name);
  621. auto var_tensor_desc = var_node->GetOpDesc()->GetOutputDesc(0);
  622. uint8_t *dev_ptr = nullptr;
  623. ge::GraphNodePtr graph_node = make_shared<GraphNode>(0);
  624. compute_graph->InferShapeInNeed();
  625. graph_node->SetComputeGraph(compute_graph);
  626. auto tmp_graph = GraphUtils::CreateGraphFromComputeGraph(compute_graph);
  627. auto tmp_graph_ptr = std::make_shared<Graph>(tmp_graph);
  628. graph_node->SetGraph(tmp_graph_ptr);
  629. VarAccelerateCtrl ctrl;
  630. ctrl.AddGraph(graph_node->GetGraphId(), compute_graph);
  631. VariableOpPass pass(&ctrl);
  632. auto ret = pass.Run(compute_graph);
  633. return CheckTest7(compute_graph);
  634. }
  635. bool DoTest8() {
  636. ge::ComputeGraphPtr compute_graph = std::make_shared<ComputeGraph>("0");
  637. const std::string var_name = "Node4D";
  638. uint64_t session_id = 0;
  639. uint32_t device_id = 0;
  640. uint64_t job_id = 0;
  641. uint32_t session_version = 0;
  642. std::vector<int64_t> dims(4, 20);
  643. ge::GeShape shape(dims);
  644. VarManager::Instance(session_id)->Init(session_version, session_id, device_id, job_id);
  645. BuildComputeGraph0(compute_graph);
  646. auto var_node = compute_graph->FindNode(var_name);
  647. auto var_tensor_desc = var_node->GetOpDesc()->GetOutputDesc(0);
  648. uint8_t *dev_ptr = nullptr;
  649. ge::GraphNodePtr graph_node = make_shared<GraphNode>(0);
  650. compute_graph->InferShapeInNeed();
  651. graph_node->SetComputeGraph(compute_graph);
  652. auto tmp_graph = GraphUtils::CreateGraphFromComputeGraph(compute_graph);
  653. auto tmp_graph_ptr = std::make_shared<Graph>(tmp_graph);
  654. graph_node->SetGraph(tmp_graph_ptr);
  655. VarAccelerateCtrl ctrl;
  656. ctrl.AddGraph(graph_node->GetGraphId(), compute_graph);
  657. VariableOpPass pass(&ctrl);
  658. pass.Run(compute_graph);
  659. return CheckTest8(compute_graph);
  660. }
  661. private:
  662. bool CheckTest0(const ge::ComputeGraphPtr compute_graph) {
  663. const auto &variable_node = compute_graph->FindNode("Node4D");
  664. auto variable_node_format = variable_node->GetOpDesc()->GetOutputDesc(0).GetFormat();
  665. auto variable_node_data_type = variable_node->GetOpDesc()->GetOutputDesc(0).GetDataType();
  666. auto variable_node_shape = variable_node->GetOpDesc()->GetOutputDesc(0).GetShape().GetDims();
  667. if (variable_node_format != FORMAT_NC1HWC0 || variable_node_data_type != DT_FLOAT ||
  668. variable_node_shape.size() != 5) {
  669. std::cout << "var format not changed !" << std::endl;
  670. return false;
  671. }
  672. const auto &variable_ref_node = compute_graph->FindNode(var_ref_name_0);
  673. GELOGD("var_ref_name_0 is %s", var_ref_name_0.c_str());
  674. auto variable_ref_node_format = variable_ref_node->GetOpDesc()->GetInputDesc(0).GetFormat();
  675. auto variable_ref_node_data_type = variable_ref_node->GetOpDesc()->GetInputDesc(0).GetDataType();
  676. auto variable_ref_node_shape = variable_ref_node->GetOpDesc()->GetInputDesc(0).GetShape().GetDims();
  677. if (variable_ref_node_format != FORMAT_NC1HWC0 || variable_ref_node_data_type != DT_FLOAT ||
  678. variable_ref_node_shape.size() != 5) {
  679. GELOGI("wanted data format is (%d,%d,%u)", FORMAT_NC1HWC0, DT_FLOAT, 5);
  680. GELOGI("variable_ref_node_format is (%d,%d,%zu)", variable_ref_node_format, variable_ref_node_data_type,
  681. variable_ref_node_shape.size());
  682. std::cout << "var ref format not changed !" << std::endl;
  683. return false;
  684. }
  685. ge::NodePtr trans_node = compute_graph->FindNode("4d_to_5d_1");
  686. if (trans_node != nullptr) {
  687. std::cout << "4d_to_5d_1 not empty !" << std::endl;
  688. return false;
  689. }
  690. trans_node = compute_graph->FindNode("4d_to_5d_2");
  691. if (trans_node != nullptr) {
  692. std::cout << "4d_to_5d_2 not empty !" << std::endl;
  693. return false;
  694. }
  695. trans_node = compute_graph->FindNode("5d_to_4d_1");
  696. if (trans_node != nullptr) {
  697. std::cout << "5d_to_4d_1 not empty !" << std::endl;
  698. return false;
  699. }
  700. trans_node = compute_graph->FindNode("4d_to_5d_1_new");
  701. if (trans_node == nullptr) {
  702. std::cout << "4d_to_5d_1_new is empty !" << std::endl;
  703. return false;
  704. }
  705. auto new_variable_node = compute_graph->FindNode("Node4D_new");
  706. auto new_variable_node_format = new_variable_node->GetOpDesc()->GetOutputDesc(0).GetFormat();
  707. auto new_variable_node_data_type = new_variable_node->GetOpDesc()->GetOutputDesc(0).GetDataType();
  708. auto new_variable_node_shape = new_variable_node->GetOpDesc()->GetOutputDesc(0).GetShape().GetDims();
  709. if (new_variable_node_format != FORMAT_NCHW || new_variable_node_data_type != DT_INT32 ||
  710. new_variable_node_shape.size() != 4) {
  711. std::cout << "Node4D_new format Changed ! wanted data format is ( " << FORMAT_NC1HWC0 << ", " << DT_INT32
  712. << ", 4) " << std::endl;
  713. std::cout << "current is ( " << new_variable_node_format << ", " << new_variable_node_data_type << ", "
  714. << new_variable_node_shape.size() << ")" << std::endl;
  715. return false;
  716. }
  717. return true;
  718. };
  719. bool CheckTest1(const ge::ComputeGraphPtr compute_graph) {
  720. const auto &variable_node = compute_graph->FindNode("Node4D");
  721. auto variable_node_format = variable_node->GetOpDesc()->GetOutputDesc(0).GetFormat();
  722. auto variable_node_data_type = variable_node->GetOpDesc()->GetOutputDesc(0).GetDataType();
  723. auto variable_node_shape = variable_node->GetOpDesc()->GetOutputDesc(0).GetShape().GetDims();
  724. if (variable_node_format != FORMAT_NCHW || variable_node_data_type != DT_INT32 || variable_node_shape.size() != 4) {
  725. std::cout << "var format changed !" << std::endl;
  726. return false;
  727. }
  728. const auto &variable_ref_node = compute_graph->FindNode(var_ref_name_0);
  729. GELOGD("var_ref_name_0 is %s", var_ref_name_0.c_str());
  730. auto variable_ref_node_format = variable_ref_node->GetOpDesc()->GetInputDesc(0).GetFormat();
  731. auto variable_ref_node_data_type = variable_ref_node->GetOpDesc()->GetInputDesc(0).GetDataType();
  732. auto variable_ref_node_shape = variable_ref_node->GetOpDesc()->GetInputDesc(0).GetShape().GetDims();
  733. if (variable_ref_node_format != FORMAT_NCHW || variable_ref_node_data_type != DT_INT32 ||
  734. variable_ref_node_shape.size() != 4) {
  735. GELOGI("wanted data format is (%d,%d,%u)", FORMAT_NCHW, DT_INT32, 4);
  736. GELOGI("variable_ref_node_format is (%d,%d,%zu)", variable_ref_node_format, variable_ref_node_data_type,
  737. variable_ref_node_shape.size());
  738. std::cout << "var ref format not changed !" << std::endl;
  739. return false;
  740. }
  741. ge::NodePtr trans_node = compute_graph->FindNode("4d_to_5d_1");
  742. if (trans_node == nullptr) {
  743. std::cout << "4d_to_5d_1 empty !" << std::endl;
  744. return false;
  745. }
  746. trans_node = compute_graph->FindNode("4d_to_5d_2");
  747. if (trans_node == nullptr) {
  748. std::cout << "4d_to_5d_2 empty !" << std::endl;
  749. return false;
  750. }
  751. trans_node = compute_graph->FindNode("5d_to_4d_1");
  752. if (trans_node == nullptr) {
  753. std::cout << "5d_to_4d_1 not empty !" << std::endl;
  754. return false;
  755. }
  756. return true;
  757. };
  758. bool CheckTest6(const ge::ComputeGraphPtr compute_graph) {
  759. const auto &variable_node = compute_graph->FindNode("Node4D");
  760. auto variable_node_format = variable_node->GetOpDesc()->GetOutputDesc(0).GetFormat();
  761. auto variable_node_data_type = variable_node->GetOpDesc()->GetOutputDesc(0).GetDataType();
  762. auto variable_node_shape = variable_node->GetOpDesc()->GetOutputDesc(0).GetShape().GetDims();
  763. if (variable_node_format != FORMAT_NC1HWC0 || variable_node_data_type != DT_INT32 ||
  764. variable_node_shape.size() != 5) {
  765. std::cout << "var format not changed !" << std::endl;
  766. return false;
  767. }
  768. ge::NodePtr trans_node = compute_graph->FindNode("4d_to_5d_1");
  769. if (trans_node != nullptr) {
  770. std::cout << "4d_to_5d_1 not empty !" << std::endl;
  771. return false;
  772. }
  773. trans_node = compute_graph->FindNode("4d_to_5d_2");
  774. if (trans_node != nullptr) {
  775. std::cout << "4d_to_5d_2 not empty !" << std::endl;
  776. return false;
  777. }
  778. trans_node = compute_graph->FindNode("float_to_int_1");
  779. if (trans_node != nullptr) {
  780. std::cout << "float_to_int_1 not empty !" << std::endl;
  781. return false;
  782. }
  783. trans_node = compute_graph->FindNode("float_to_int_2");
  784. if (trans_node != nullptr) {
  785. std::cout << "float_to_int_1 not empty !" << std::endl;
  786. return false;
  787. }
  788. return true;
  789. };
  790. bool CheckTest7(const ge::ComputeGraphPtr compute_graph) {
  791. const auto &variable_node = compute_graph->FindNode("Node4D");
  792. auto variable_node_format = variable_node->GetOpDesc()->GetOutputDesc(0).GetFormat();
  793. auto variable_node_data_type = variable_node->GetOpDesc()->GetOutputDesc(0).GetDataType();
  794. auto variable_node_shape = variable_node->GetOpDesc()->GetOutputDesc(0).GetShape().GetDims();
  795. if (variable_node_format != FORMAT_NC1HWC0 || variable_node_data_type != DT_INT32 ||
  796. variable_node_shape.size() != 5) {
  797. std::cout << "var format not changed !" << std::endl;
  798. return false;
  799. }
  800. ge::NodePtr trans_node = compute_graph->FindNode("4d_to_4d_1");
  801. if (trans_node != nullptr) {
  802. std::cout << "4d_to_5d_1 not empty !" << std::endl;
  803. return false;
  804. }
  805. return true;
  806. };
  807. bool CheckTest8(const ge::ComputeGraphPtr compute_graph) {
  808. const auto &variable_node = compute_graph->FindNode("Node4D");
  809. auto variable_node_format = variable_node->GetOpDesc()->GetOutputDesc(0).GetFormat();
  810. auto variable_node_data_type = variable_node->GetOpDesc()->GetOutputDesc(0).GetDataType();
  811. auto variable_node_shape = variable_node->GetOpDesc()->GetOutputDesc(0).GetShape().GetDims();
  812. return true;
  813. };
  814. };
  815. TEST_F(UtestVariableOpPassUnit, test_trans_data_remove) {
  816. VariableOpPassSimulator varibale_op_pass_simulator;
  817. bool result = varibale_op_pass_simulator.DoTest0();
  818. EXPECT_EQ(result, true);
  819. }
  820. TEST_F(UtestVariableOpPassUnit, test_variable_ref) {
  821. VariableOpPassSimulator varibale_op_pass_simulator;
  822. bool result = varibale_op_pass_simulator.DoTest1();
  823. EXPECT_EQ(result, true);
  824. }
  825. TEST_F(UtestVariableOpPassUnit, test_null_graph) {
  826. VariableOpPassSimulator varibale_op_pass_simulator;
  827. bool result = varibale_op_pass_simulator.DoTest2();
  828. EXPECT_EQ(result, true);
  829. }
  830. TEST_F(UtestVariableOpPassUnit, test_covarage_trans_var_data) {
  831. VariableOpPassSimulator varibale_op_pass_simulator;
  832. bool result = varibale_op_pass_simulator.DoTest3();
  833. EXPECT_EQ(result, false);
  834. }
  835. TEST_F(UtestVariableOpPassUnit, test_illegally_ref) {
  836. VariableOpPassSimulator varibale_op_pass_simulator;
  837. bool result = varibale_op_pass_simulator.DoTest4();
  838. EXPECT_EQ(result, true);
  839. }
  840. TEST_F(UtestVariableOpPassUnit, test_single_node) {
  841. VariableOpPassSimulator varibale_op_pass_simulator;
  842. bool result = varibale_op_pass_simulator.DoTest5();
  843. EXPECT_EQ(result, true);
  844. }
  845. TEST_F(UtestVariableOpPassUnit, test_un_mathed) {
  846. VariableOpPassSimulator varibale_op_pass_simulator;
  847. bool result = varibale_op_pass_simulator.DoTest6();
  848. EXPECT_EQ(result, true);
  849. }
  850. TEST_F(UtestVariableOpPassUnit, test_same_op) {
  851. VariableOpPassSimulator varibale_op_pass_simulator;
  852. bool result = varibale_op_pass_simulator.DoTest7();
  853. EXPECT_EQ(true, true);
  854. }
  855. TEST_F(UtestVariableOpPassUnit, test_error_return) {
  856. VariableOpPassSimulator varibale_op_pass_simulator;
  857. bool result = varibale_op_pass_simulator.DoTest8();
  858. EXPECT_EQ(true, true);
  859. }
  860. TEST_F(UtestVariableOpPassUnit, reshape) {
  861. // init
  862. MemManager::Instance().Initialize(std::vector<rtMemType_t>({RT_MEMORY_HBM}));
  863. VarManager::Instance(0)->Init(0, 0, 0, 0);
  864. auto graph = BuildGraph2();
  865. graph->SetSessionID(0);
  866. auto var1 = graph->FindNode("var1");
  867. VarManager::Instance(0)->AssignVarMem(var1->GetName(), var1->GetOpDesc()->GetOutputDesc(0), RT_MEMORY_HBM);
  868. uint8_t *dev_ptr = nullptr;
  869. VarManager::Instance(0)->SetVarAddr(var1->GetName(), var1->GetOpDesc()->GetOutputDesc(0), dev_ptr, RT_MEMORY_HBM);
  870. ge::GraphNodePtr graph_node = make_shared<GraphNode>(0);
  871. graph->InferShapeInNeed();
  872. graph_node->SetComputeGraph(graph);
  873. auto tmp_graph = GraphUtils::CreateGraphFromComputeGraph(graph);
  874. auto tmp_graph_ptr = std::make_shared<Graph>(tmp_graph);
  875. graph_node->SetGraph(tmp_graph_ptr);
  876. VarAccelerateCtrl ctrl;
  877. ctrl.AddGraph(graph_node->GetGraphId(), graph);
  878. VariableOpPass pass(&ctrl);
  879. EXPECT_EQ(pass.Run(graph), ge::SUCCESS);
  880. MemManager::Instance().Finalize();
  881. EXPECT_EQ(var1->GetOutNodes().size(), 1);
  882. EXPECT_EQ(var1->GetOutDataNodes().at(0)->GetName(), "conv1");
  883. EXPECT_EQ(var1->GetOpDesc()->GetOutputDesc(0).GetFormat(), FORMAT_HWCN);
  884. EXPECT_EQ(var1->GetOpDesc()->GetOutputDesc(0).GetShape().GetDims(), std::vector<int64_t>({8, 8, 3, 2}));
  885. }
  886. TEST_F(UtestVariableOpPassUnit, reformat) {
  887. // init
  888. MemManager::Instance().Initialize(std::vector<rtMemType_t>({RT_MEMORY_HBM}));
  889. VarManager::Instance(0)->Init(0, 0, 0, 0);
  890. auto graph = BuildGraph3();
  891. graph->SetSessionID(0);
  892. auto var1 = graph->FindNode("var1");
  893. VarManager::Instance(0)->AssignVarMem(var1->GetName(), var1->GetOpDesc()->GetOutputDesc(0), RT_MEMORY_HBM);
  894. uint8_t *dev_ptr = nullptr;
  895. VarManager::Instance(0)->SetVarAddr(var1->GetName(), var1->GetOpDesc()->GetOutputDesc(0), dev_ptr, RT_MEMORY_HBM);
  896. ge::GraphNodePtr graph_node = make_shared<GraphNode>(0);
  897. graph->InferShapeInNeed();
  898. graph_node->SetComputeGraph(graph);
  899. auto tmp_graph = GraphUtils::CreateGraphFromComputeGraph(graph);
  900. auto tmp_graph_ptr = std::make_shared<Graph>(tmp_graph);
  901. graph_node->SetGraph(tmp_graph_ptr);
  902. VarAccelerateCtrl ctrl;
  903. ctrl.AddGraph(graph_node->GetGraphId(), graph);
  904. VariableOpPass pass(&ctrl);
  905. EXPECT_EQ(pass.Run(graph), ge::SUCCESS);
  906. MemManager::Instance().Finalize();
  907. EXPECT_EQ(var1->GetOutNodes().size(), 1);
  908. EXPECT_EQ(var1->GetOutDataNodes().at(0)->GetName(), "conv1");
  909. EXPECT_EQ(var1->GetOpDesc()->GetOutputDesc(0).GetFormat(), FORMAT_ND);
  910. EXPECT_EQ(var1->GetOpDesc()->GetOutputDesc(0).GetShape().GetDims(), std::vector<int64_t>({8, 8, 3, 2}));
  911. }
  912. TEST_F(UtestVariableOpPassUnit, invalid_src_shape2) {
  913. formats::FormatTransferNchwNc1hwc0 t1;
  914. formats::FormatTransferNhwcNc1hwc0 t2;
  915. formats::TransArgs args = formats::TransArgs();
  916. formats::TransResult ret;
  917. t2.TransFormat(args, ret);
  918. }

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示