You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_network.py 16 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502
  1. # -*- coding: utf-8 -*-
  2. import os
  3. import unittest
  4. import numpy as np
  5. from megenginelite import *
  6. set_log_level(2)
  7. def test_version():
  8. print("Lite verson: {}".format(version))
  9. def test_config():
  10. config = LiteConfig()
  11. config.bare_model_cryption_name = "nothing"
  12. print(config)
  13. def test_network_io():
  14. input_io1 = LiteIO("data1", is_host=False, io_type=LiteIOType.LITE_IO_VALUE)
  15. input_io2 = LiteIO(
  16. "data2",
  17. is_host=True,
  18. io_type=LiteIOType.LITE_IO_SHAPE,
  19. layout=LiteLayout([2, 4, 4]),
  20. )
  21. io = LiteNetworkIO()
  22. io.add_input(input_io1)
  23. io.add_input(input_io2)
  24. io.add_input("data3", False)
  25. output_io1 = LiteIO("out1", is_host=False)
  26. output_io2 = LiteIO("out2", is_host=True, layout=LiteLayout([1, 1000]))
  27. io.add_output(output_io1)
  28. io.add_output(output_io2)
  29. assert len(io.inputs) == 3
  30. assert len(io.outputs) == 2
  31. assert io.inputs[0] == input_io1
  32. assert io.outputs[0] == output_io1
  33. c_io = io._create_network_io()
  34. assert c_io.input_size == 3
  35. assert c_io.output_size == 2
  36. ins = [["data1", True], ["data2", False, LiteIOType.LITE_IO_SHAPE]]
  37. outs = [["out1", True], ["out2", False, LiteIOType.LITE_IO_VALUE]]
  38. io2 = LiteNetworkIO(ins, outs)
  39. assert len(io2.inputs) == 2
  40. assert len(io2.outputs) == 2
  41. io3 = LiteNetworkIO([input_io1, input_io2], [output_io1, output_io2])
  42. assert len(io3.inputs) == 2
  43. assert len(io3.outputs) == 2
  44. test_io = LiteIO("test")
  45. assert test_io.name == "test"
  46. test_io.name = "test2"
  47. assert test_io.name == "test2"
  48. class TestShuffleNet(unittest.TestCase):
  49. source_dir = os.getenv("LITE_TEST_RESOURCE")
  50. input_data_path = os.path.join(source_dir, "input_data.npy")
  51. correct_data_path = os.path.join(source_dir, "output_data.npy")
  52. model_path = os.path.join(source_dir, "shufflenet.mge")
  53. correct_data = np.load(correct_data_path).flatten()
  54. input_data = np.load(input_data_path)
  55. def check_correct(self, out_data, error=1e-4):
  56. out_data = out_data.flatten()
  57. assert np.isfinite(out_data.sum())
  58. assert self.correct_data.size == out_data.size
  59. for i in range(out_data.size):
  60. assert abs(out_data[i] - self.correct_data[i]) < error
  61. def do_forward(self, network, times=3):
  62. input_name = network.get_input_name(0)
  63. input_tensor = network.get_io_tensor(input_name)
  64. output_name = network.get_output_name(0)
  65. output_tensor = network.get_io_tensor(output_name)
  66. input_tensor.set_data_by_copy(self.input_data)
  67. for i in range(times):
  68. network.forward()
  69. network.wait()
  70. output_data = output_tensor.to_numpy()
  71. self.check_correct(output_data)
  72. class TestNetwork(TestShuffleNet):
  73. def test_decryption(self):
  74. model_path = os.path.join(self.source_dir, "shufflenet_crypt_aes.mge")
  75. config = LiteConfig()
  76. config.bare_model_cryption_name = "AES_default".encode("utf-8")
  77. network = LiteNetwork(config)
  78. network.load(model_path)
  79. self.do_forward(network)
  80. def test_pack_model(self):
  81. model_path = os.path.join(self.source_dir, "test_packed_model_rc4.lite")
  82. network = LiteNetwork()
  83. network.load(model_path)
  84. self.do_forward(network)
  85. def test_disable_model_config(self):
  86. model_path = os.path.join(self.source_dir, "test_packed_model_rc4.lite")
  87. network = LiteNetwork()
  88. network.extra_configure(LiteExtraConfig(True))
  89. network.load(model_path)
  90. self.do_forward(network)
  91. def test_pack_cache_to_model(self):
  92. model_path = os.path.join(self.source_dir, "test_pack_cache_to_model.lite")
  93. network = LiteNetwork()
  94. network.load(model_path)
  95. self.do_forward(network)
  96. def test_network_basic(self):
  97. network = LiteNetwork()
  98. network.load(self.model_path)
  99. input_name = network.get_input_name(0)
  100. input_tensor = network.get_io_tensor(input_name)
  101. output_name = network.get_output_name(0)
  102. output_tensor = network.get_io_tensor(output_name)
  103. assert input_tensor.layout.shapes[0] == 1
  104. assert input_tensor.layout.shapes[1] == 3
  105. assert input_tensor.layout.shapes[2] == 224
  106. assert input_tensor.layout.shapes[3] == 224
  107. assert input_tensor.layout.data_type == LiteDataType.LITE_FLOAT
  108. assert input_tensor.layout.ndim == 4
  109. self.do_forward(network)
  110. def test_network_shared_data(self):
  111. network = LiteNetwork()
  112. network.load(self.model_path)
  113. input_name = network.get_input_name(0)
  114. input_tensor = network.get_io_tensor(input_name)
  115. output_name = network.get_output_name(0)
  116. output_tensor = network.get_io_tensor(output_name)
  117. input_tensor.set_data_by_share(self.input_data)
  118. for i in range(3):
  119. network.forward()
  120. network.wait()
  121. output_data = output_tensor.to_numpy()
  122. self.check_correct(output_data)
  123. def test_network_get_name(self):
  124. network = LiteNetwork()
  125. network.load(self.model_path)
  126. input_names = network.get_all_input_name()
  127. assert input_names[0] == "data"
  128. output_names = network.get_all_output_name()
  129. assert output_names[0] == network.get_output_name(0)
  130. self.do_forward(network)
  131. def test_network_set_device_id(self):
  132. network = LiteNetwork()
  133. assert network.device_id == 0
  134. network.device_id = 1
  135. network.load(self.model_path)
  136. assert network.device_id == 1
  137. with self.assertRaises(RuntimeError):
  138. network.device_id = 1
  139. self.do_forward(network)
  140. def test_network_set_stream_id(self):
  141. network = LiteNetwork()
  142. assert network.stream_id == 0
  143. network.stream_id = 1
  144. network.load(self.model_path)
  145. assert network.stream_id == 1
  146. with self.assertRaises(RuntimeError):
  147. network.stream_id = 1
  148. self.do_forward(network)
  149. def test_network_set_thread_number(self):
  150. network = LiteNetwork()
  151. assert network.threads_number == 1
  152. network.threads_number = 2
  153. network.load(self.model_path)
  154. assert network.threads_number == 2
  155. with self.assertRaises(RuntimeError):
  156. network.threads_number = 2
  157. self.do_forward(network)
  158. def test_network_cpu_inplace(self):
  159. network = LiteNetwork()
  160. assert network.is_cpu_inplace_mode() == False
  161. network.enable_cpu_inplace_mode()
  162. network.load(self.model_path)
  163. assert network.is_cpu_inplace_mode() == True
  164. with self.assertRaises(RuntimeError):
  165. network.enable_cpu_inplace_mode()
  166. self.do_forward(network)
  167. def test_network_option(self):
  168. option = LiteOptions()
  169. option.weight_preprocess = 1
  170. option.var_sanity_check_first_run = 0
  171. config = LiteConfig(option=option)
  172. network = LiteNetwork(config=config)
  173. network.load(self.model_path)
  174. self.do_forward(network)
  175. def test_network_reset_io(self):
  176. option = LiteOptions()
  177. option.var_sanity_check_first_run = 0
  178. config = LiteConfig(option=option)
  179. input_io = LiteIO("data")
  180. ios = LiteNetworkIO()
  181. ios.add_input(input_io)
  182. network = LiteNetwork(config=config, io=ios)
  183. network.load(self.model_path)
  184. input_tensor = network.get_io_tensor("data")
  185. assert input_tensor.device_type == LiteDeviceType.LITE_CPU
  186. self.do_forward(network)
  187. def test_network_by_share(self):
  188. network = LiteNetwork()
  189. network.load(self.model_path)
  190. input_name = network.get_input_name(0)
  191. input_tensor = network.get_io_tensor(input_name)
  192. output_name = network.get_output_name(0)
  193. output_tensor = network.get_io_tensor(output_name)
  194. assert input_tensor.device_type == LiteDeviceType.LITE_CPU
  195. layout = LiteLayout(self.input_data.shape, self.input_data.dtype)
  196. tensor_tmp = LiteTensor(layout=layout)
  197. tensor_tmp.set_data_by_share(self.input_data)
  198. input_tensor.share_memory_with(tensor_tmp)
  199. for i in range(3):
  200. network.forward()
  201. network.wait()
  202. output_data = output_tensor.to_numpy()
  203. self.check_correct(output_data)
  204. def test_network_share_weights(self):
  205. option = LiteOptions()
  206. option.var_sanity_check_first_run = 0
  207. config = LiteConfig(option=option)
  208. src_network = LiteNetwork(config=config)
  209. src_network.load(self.model_path)
  210. new_network = LiteNetwork()
  211. new_network.enable_cpu_inplace_mode()
  212. new_network.share_weights_with(src_network)
  213. self.do_forward(src_network)
  214. self.do_forward(new_network)
  215. def test_network_share_runtime_memory(self):
  216. option = LiteOptions()
  217. option.var_sanity_check_first_run = 0
  218. config = LiteConfig(option=option)
  219. src_network = LiteNetwork(config=config)
  220. src_network.load(self.model_path)
  221. new_network = LiteNetwork()
  222. new_network.enable_cpu_inplace_mode()
  223. new_network.share_runtime_memroy(src_network)
  224. new_network.load(self.model_path)
  225. self.do_forward(src_network)
  226. self.do_forward(new_network)
  227. def test_network_async(self):
  228. count = 0
  229. finished = False
  230. def async_callback():
  231. nonlocal finished
  232. finished = True
  233. return 0
  234. option = LiteOptions()
  235. option.var_sanity_check_first_run = 0
  236. config = LiteConfig(option=option)
  237. network = LiteNetwork(config=config)
  238. network.load(self.model_path)
  239. network.async_with_callback(async_callback)
  240. input_tensor = network.get_io_tensor(network.get_input_name(0))
  241. output_tensor = network.get_io_tensor(network.get_output_name(0))
  242. input_tensor.set_data_by_share(self.input_data)
  243. network.forward()
  244. while not finished:
  245. count += 1
  246. assert count > 0
  247. output_data = output_tensor.to_numpy()
  248. self.check_correct(output_data)
  249. def test_network_start_callback(self):
  250. network = LiteNetwork()
  251. network.load(self.model_path)
  252. start_checked = False
  253. def start_callback(ios):
  254. nonlocal start_checked
  255. start_checked = True
  256. assert len(ios) == 1
  257. for key in ios:
  258. io = key
  259. data = ios[key].to_numpy().flatten()
  260. input_data = self.input_data.flatten()
  261. assert data.size == input_data.size
  262. assert io.name == "data"
  263. for i in range(data.size):
  264. assert abs(data[i] - input_data[i]) < 1e-5
  265. return 0
  266. network.set_start_callback(start_callback)
  267. self.do_forward(network, 1)
  268. assert start_checked == True
  269. def test_network_finish_callback(self):
  270. network = LiteNetwork()
  271. network.load(self.model_path)
  272. finish_checked = False
  273. def finish_callback(ios):
  274. nonlocal finish_checked
  275. finish_checked = True
  276. assert len(ios) == 1
  277. for key in ios:
  278. io = key
  279. data = ios[key].to_numpy().flatten()
  280. output_data = self.correct_data.flatten()
  281. assert data.size == output_data.size
  282. for i in range(data.size):
  283. assert abs(data[i] - output_data[i]) < 1e-5
  284. return 0
  285. network.set_finish_callback(finish_callback)
  286. self.do_forward(network, 1)
  287. assert finish_checked == True
  288. def test_enable_profile(self):
  289. network = LiteNetwork()
  290. network.load(self.model_path)
  291. network.enable_profile_performance("./profile.json")
  292. self.do_forward(network)
  293. fi = open("./profile.json", "r")
  294. fi.close()
  295. os.remove("./profile.json")
  296. def test_io_txt_dump(self):
  297. network = LiteNetwork()
  298. network.load(self.model_path)
  299. network.io_txt_dump("./io_txt.txt")
  300. self.do_forward(network)
  301. def test_io_bin_dump(self):
  302. import shutil
  303. folder = "./out"
  304. network = LiteNetwork()
  305. network.load(self.model_path)
  306. if not os.path.exists(folder):
  307. os.mkdir(folder)
  308. network.io_bin_dump(folder)
  309. self.do_forward(network)
  310. shutil.rmtree(folder)
  311. def test_algo_workspace_limit(self):
  312. network = LiteNetwork()
  313. network.load(self.model_path)
  314. print("modify the workspace limit.")
  315. network.set_network_algo_workspace_limit(10000)
  316. self.do_forward(network)
  317. def test_network_algo_policy(self):
  318. network = LiteNetwork()
  319. network.load(self.model_path)
  320. network.set_network_algo_policy(
  321. LiteAlgoSelectStrategy.LITE_ALGO_PROFILE
  322. | LiteAlgoSelectStrategy.LITE_ALGO_REPRODUCIBLE
  323. )
  324. self.do_forward(network)
  325. def test_network_algo_policy_ignore_batch(self):
  326. network = LiteNetwork()
  327. network.load(self.model_path)
  328. network.set_network_algo_policy(
  329. LiteAlgoSelectStrategy.LITE_ALGO_PROFILE,
  330. shared_batch_size=1,
  331. binary_equal_between_batch=True,
  332. )
  333. self.do_forward(network)
  334. def test_device_tensor_no_copy(self):
  335. # construct LiteOption
  336. net_config = LiteConfig()
  337. net_config.options.force_output_use_user_specified_memory = True
  338. network = LiteNetwork(config=net_config)
  339. network.load(self.model_path)
  340. input_tensor = network.get_io_tensor("data")
  341. # fill input_data with device data
  342. input_tensor.set_data_by_share(self.input_data)
  343. output_tensor = network.get_io_tensor(network.get_output_name(0))
  344. out_array = np.zeros(output_tensor.layout.shapes, output_tensor.layout.dtype)
  345. output_tensor.set_data_by_share(out_array)
  346. # inference
  347. for i in range(2):
  348. network.forward()
  349. network.wait()
  350. self.check_correct(out_array)
  351. def test_enable_global_layout_transform(self):
  352. network = LiteNetwork()
  353. network.enable_global_layout_transform()
  354. network.load(self.model_path)
  355. self.do_forward(network)
  356. def test_dump_layout_transform_model(self):
  357. network = LiteNetwork()
  358. network.enable_global_layout_transform()
  359. network.load(self.model_path)
  360. network.dump_layout_transform_model("./model_afer_layoutTrans.mgb")
  361. self.do_forward(network)
  362. fi = open("./model_afer_layoutTrans.mgb", "r")
  363. fi.close()
  364. os.remove("./model_afer_layoutTrans.mgb")
  365. def test_fast_run_and_global_layout_transform(self):
  366. config_ = LiteConfig()
  367. network = LiteNetwork(config_)
  368. fast_run_cache = "./algo_cache"
  369. global_layout_transform_model = "./model_afer_layoutTrans.mgb"
  370. network.set_network_algo_policy(
  371. LiteAlgoSelectStrategy.LITE_ALGO_PROFILE
  372. | LiteAlgoSelectStrategy.LITE_ALGO_OPTIMIZED
  373. )
  374. network.enable_global_layout_transform()
  375. network.load(self.model_path)
  376. self.do_forward(network)
  377. network.dump_layout_transform_model(global_layout_transform_model)
  378. LiteGlobal.dump_persistent_cache(fast_run_cache)
  379. fi = open(fast_run_cache, "r")
  380. fi.close()
  381. fi = open(global_layout_transform_model, "r")
  382. fi.close()
  383. LiteGlobal.set_persistent_cache(path=fast_run_cache)
  384. self.do_forward(network)
  385. os.remove(fast_run_cache)
  386. os.remove(global_layout_transform_model)