You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_network.py 13 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. import os
  10. import unittest
  11. import numpy as np
  12. from megenginelite import *
  13. set_log_level(2)
  14. def test_version():
  15. print("Lite verson: {}".format(version))
  16. def test_network_io():
  17. input_io1 = LiteIO("data1", is_host=False, io_type=LiteIOType.LITE_IO_VALUE)
  18. input_io2 = LiteIO(
  19. "data2",
  20. is_host=True,
  21. io_type=LiteIOType.LITE_IO_SHAPE,
  22. layout=LiteLayout([2, 4, 4]),
  23. )
  24. io = LiteNetworkIO()
  25. io.add_input(input_io1)
  26. io.add_input(input_io2)
  27. output_io1 = LiteIO("out1", is_host=False)
  28. output_io2 = LiteIO("out2", is_host=True, layout=LiteLayout([1, 1000]))
  29. io.add_output(output_io1)
  30. io.add_output(output_io2)
  31. assert len(io.inputs) == 2
  32. assert len(io.outputs) == 2
  33. assert io.inputs[0] == input_io1
  34. assert io.outputs[0] == output_io1
  35. c_io = io._create_network_io()
  36. assert c_io.input_size == 2
  37. assert c_io.output_size == 2
  38. class TestShuffleNet(unittest.TestCase):
  39. source_dir = os.getenv("LITE_TEST_RESOUCE")
  40. input_data_path = os.path.join(source_dir, "input_data.npy")
  41. correct_data_path = os.path.join(source_dir, "output_data.npy")
  42. model_path = os.path.join(source_dir, "shufflenet.mge")
  43. correct_data = np.load(correct_data_path).flatten()
  44. input_data = np.load(input_data_path)
  45. def check_correct(self, out_data, error=1e-4):
  46. out_data = out_data.flatten()
  47. assert np.isfinite(out_data.sum())
  48. assert self.correct_data.size == out_data.size
  49. for i in range(out_data.size):
  50. assert abs(out_data[i] - self.correct_data[i]) < error
  51. def do_forward(self, network, times=3):
  52. input_name = network.get_input_name(0)
  53. input_tensor = network.get_io_tensor(input_name)
  54. output_name = network.get_output_name(0)
  55. output_tensor = network.get_io_tensor(output_name)
  56. input_tensor.set_data_by_copy(self.input_data)
  57. for i in range(times):
  58. network.forward()
  59. network.wait()
  60. output_data = output_tensor.to_numpy()
  61. self.check_correct(output_data)
  62. class TestNetwork(TestShuffleNet):
  63. def test_decryption(self):
  64. model_path = os.path.join(self.source_dir, "shufflenet_crypt_aes.mge")
  65. config = LiteConfig()
  66. config.bare_model_cryption_name = "AES_default".encode("utf-8")
  67. network = LiteNetwork(config)
  68. network.load(model_path)
  69. self.do_forward(network)
  70. def test_pack_model(self):
  71. model_path = os.path.join(self.source_dir, "test_packed_model_rc4.lite")
  72. network = LiteNetwork()
  73. network.load(model_path)
  74. self.do_forward(network)
  75. def test_network_basic(self):
  76. network = LiteNetwork()
  77. network.load(self.model_path)
  78. input_name = network.get_input_name(0)
  79. input_tensor = network.get_io_tensor(input_name)
  80. output_name = network.get_output_name(0)
  81. output_tensor = network.get_io_tensor(output_name)
  82. assert input_tensor.layout.shapes[0] == 1
  83. assert input_tensor.layout.shapes[1] == 3
  84. assert input_tensor.layout.shapes[2] == 224
  85. assert input_tensor.layout.shapes[3] == 224
  86. assert input_tensor.layout.data_type == LiteDataType.LITE_FLOAT
  87. assert input_tensor.layout.ndim == 4
  88. self.do_forward(network)
  89. def test_network_shared_data(self):
  90. network = LiteNetwork()
  91. network.load(self.model_path)
  92. input_name = network.get_input_name(0)
  93. input_tensor = network.get_io_tensor(input_name)
  94. output_name = network.get_output_name(0)
  95. output_tensor = network.get_io_tensor(output_name)
  96. input_tensor.set_data_by_share(self.input_data)
  97. for i in range(3):
  98. network.forward()
  99. network.wait()
  100. output_data = output_tensor.to_numpy()
  101. self.check_correct(output_data)
  102. def test_network_get_name(self):
  103. network = LiteNetwork()
  104. network.load(self.model_path)
  105. input_names = network.get_all_input_name()
  106. assert input_names[0] == "data"
  107. output_names = network.get_all_output_name()
  108. assert output_names[0] == network.get_output_name(0)
  109. self.do_forward(network)
  110. def test_network_set_device_id(self):
  111. network = LiteNetwork()
  112. assert network.device_id == 0
  113. network.device_id = 1
  114. network.load(self.model_path)
  115. assert network.device_id == 1
  116. with self.assertRaises(RuntimeError):
  117. network.device_id = 1
  118. self.do_forward(network)
  119. def test_network_set_stream_id(self):
  120. network = LiteNetwork()
  121. assert network.stream_id == 0
  122. network.stream_id = 1
  123. network.load(self.model_path)
  124. assert network.stream_id == 1
  125. with self.assertRaises(RuntimeError):
  126. network.stream_id = 1
  127. self.do_forward(network)
  128. def test_network_set_thread_number(self):
  129. network = LiteNetwork()
  130. assert network.threads_number == 1
  131. network.threads_number = 2
  132. network.load(self.model_path)
  133. assert network.threads_number == 2
  134. with self.assertRaises(RuntimeError):
  135. network.threads_number = 2
  136. self.do_forward(network)
  137. def test_network_cpu_inplace(self):
  138. network = LiteNetwork()
  139. assert network.is_cpu_inplace_mode() == False
  140. network.enable_cpu_inplace_mode()
  141. network.load(self.model_path)
  142. assert network.is_cpu_inplace_mode() == True
  143. with self.assertRaises(RuntimeError):
  144. network.enable_cpu_inplace_mode()
  145. self.do_forward(network)
  146. def test_network_option(self):
  147. option = LiteOptions()
  148. option.weight_preprocess = 1
  149. option.var_sanity_check_first_run = 0
  150. config = LiteConfig(option=option)
  151. network = LiteNetwork(config=config)
  152. network.load(self.model_path)
  153. self.do_forward(network)
  154. def test_network_reset_io(self):
  155. option = LiteOptions()
  156. option.var_sanity_check_first_run = 0
  157. config = LiteConfig(option=option)
  158. input_io = LiteIO("data")
  159. ios = LiteNetworkIO()
  160. ios.add_input(input_io)
  161. network = LiteNetwork(config=config, io=ios)
  162. network.load(self.model_path)
  163. input_tensor = network.get_io_tensor("data")
  164. assert input_tensor.device_type == LiteDeviceType.LITE_CPU
  165. self.do_forward(network)
  166. def test_network_by_share(self):
  167. network = LiteNetwork()
  168. network.load(self.model_path)
  169. input_name = network.get_input_name(0)
  170. input_tensor = network.get_io_tensor(input_name)
  171. output_name = network.get_output_name(0)
  172. output_tensor = network.get_io_tensor(output_name)
  173. assert input_tensor.device_type == LiteDeviceType.LITE_CPU
  174. layout = LiteLayout(self.input_data.shape, self.input_data.dtype)
  175. tensor_tmp = LiteTensor(layout=layout)
  176. tensor_tmp.set_data_by_share(self.input_data)
  177. input_tensor.share_memory_with(tensor_tmp)
  178. for i in range(3):
  179. network.forward()
  180. network.wait()
  181. output_data = output_tensor.to_numpy()
  182. self.check_correct(output_data)
  183. def test_network_share_weights(self):
  184. option = LiteOptions()
  185. option.var_sanity_check_first_run = 0
  186. config = LiteConfig(option=option)
  187. src_network = LiteNetwork(config=config)
  188. src_network.load(self.model_path)
  189. new_network = LiteNetwork()
  190. new_network.enable_cpu_inplace_mode()
  191. new_network.share_weights_with(src_network)
  192. self.do_forward(src_network)
  193. self.do_forward(new_network)
  194. def test_network_share_runtime_memory(self):
  195. option = LiteOptions()
  196. option.var_sanity_check_first_run = 0
  197. config = LiteConfig(option=option)
  198. src_network = LiteNetwork(config=config)
  199. src_network.load(self.model_path)
  200. new_network = LiteNetwork()
  201. new_network.enable_cpu_inplace_mode()
  202. new_network.share_runtime_memroy(src_network)
  203. new_network.load(self.model_path)
  204. self.do_forward(src_network)
  205. self.do_forward(new_network)
  206. # def test_network_async(self):
  207. # count = 0
  208. # finished = False
  209. #
  210. # def async_callback():
  211. # nonlocal finished
  212. # finished = True
  213. # return 0
  214. #
  215. # option = LiteOptions()
  216. # option.var_sanity_check_first_run = 0
  217. # config = LiteConfig(option=option)
  218. #
  219. # network = LiteNetwork(config=config)
  220. # network.load(self.model_path)
  221. #
  222. # network.async_with_callback(async_callback)
  223. #
  224. # input_tensor = network.get_io_tensor(network.get_input_name(0))
  225. # output_tensor = network.get_io_tensor(network.get_output_name(0))
  226. #
  227. # input_tensor.set_data_by_share(self.input_data)
  228. # network.forward()
  229. #
  230. # while not finished:
  231. # count += 1
  232. #
  233. # assert count > 0
  234. # output_data = output_tensor.to_numpy()
  235. # self.check_correct(output_data)
  236. #
  237. # def test_network_start_callback(self):
  238. # network = LiteNetwork()
  239. # network.load(self.model_path)
  240. # start_checked = False
  241. #
  242. # @start_finish_callback
  243. # def start_callback(ios):
  244. # nonlocal start_checked
  245. # start_checked = True
  246. # assert len(ios) == 1
  247. # for key in ios:
  248. # io = key
  249. # data = ios[key].to_numpy().flatten()
  250. # input_data = self.input_data.flatten()
  251. # assert data.size == input_data.size
  252. # assert io.name.decode("utf-8") == "data"
  253. # for i in range(data.size):
  254. # assert data[i] == input_data[i]
  255. # return 0
  256. #
  257. # network.set_start_callback(start_callback)
  258. # self.do_forward(network, 1)
  259. # assert start_checked == True
  260. #
  261. # def test_network_finish_callback(self):
  262. # network = LiteNetwork()
  263. # network.load(self.model_path)
  264. # finish_checked = False
  265. #
  266. # @start_finish_callback
  267. # def finish_callback(ios):
  268. # nonlocal finish_checked
  269. # finish_checked = True
  270. # assert len(ios) == 1
  271. # for key in ios:
  272. # io = key
  273. # data = ios[key].to_numpy().flatten()
  274. # output_data = self.correct_data.flatten()
  275. # assert data.size == output_data.size
  276. # for i in range(data.size):
  277. # assert data[i] == output_data[i]
  278. # return 0
  279. #
  280. # network.set_finish_callback(finish_callback)
  281. # self.do_forward(network, 1)
  282. # assert finish_checked == True
  283. def test_enable_profile(self):
  284. network = LiteNetwork()
  285. network.load(self.model_path)
  286. network.enable_profile_performance("./profile.json")
  287. self.do_forward(network)
  288. fi = open("./profile.json", "r")
  289. fi.close()
  290. os.remove("./profile.json")
  291. def test_io_txt_dump(self):
  292. network = LiteNetwork()
  293. network.load(self.model_path)
  294. network.io_txt_dump("./io_txt.txt")
  295. self.do_forward(network)
  296. def test_io_bin_dump(self):
  297. import shutil
  298. folder = "./out"
  299. network = LiteNetwork()
  300. network.load(self.model_path)
  301. if not os.path.exists(folder):
  302. os.mkdir(folder)
  303. network.io_bin_dump(folder)
  304. self.do_forward(network)
  305. shutil.rmtree(folder)
  306. def test_algo_workspace_limit(self):
  307. network = LiteNetwork()
  308. network.load(self.model_path)
  309. print("modify the workspace limit.")
  310. network.set_network_algo_workspace_limit(10000)
  311. self.do_forward(network)
  312. def test_network_algo_policy(self):
  313. network = LiteNetwork()
  314. network.load(self.model_path)
  315. network.set_network_algo_policy(
  316. LiteAlgoSelectStrategy.LITE_ALGO_PROFILE
  317. | LiteAlgoSelectStrategy.LITE_ALGO_REPRODUCIBLE
  318. )
  319. self.do_forward(network)
  320. def test_network_algo_policy_ignore_batch(self):
  321. network = LiteNetwork()
  322. network.load(self.model_path)
  323. network.set_network_algo_policy(
  324. LiteAlgoSelectStrategy.LITE_ALGO_PROFILE,
  325. shared_batch_size=1,
  326. binary_equal_between_batch=True,
  327. )
  328. self.do_forward(network)

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台