You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_network.py 15 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. import os
  10. import unittest
  11. import numpy as np
  12. from megenginelite import *
  13. set_log_level(2)
  14. def test_version():
  15. print("Lite verson: {}".format(version))
  16. def test_config():
  17. config = LiteConfig()
  18. config.bare_model_cryption_name = "nothing"
  19. print(config)
  20. def test_network_io():
  21. input_io1 = LiteIO("data1", is_host=False, io_type=LiteIOType.LITE_IO_VALUE)
  22. input_io2 = LiteIO(
  23. "data2",
  24. is_host=True,
  25. io_type=LiteIOType.LITE_IO_SHAPE,
  26. layout=LiteLayout([2, 4, 4]),
  27. )
  28. io = LiteNetworkIO()
  29. io.add_input(input_io1)
  30. io.add_input(input_io2)
  31. io.add_input("data3", False)
  32. output_io1 = LiteIO("out1", is_host=False)
  33. output_io2 = LiteIO("out2", is_host=True, layout=LiteLayout([1, 1000]))
  34. io.add_output(output_io1)
  35. io.add_output(output_io2)
  36. assert len(io.inputs) == 3
  37. assert len(io.outputs) == 2
  38. assert io.inputs[0] == input_io1
  39. assert io.outputs[0] == output_io1
  40. c_io = io._create_network_io()
  41. assert c_io.input_size == 3
  42. assert c_io.output_size == 2
  43. ins = [["data1", True], ["data2", False, LiteIOType.LITE_IO_SHAPE]]
  44. outs = [["out1", True], ["out2", False, LiteIOType.LITE_IO_VALUE]]
  45. io2 = LiteNetworkIO(ins, outs)
  46. assert len(io2.inputs) == 2
  47. assert len(io2.outputs) == 2
  48. io3 = LiteNetworkIO([input_io1, input_io2], [output_io1, output_io2])
  49. assert len(io3.inputs) == 2
  50. assert len(io3.outputs) == 2
  51. test_io = LiteIO("test")
  52. assert test_io.name == "test"
  53. test_io.name = "test2"
  54. assert test_io.name == "test2"
  55. class TestShuffleNet(unittest.TestCase):
  56. source_dir = os.getenv("LITE_TEST_RESOURCE")
  57. input_data_path = os.path.join(source_dir, "input_data.npy")
  58. correct_data_path = os.path.join(source_dir, "output_data.npy")
  59. model_path = os.path.join(source_dir, "shufflenet.mge")
  60. correct_data = np.load(correct_data_path).flatten()
  61. input_data = np.load(input_data_path)
  62. def check_correct(self, out_data, error=1e-4):
  63. out_data = out_data.flatten()
  64. assert np.isfinite(out_data.sum())
  65. assert self.correct_data.size == out_data.size
  66. for i in range(out_data.size):
  67. assert abs(out_data[i] - self.correct_data[i]) < error
  68. def do_forward(self, network, times=3):
  69. input_name = network.get_input_name(0)
  70. input_tensor = network.get_io_tensor(input_name)
  71. output_name = network.get_output_name(0)
  72. output_tensor = network.get_io_tensor(output_name)
  73. input_tensor.set_data_by_copy(self.input_data)
  74. for i in range(times):
  75. network.forward()
  76. network.wait()
  77. output_data = output_tensor.to_numpy()
  78. self.check_correct(output_data)
  79. class TestNetwork(TestShuffleNet):
  80. def test_decryption(self):
  81. model_path = os.path.join(self.source_dir, "shufflenet_crypt_aes.mge")
  82. config = LiteConfig()
  83. config.bare_model_cryption_name = "AES_default".encode("utf-8")
  84. network = LiteNetwork(config)
  85. network.load(model_path)
  86. self.do_forward(network)
  87. def test_pack_model(self):
  88. model_path = os.path.join(self.source_dir, "test_packed_model_rc4.lite")
  89. network = LiteNetwork()
  90. network.load(model_path)
  91. self.do_forward(network)
  92. def test_network_basic(self):
  93. network = LiteNetwork()
  94. network.load(self.model_path)
  95. input_name = network.get_input_name(0)
  96. input_tensor = network.get_io_tensor(input_name)
  97. output_name = network.get_output_name(0)
  98. output_tensor = network.get_io_tensor(output_name)
  99. assert input_tensor.layout.shapes[0] == 1
  100. assert input_tensor.layout.shapes[1] == 3
  101. assert input_tensor.layout.shapes[2] == 224
  102. assert input_tensor.layout.shapes[3] == 224
  103. assert input_tensor.layout.data_type == LiteDataType.LITE_FLOAT
  104. assert input_tensor.layout.ndim == 4
  105. self.do_forward(network)
  106. def test_network_shared_data(self):
  107. network = LiteNetwork()
  108. network.load(self.model_path)
  109. input_name = network.get_input_name(0)
  110. input_tensor = network.get_io_tensor(input_name)
  111. output_name = network.get_output_name(0)
  112. output_tensor = network.get_io_tensor(output_name)
  113. input_tensor.set_data_by_share(self.input_data)
  114. for i in range(3):
  115. network.forward()
  116. network.wait()
  117. output_data = output_tensor.to_numpy()
  118. self.check_correct(output_data)
  119. def test_network_get_name(self):
  120. network = LiteNetwork()
  121. network.load(self.model_path)
  122. input_names = network.get_all_input_name()
  123. assert input_names[0] == "data"
  124. output_names = network.get_all_output_name()
  125. assert output_names[0] == network.get_output_name(0)
  126. self.do_forward(network)
  127. def test_network_set_device_id(self):
  128. network = LiteNetwork()
  129. assert network.device_id == 0
  130. network.device_id = 1
  131. network.load(self.model_path)
  132. assert network.device_id == 1
  133. with self.assertRaises(RuntimeError):
  134. network.device_id = 1
  135. self.do_forward(network)
  136. def test_network_set_stream_id(self):
  137. network = LiteNetwork()
  138. assert network.stream_id == 0
  139. network.stream_id = 1
  140. network.load(self.model_path)
  141. assert network.stream_id == 1
  142. with self.assertRaises(RuntimeError):
  143. network.stream_id = 1
  144. self.do_forward(network)
  145. def test_network_set_thread_number(self):
  146. network = LiteNetwork()
  147. assert network.threads_number == 1
  148. network.threads_number = 2
  149. network.load(self.model_path)
  150. assert network.threads_number == 2
  151. with self.assertRaises(RuntimeError):
  152. network.threads_number = 2
  153. self.do_forward(network)
  154. def test_network_cpu_inplace(self):
  155. network = LiteNetwork()
  156. assert network.is_cpu_inplace_mode() == False
  157. network.enable_cpu_inplace_mode()
  158. network.load(self.model_path)
  159. assert network.is_cpu_inplace_mode() == True
  160. with self.assertRaises(RuntimeError):
  161. network.enable_cpu_inplace_mode()
  162. self.do_forward(network)
  163. def test_network_option(self):
  164. option = LiteOptions()
  165. option.weight_preprocess = 1
  166. option.var_sanity_check_first_run = 0
  167. config = LiteConfig(option=option)
  168. network = LiteNetwork(config=config)
  169. network.load(self.model_path)
  170. self.do_forward(network)
  171. def test_network_reset_io(self):
  172. option = LiteOptions()
  173. option.var_sanity_check_first_run = 0
  174. config = LiteConfig(option=option)
  175. input_io = LiteIO("data")
  176. ios = LiteNetworkIO()
  177. ios.add_input(input_io)
  178. network = LiteNetwork(config=config, io=ios)
  179. network.load(self.model_path)
  180. input_tensor = network.get_io_tensor("data")
  181. assert input_tensor.device_type == LiteDeviceType.LITE_CPU
  182. self.do_forward(network)
  183. def test_network_by_share(self):
  184. network = LiteNetwork()
  185. network.load(self.model_path)
  186. input_name = network.get_input_name(0)
  187. input_tensor = network.get_io_tensor(input_name)
  188. output_name = network.get_output_name(0)
  189. output_tensor = network.get_io_tensor(output_name)
  190. assert input_tensor.device_type == LiteDeviceType.LITE_CPU
  191. layout = LiteLayout(self.input_data.shape, self.input_data.dtype)
  192. tensor_tmp = LiteTensor(layout=layout)
  193. tensor_tmp.set_data_by_share(self.input_data)
  194. input_tensor.share_memory_with(tensor_tmp)
  195. for i in range(3):
  196. network.forward()
  197. network.wait()
  198. output_data = output_tensor.to_numpy()
  199. self.check_correct(output_data)
  200. def test_network_share_weights(self):
  201. option = LiteOptions()
  202. option.var_sanity_check_first_run = 0
  203. config = LiteConfig(option=option)
  204. src_network = LiteNetwork(config=config)
  205. src_network.load(self.model_path)
  206. new_network = LiteNetwork()
  207. new_network.enable_cpu_inplace_mode()
  208. new_network.share_weights_with(src_network)
  209. self.do_forward(src_network)
  210. self.do_forward(new_network)
  211. def test_network_share_runtime_memory(self):
  212. option = LiteOptions()
  213. option.var_sanity_check_first_run = 0
  214. config = LiteConfig(option=option)
  215. src_network = LiteNetwork(config=config)
  216. src_network.load(self.model_path)
  217. new_network = LiteNetwork()
  218. new_network.enable_cpu_inplace_mode()
  219. new_network.share_runtime_memroy(src_network)
  220. new_network.load(self.model_path)
  221. self.do_forward(src_network)
  222. self.do_forward(new_network)
  223. def test_network_async(self):
  224. count = 0
  225. finished = False
  226. def async_callback():
  227. nonlocal finished
  228. finished = True
  229. return 0
  230. option = LiteOptions()
  231. option.var_sanity_check_first_run = 0
  232. config = LiteConfig(option=option)
  233. network = LiteNetwork(config=config)
  234. network.load(self.model_path)
  235. network.async_with_callback(async_callback)
  236. input_tensor = network.get_io_tensor(network.get_input_name(0))
  237. output_tensor = network.get_io_tensor(network.get_output_name(0))
  238. input_tensor.set_data_by_share(self.input_data)
  239. network.forward()
  240. while not finished:
  241. count += 1
  242. assert count > 0
  243. output_data = output_tensor.to_numpy()
  244. self.check_correct(output_data)
  245. def test_network_start_callback(self):
  246. network = LiteNetwork()
  247. network.load(self.model_path)
  248. start_checked = False
  249. def start_callback(ios):
  250. nonlocal start_checked
  251. start_checked = True
  252. assert len(ios) == 1
  253. for key in ios:
  254. io = key
  255. data = ios[key].to_numpy().flatten()
  256. input_data = self.input_data.flatten()
  257. assert data.size == input_data.size
  258. assert io.name == "data"
  259. for i in range(data.size):
  260. assert abs(data[i] - input_data[i]) < 1e-5
  261. return 0
  262. network.set_start_callback(start_callback)
  263. self.do_forward(network, 1)
  264. assert start_checked == True
  265. def test_network_finish_callback(self):
  266. network = LiteNetwork()
  267. network.load(self.model_path)
  268. finish_checked = False
  269. def finish_callback(ios):
  270. nonlocal finish_checked
  271. finish_checked = True
  272. assert len(ios) == 1
  273. for key in ios:
  274. io = key
  275. data = ios[key].to_numpy().flatten()
  276. output_data = self.correct_data.flatten()
  277. assert data.size == output_data.size
  278. for i in range(data.size):
  279. assert abs(data[i] - output_data[i]) < 1e-5
  280. return 0
  281. network.set_finish_callback(finish_callback)
  282. self.do_forward(network, 1)
  283. assert finish_checked == True
  284. def test_enable_profile(self):
  285. network = LiteNetwork()
  286. network.load(self.model_path)
  287. network.enable_profile_performance("./profile.json")
  288. self.do_forward(network)
  289. fi = open("./profile.json", "r")
  290. fi.close()
  291. os.remove("./profile.json")
  292. def test_io_txt_dump(self):
  293. network = LiteNetwork()
  294. network.load(self.model_path)
  295. network.io_txt_dump("./io_txt.txt")
  296. self.do_forward(network)
  297. def test_io_bin_dump(self):
  298. import shutil
  299. folder = "./out"
  300. network = LiteNetwork()
  301. network.load(self.model_path)
  302. if not os.path.exists(folder):
  303. os.mkdir(folder)
  304. network.io_bin_dump(folder)
  305. self.do_forward(network)
  306. shutil.rmtree(folder)
  307. def test_algo_workspace_limit(self):
  308. network = LiteNetwork()
  309. network.load(self.model_path)
  310. print("modify the workspace limit.")
  311. network.set_network_algo_workspace_limit(10000)
  312. self.do_forward(network)
  313. def test_network_algo_policy(self):
  314. network = LiteNetwork()
  315. network.load(self.model_path)
  316. network.set_network_algo_policy(
  317. LiteAlgoSelectStrategy.LITE_ALGO_PROFILE
  318. | LiteAlgoSelectStrategy.LITE_ALGO_REPRODUCIBLE
  319. )
  320. self.do_forward(network)
  321. def test_network_algo_policy_ignore_batch(self):
  322. network = LiteNetwork()
  323. network.load(self.model_path)
  324. network.set_network_algo_policy(
  325. LiteAlgoSelectStrategy.LITE_ALGO_PROFILE,
  326. shared_batch_size=1,
  327. binary_equal_between_batch=True,
  328. )
  329. self.do_forward(network)
  330. def test_device_tensor_no_copy(self):
  331. # construct LiteOption
  332. net_config = LiteConfig()
  333. net_config.options.force_output_use_user_specified_memory = True
  334. network = LiteNetwork(config=net_config)
  335. network.load(self.model_path)
  336. input_tensor = network.get_io_tensor("data")
  337. # fill input_data with device data
  338. input_tensor.set_data_by_share(self.input_data)
  339. output_tensor = network.get_io_tensor(network.get_output_name(0))
  340. out_array = np.zeros(output_tensor.layout.shapes, output_tensor.layout.dtype)
  341. output_tensor.set_data_by_share(out_array)
  342. # inference
  343. for i in range(2):
  344. network.forward()
  345. network.wait()
  346. self.check_correct(out_array)
  347. def test_enable_global_layout_transform(self):
  348. network = LiteNetwork()
  349. network.enable_global_layout_transform()
  350. network.load(self.model_path)
  351. self.do_forward(network)
  352. def test_dump_layout_transform_model(self):
  353. network = LiteNetwork()
  354. network.enable_global_layout_transform()
  355. network.load(self.model_path)
  356. network.dump_layout_transform_model("./model_afer_layoutTrans.mgb")
  357. self.do_forward(network)
  358. fi = open("./model_afer_layoutTrans.mgb", "r")
  359. fi.close()
  360. os.remove("./model_afer_layoutTrans.mgb")