You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

ge_ir_build_unittest.cc 15 kB

4 years ago
4 years ago

  1. /**
  2. * Copyright 2019-2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include <stdio.h>
  17. #include <gtest/gtest.h>
  18. #include "ir_build/option_utils.h"
  19. #include "graph/testcase/ge_graph/graph_builder_utils.h"
  20. #include "graph/debug/ge_attr_define.h"
  21. #include "graph/utils/graph_utils.h"
  22. #include "ge/ge_ir_build.h"
  23. #include "graph/ops_stub.h"
  24. #include "ge/ir_build/attr_options/attr_options.h"
  25. #define protected public
  26. #define private public
  27. #undef private
  28. #undef protected
  29. const string DATA = "Data";
  30. const string AddNYes = "AddNYes";
  31. const string NETOUTPUT = "NetOutput";
  32. using namespace ge;
  33. class UtestIrCommon : public testing::Test {
  34. protected:
  35. void SetUp() {}
  36. void TearDown() {}
  37. };
  38. class UtestIrBuild : public testing::Test {
  39. protected:
  40. void SetUp() {}
  41. void TearDown() {}
  42. };
  43. static ge::OpDescPtr CreateOpDesc(const std::string &name, const std::string &type) {
  44. OpDescPtr op_desc = std::make_shared<ge::OpDesc>(name, type);
  45. ge::GeTensorDesc ge_tensor_desc;
  46. op_desc->AddInputDesc("input", ge_tensor_desc);
  47. op_desc->AddOutputDesc("output", ge_tensor_desc);
  48. return op_desc;
  49. }
  50. static ComputeGraphPtr BuildComputeGraph() {
  51. auto builder = ut::GraphBuilder("test");
  52. auto data1 = builder.AddNode("input1", DATA, 1, 1, FORMAT_NCHW, DT_FLOAT, {1, 2, 3});
  53. auto data2 = builder.AddNode("input2", DATA, 1, 1, FORMAT_NCHW, DT_FLOAT, {4, 10});
  54. auto addn1 = builder.AddNode("addn1", AddNYes, 2, 1);
  55. auto netoutput = builder.AddNode("netoutput", NETOUTPUT, 1, 0);
  56. builder.AddDataEdge(data1, 0, addn1, 0);
  57. builder.AddDataEdge(data2, 0, addn1, 1);
  58. builder.AddDataEdge(addn1, 0,netoutput, 0);
  59. return builder.GetGraph();
  60. }
  61. static ComputeGraphPtr BuildComputeGraph1() {
  62. auto builder = ut::GraphBuilder("test");
  63. auto data1 = builder.AddNode("input1", DATA, 1, 1, FORMAT_NCHW, DT_FLOAT, {1, 2, 3});
  64. auto data2 = builder.AddNode("input2", DATA, 1, 1, FORMAT_NCHW, DT_FLOAT, {4, 10});
  65. auto addn1 = builder.AddNode("addn1", AddNYes, 2, 1);
  66. auto node1 = builder.AddNode("addd", "Mul", 2, 1);
  67. auto node2 = builder.AddNode("ffm", "FrameworkOp", 2, 1);
  68. auto netoutput = builder.AddNode("netoutput", NETOUTPUT, 1, 0);
  69. builder.AddDataEdge(data1, 0, addn1, 0);
  70. builder.AddDataEdge(data2, 0, addn1, 1);
  71. builder.AddDataEdge(addn1, 0,netoutput, 0);
  72. return builder.GetGraph();
  73. }
  74. // data not set attr index;
  75. // but becasue of op proto, register attr index. so all data index is zero;
  76. static Graph BuildIrGraph() {
  77. auto data1 = op::Data("data1");
  78. auto data2 = op::Data("data2");
  79. auto data3 = op::Data("data3");
  80. std::vector<Operator> inputs {data1, data2, data3};
  81. std::vector<Operator> outputs;
  82. Graph graph("test_graph");
  83. graph.SetInputs(inputs).SetOutputs(outputs);
  84. return graph;
  85. }
  86. // data set attr index, but is not valid
  87. static Graph BuildIrGraph1() {
  88. auto data1 = op::Data("data1").set_attr_index(0);
  89. auto data2 = op::Data("data2").set_attr_index(1);
  90. auto data3 = op::Data("data3");
  91. auto data4 = op::Data("Test");
  92. std::vector<Operator> inputs {data1, data2, data3, data4};
  93. std::vector<Operator> outputs;
  94. Graph graph("test_graph");
  95. graph.AddNodeByOp(Operator("gg", "Mul"));
  96. graph.SetInputs(inputs).SetOutputs(outputs);
  97. return graph;
  98. }
  99. // data set attr index, but is not valid
  100. static Graph BuildIrGraph2() {
  101. auto data1 = op::Data("data1").set_attr_index(0);
  102. auto data2 = op::Data("data2");
  103. auto data3 = op::Data("data3").set_attr_index(2);
  104. std::vector<Operator> inputs {data1, data2, data3};
  105. std::vector<Operator> outputs;
  106. Graph graph("test_graph");
  107. graph.SetInputs(inputs).SetOutputs(outputs);
  108. return graph;
  109. }
  110. // data set attr index
  111. static Graph BuildIrGraph3() {
  112. auto data1 = op::Data("data1").set_attr_index(0);
  113. auto data2 = op::Data("data2").set_attr_index(1);
  114. auto data3 = op::Data("data3").set_attr_index(2);
  115. std::vector<Operator> inputs {data1, data2, data3};
  116. std::vector<Operator> outputs;
  117. Graph graph("test_graph");
  118. graph.SetInputs(inputs).SetOutputs(outputs);
  119. return graph;
  120. }
  121. TEST(UtestIrCommon, update_data_op_shape) {
  122. ge::OpDescPtr op_desc = CreateOpDesc("Data", "Data");
  123. map<string, vector<int64_t>> shape_map;
  124. shape_map["Data"] = {{1,2}};
  125. Status ret = UpdateDataOpShape(op_desc, shape_map);
  126. EXPECT_EQ(ret, ge::SUCCESS);
  127. }
  128. TEST(UtestIrCommon, update_data_op_shape_range) {
  129. ge::OpDescPtr op_desc = CreateOpDesc("Data", "Data");
  130. std::vector<std::vector<std::pair<int64_t, int64_t>>> index_shape_range_map;
  131. std::pair<int64_t, int64_t> range_pair(1, 2);
  132. vector<pair<int64_t, int64_t>> range_pair_tmp = { range_pair };
  133. index_shape_range_map.push_back(range_pair_tmp);
  134. AttrUtils::SetInt(op_desc, ATTR_NAME_INDEX, 0);
  135. Status ret = UpdateDataOpShapeRange(op_desc, index_shape_range_map);
  136. EXPECT_EQ(ret, ge::SUCCESS);
  137. }
  138. TEST(UtestIrCommon, update_dynamic_shape_range_success) {
  139. ComputeGraphPtr graph = BuildComputeGraph();
  140. std::string input_shape_range = "input1:[1, 2~3, -1];input2:[3~5, 10]";
  141. Status ret = UpdateDynamicInputShapeRange(graph, input_shape_range);
  142. EXPECT_EQ(ret, ge::SUCCESS);
  143. }
  144. TEST(UtestIrCommon, update_dynamic_shape_range_failed) {
  145. ComputeGraphPtr graph = BuildComputeGraph();
  146. // 1
  147. std::string input_shape_range = "input1;[1, 2~3, -1]";
  148. Status ret = UpdateDynamicInputShapeRange(graph, input_shape_range);
  149. EXPECT_EQ(ret, ge::PARAM_INVALID);
  150. // 2
  151. input_shape_range = "input1:[1, 2~3, -1)";
  152. ret = UpdateDynamicInputShapeRange(graph, input_shape_range);
  153. EXPECT_EQ(ret, ge::PARAM_INVALID);
  154. //3
  155. input_shape_range = "input1:[1, 3~2, -1];input2:[3~5, 10]";
  156. ret = UpdateDynamicInputShapeRange(graph, input_shape_range);
  157. EXPECT_EQ(ret, ge::FAILED);
  158. //4
  159. input_shape_range = "input1:[1, 2~-3, -1]";
  160. ret = UpdateDynamicInputShapeRange(graph, input_shape_range);
  161. EXPECT_EQ(ret, ge::PARAM_INVALID);
  162. //5
  163. input_shape_range = "input:[1, 2~3, -1]";
  164. ret = UpdateDynamicInputShapeRange(graph, input_shape_range);
  165. EXPECT_EQ(ret, ge::PARAM_INVALID);
  166. //6
  167. input_shape_range = "addn1:[1, 2~3, -1]";
  168. ret = UpdateDynamicInputShapeRange(graph, input_shape_range);
  169. EXPECT_EQ(ret, ge::PARAM_INVALID);
  170. }
  171. TEST(UtestIrCommon, check_dynamic_image_size_fail) {
  172. map<string, vector<int64_t>> shape_map;
  173. shape_map["input1"] = {8, 3, -1, -1};
  174. string input_format = "NCHW";
  175. string dynamic_image_size = "@64,64;128,128;";
  176. bool ret = CheckDynamicImagesizeInputShapeValid(shape_map, input_format, dynamic_image_size);
  177. EXPECT_EQ(ret, false);
  178. }
  179. TEST(UtestIrCommon, check_input_format_failed) {
  180. std::string format = "invalid";
  181. Status ret = CheckInputFormat(format);
  182. EXPECT_EQ(ret, ge::PARAM_INVALID);
  183. }
  184. TEST(UtestIrCommon, check_dynamic_batch_size_input_shape_succ) {
  185. map<string, vector<int64_t>> shape_map;
  186. shape_map.insert(std::pair<string, vector<int64_t>>("data", {-1, 2, 3}));
  187. std::string dynamic_batch_size = "11";
  188. bool ret = CheckDynamicBatchSizeInputShapeValid(shape_map, dynamic_batch_size);
  189. EXPECT_EQ(ret, true);
  190. }
  191. TEST(UtestIrCommon, check_dynamic_images_size_input_shape_succ) {
  192. map<string, vector<int64_t>> shape_map;
  193. shape_map.insert(std::pair<string, vector<int64_t>>("data", {4, -1, -1, 5}));
  194. std::string input_format = "NCHW";
  195. std::string dynamic_image_size = "4,5";
  196. Status ret = CheckDynamicImagesizeInputShapeValid(shape_map, input_format, dynamic_image_size);
  197. EXPECT_EQ(ret, ge::SUCCESS);
  198. }
  199. TEST(UtestIrCommon, check_dynamic_input_param_succ) {
  200. string dynamic_batch_size = "1";
  201. string dynamic_image_size;
  202. string dynamic_dims;
  203. string input_shape = "data:-1,3,244,244";
  204. string input_shape_range;
  205. string input_format = "NCHW";
  206. bool is_dynamic_input = false;
  207. Status ret = CheckDynamicInputParamValid(dynamic_batch_size, dynamic_image_size, dynamic_dims,
  208. input_shape, input_shape_range, input_format,is_dynamic_input);
  209. EXPECT_EQ(ret, ge::SUCCESS);
  210. }
  211. TEST(UtestIrCommon, check_dynamic_input_param_failed) {
  212. string dynamic_batch_size = "1";
  213. string dynamic_image_size;
  214. string dynamic_dims;
  215. string input_shape = "data:1,3,244,244";
  216. string input_shape_range;
  217. string input_format = "NCHW";
  218. bool is_dynamic_input = false;
  219. Status ret = CheckDynamicInputParamValid(dynamic_batch_size, dynamic_image_size, dynamic_dims,
  220. input_shape, input_shape_range, input_format,is_dynamic_input);
  221. EXPECT_EQ(ret, ge::PARAM_INVALID);
  222. }
  223. TEST(UtestIrCommon, check_modify_mixlist_param) {
  224. std::string precision_mode = "allow_mix_precision";
  225. std::string modify_mixlist = "/mixlist.json";
  226. Status ret = CheckModifyMixlistParamValid(precision_mode, modify_mixlist);
  227. EXPECT_EQ(ret, ge::SUCCESS);
  228. precision_mode = "";
  229. ret = CheckModifyMixlistParamValid(precision_mode, modify_mixlist);
  230. EXPECT_EQ(ret, ge::PARAM_INVALID);
  231. }
  232. TEST(UtestIrCommon, check_compress_weight) {
  233. std::string enable_compress_weight = "true";
  234. std::string compress_weight_conf="./";
  235. Status ret = CheckCompressWeightParamValid(enable_compress_weight, compress_weight_conf);
  236. EXPECT_EQ(ret, PARAM_INVALID);
  237. enable_compress_weight = "yes";
  238. compress_weight_conf = "./";
  239. ret = CheckCompressWeightParamValid(enable_compress_weight, compress_weight_conf);
  240. EXPECT_EQ(ret, PARAM_INVALID);
  241. }
  242. TEST(UtestIrCommon, check_param_failed) {
  243. std::string param_invalid = "invalid";
  244. Status ret = CheckOutputTypeParamValid(param_invalid);
  245. EXPECT_EQ(ret, PARAM_INVALID);
  246. ret = CheckBufferOptimizeParamValid(param_invalid);
  247. EXPECT_EQ(ret, PARAM_INVALID);
  248. ret = CheckKeepTypeParamValid(param_invalid);
  249. EXPECT_EQ(ret, PARAM_INVALID);
  250. ret = CheckInsertOpConfParamValid(param_invalid);
  251. EXPECT_EQ(ret, PARAM_INVALID);
  252. ret = CheckDisableReuseMemoryParamValid(param_invalid);
  253. EXPECT_EQ(ret, PARAM_INVALID);
  254. ret = CheckEnableSingleStreamParamValid(param_invalid);
  255. EXPECT_EQ(ret, PARAM_INVALID);
  256. std::string optypelist_for_implmode;
  257. std::string op_select_implmode = "1";
  258. ret = CheckImplmodeParamValid(optypelist_for_implmode, op_select_implmode);
  259. EXPECT_EQ(ret, PARAM_INVALID);
  260. ret = CheckLogParamValidAndSetLogLevel(param_invalid);
  261. }
  262. // Get attr index failed, when set input shape range
  263. TEST(UtestIrBuild, check_data_op_attr_index_invalid_0) {
  264. ComputeGraphPtr compute_graph = BuildComputeGraph();
  265. Graph graph = GraphUtils::CreateGraphFromComputeGraph(compute_graph);
  266. const map<string, string> build_options = {
  267. {"input_shape_range", "[1, 2~3, -1],[4~5, 3~5, 10],[1, 2~3, -1]"}
  268. };
  269. ModelBufferData model;
  270. graphStatus ret = aclgrphBuildModel(graph, build_options, model);
  271. EXPECT_EQ(ret, GRAPH_FAILED);
  272. }
  273. // not set attr index, when set input shape range
  274. TEST(UtestIrBuild, check_data_op_attr_index_invalid_1) {
  275. Graph graph = BuildIrGraph();
  276. const map<string, string> build_options = {
  277. {"input_shape_range", "[1, 2~3, -1],[4~5, 3~5, 10],[1, 2~3, -1]"}
  278. };
  279. ModelBufferData model;
  280. graphStatus ret = aclgrphBuildModel(graph, build_options, model);
  281. EXPECT_EQ(ret, GRAPH_FAILED);
  282. }
  283. // set attr index, but not valid, when set input shape range
  284. TEST(UtestIrBuild, check_data_op_attr_index_invalid_2) {
  285. Graph graph = BuildIrGraph1();
  286. const map<string, string> build_options = {
  287. {"input_shape_range", "[1, 2~3, -1],[4~5, 3~5, 10],[1, 2~3, -1]"}
  288. };
  289. ModelBufferData model;
  290. graphStatus ret = aclgrphBuildModel(graph, build_options, model);
  291. EXPECT_EQ(ret, GRAPH_FAILED);
  292. Graph graph2 = BuildIrGraph2();
  293. ret = aclgrphBuildModel(graph2, build_options, model);
  294. EXPECT_EQ(ret, GRAPH_FAILED);
  295. }
  296. // set attr index valid, when set input shape range
  297. // only check data op attr index valid func.
  298. TEST(UtestIrBuild, check_data_op_attr_index_valid) {
  299. Graph graph = BuildIrGraph3();
  300. const map<string, string> build_options = {
  301. {"input_shape_range", "[1, 2~3, -1],[4~5, 3~5, 10],[1, 2~3, -1]"}
  302. };
  303. ModelBufferData model;
  304. graphStatus ret = aclgrphBuildModel(graph, build_options, model);
  305. EXPECT_EQ(ret, GE_GENERATOR_GRAPH_MANAGER_BUILD_GRAPH_FAILED);
  306. }
  307. // set attr index invalid, when not set input shape range
  308. // only check data op attr index valid func.
  309. TEST(UtestIrBuild, check_data_attr_index_succ_no_input_range) {
  310. Graph graph = BuildIrGraph1();
  311. const map<string, string> build_options;
  312. ModelBufferData model;
  313. graphStatus ret = aclgrphBuildModel(graph, build_options, model);
  314. EXPECT_EQ(ret, GE_GENERATOR_GRAPH_MANAGER_BUILD_GRAPH_FAILED);
  315. }
  316. TEST(UtestIrBuild, check_modify_mixlist_param) {
  317. Graph graph = BuildIrGraph1();
  318. const std::map<std::string, std::string> build_options = {
  319. {"ge.exec.modify_mixlist", "/modify.json"}
  320. };
  321. ModelBufferData model;
  322. auto ret = aclgrphBuildModel(graph, build_options, model);
  323. EXPECT_EQ(ret, GRAPH_PARAM_INVALID);
  324. }
  325. TEST(UtestIrBuild, check_op_precision_mode_param) {
  326. Graph graph = BuildIrGraph1();
  327. const std::map<std::string, std::string> build_options = {
  328. {"ge.exec.op_precision_mode", "./op_precision_mode.ini"}
  329. };
  330. ModelBufferData model;
  331. auto ret = aclgrphBuildModel(graph, build_options, model);
  332. EXPECT_EQ(ret, GRAPH_PARAM_INVALID);
  333. }
  334. TEST(UtestIrBuild, check_build_model_and_build_step) {
  335. Graph graph_1 = BuildIrGraph1();
  336. const std::map<std::string, std::string> build_options_1 = {
  337. {"ge.buildMode", "xxx"}
  338. };
  339. ModelBufferData model_1;
  340. auto ret_1 = aclgrphBuildModel(graph_1, build_options_1, model_1);
  341. EXPECT_NE(ret_1, GRAPH_SUCCESS);
  342. Graph graph_2 = BuildIrGraph1();
  343. const std::map<std::string, std::string> build_options_2 = {
  344. {"ge.buildStep", "xxx"}
  345. };
  346. ModelBufferData model_2;
  347. auto ret_2 = aclgrphBuildModel(graph_2, build_options_2, model_2);
  348. EXPECT_NE(ret_2, GRAPH_SUCCESS);
  349. Graph graph_3 = BuildIrGraph1();
  350. const std::map<std::string, std::string> build_options_3 = {
  351. {"ge.buildMode", "tuning"}
  352. };
  353. ModelBufferData model_3;
  354. auto ret_3 = aclgrphBuildModel(graph_3, build_options_3, model_3);
  355. EXPECT_NE(ret_3, GRAPH_SUCCESS);
  356. }
  357. TEST(UtestIrBuild, atc_cfg_optype_param) {
  358. ComputeGraphPtr graph = BuildComputeGraph1();
  359. FILE *fp = fopen("./keep.txt", "w+");
  360. if (fp) {
  361. fprintf(fp, "Test\n");
  362. fprintf(fp, "OpType::Mul\n");
  363. fprintf(fp, "Optype::Sub\n");
  364. fclose(fp);
  365. }
  366. auto ret = KeepDtypeFunc(graph, "./keep.txt");
  367. (void)remove("./keep.txt");
  368. EXPECT_EQ(ret, GRAPH_PARAM_INVALID);
  369. }

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示