You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_network_device.py 9.6 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. import functools
  10. import os
  11. import unittest
  12. import numpy as np
  13. from megenginelite import *
  14. set_log_level(2)
  15. def require_cuda(ngpu=1):
  16. """a decorator that disables a testcase if cuda is not enabled"""
  17. def dector(func):
  18. @functools.wraps(func)
  19. def wrapped(*args, **kwargs):
  20. if LiteGlobal.get_device_count(LiteDeviceType.LITE_CUDA) >= ngpu:
  21. return func(*args, **kwargs)
  22. return wrapped
  23. return dector
  24. class TestShuffleNetCuda(unittest.TestCase):
  25. source_dir = os.getenv("LITE_TEST_RESOURCE")
  26. input_data_path = os.path.join(source_dir, "input_data.npy")
  27. correct_data_path = os.path.join(source_dir, "output_data.npy")
  28. model_path = os.path.join(source_dir, "shufflenet.mge")
  29. correct_data = np.load(correct_data_path).flatten()
  30. input_data = np.load(input_data_path)
  31. def check_correct(self, out_data, error=1e-4):
  32. out_data = out_data.flatten()
  33. assert np.isfinite(out_data.sum())
  34. assert self.correct_data.size == out_data.size
  35. for i in range(out_data.size):
  36. assert abs(out_data[i] - self.correct_data[i]) < error
  37. def do_forward(self, network, times=3):
  38. input_name = network.get_input_name(0)
  39. input_tensor = network.get_io_tensor(input_name)
  40. output_name = network.get_output_name(0)
  41. output_tensor = network.get_io_tensor(output_name)
  42. input_tensor.set_data_by_copy(self.input_data)
  43. for i in range(times):
  44. network.forward()
  45. network.wait()
  46. output_data = output_tensor.to_numpy()
  47. self.check_correct(output_data)
  48. class TestNetwork(TestShuffleNetCuda):
  49. @require_cuda()
  50. def test_network_basic(self):
  51. config = LiteConfig()
  52. config.device_type = LiteDeviceType.LITE_CUDA
  53. network = LiteNetwork(config)
  54. network.load(self.model_path)
  55. input_name = network.get_input_name(0)
  56. input_tensor = network.get_io_tensor(input_name)
  57. output_name = network.get_output_name(0)
  58. output_tensor = network.get_io_tensor(output_name)
  59. assert input_tensor.layout.shapes[0] == 1
  60. assert input_tensor.layout.shapes[1] == 3
  61. assert input_tensor.layout.shapes[2] == 224
  62. assert input_tensor.layout.shapes[3] == 224
  63. assert input_tensor.layout.data_type == LiteDataType.LITE_FLOAT
  64. assert input_tensor.layout.ndim == 4
  65. self.do_forward(network)
  66. @require_cuda()
  67. def test_network_shared_data(self):
  68. config = LiteConfig()
  69. config.device_type = LiteDeviceType.LITE_CUDA
  70. network = LiteNetwork(config)
  71. network.load(self.model_path)
  72. input_name = network.get_input_name(0)
  73. input_tensor = network.get_io_tensor(input_name)
  74. output_name = network.get_output_name(0)
  75. output_tensor = network.get_io_tensor(output_name)
  76. input_tensor.set_data_by_share(self.input_data)
  77. for i in range(3):
  78. network.forward()
  79. network.wait()
  80. output_data = output_tensor.to_numpy()
  81. self.check_correct(output_data)
  82. @require_cuda(2)
  83. def test_network_set_device_id(self):
  84. config = LiteConfig()
  85. config.device_type = LiteDeviceType.LITE_CUDA
  86. network = LiteNetwork(config)
  87. assert network.device_id == 0
  88. network.device_id = 1
  89. network.load(self.model_path)
  90. assert network.device_id == 1
  91. with self.assertRaises(RuntimeError):
  92. network.device_id = 1
  93. self.do_forward(network)
  94. @require_cuda()
  95. def test_network_option(self):
  96. option = LiteOptions()
  97. option.weight_preprocess = 1
  98. option.var_sanity_check_first_run = 0
  99. config = LiteConfig(option=option)
  100. config.device_type = LiteDeviceType.LITE_CUDA
  101. network = LiteNetwork(config=config)
  102. network.load(self.model_path)
  103. self.do_forward(network)
  104. @require_cuda()
  105. def test_network_reset_io(self):
  106. option = LiteOptions()
  107. option.var_sanity_check_first_run = 0
  108. config = LiteConfig(option=option)
  109. config.device_type = LiteDeviceType.LITE_CUDA
  110. input_io = LiteIO("data")
  111. ios = LiteNetworkIO()
  112. ios.add_input(input_io)
  113. network = LiteNetwork(config=config, io=ios)
  114. network.load(self.model_path)
  115. input_tensor = network.get_io_tensor("data")
  116. assert input_tensor.device_type == LiteDeviceType.LITE_CPU
  117. self.do_forward(network)
  118. @require_cuda()
  119. def test_network_share_weights(self):
  120. option = LiteOptions()
  121. option.var_sanity_check_first_run = 0
  122. config = LiteConfig(option=option)
  123. config.device_type = LiteDeviceType.LITE_CUDA
  124. src_network = LiteNetwork(config=config)
  125. src_network.load(self.model_path)
  126. new_network = LiteNetwork()
  127. new_network.enable_cpu_inplace_mode()
  128. new_network.share_weights_with(src_network)
  129. self.do_forward(src_network)
  130. self.do_forward(new_network)
  131. @require_cuda()
  132. def test_network_share_runtime_memory(self):
  133. option = LiteOptions()
  134. option.var_sanity_check_first_run = 0
  135. config = LiteConfig(option=option)
  136. config.device_type = LiteDeviceType.LITE_CUDA
  137. src_network = LiteNetwork(config=config)
  138. src_network.load(self.model_path)
  139. new_network = LiteNetwork()
  140. new_network.enable_cpu_inplace_mode()
  141. new_network.share_runtime_memroy(src_network)
  142. new_network.load(self.model_path)
  143. self.do_forward(src_network)
  144. self.do_forward(new_network)
  145. @require_cuda
  146. def test_network_start_callback(self):
  147. config = LiteConfig()
  148. config.device = LiteDeviceType.LITE_CUDA
  149. network = LiteNetwork(config)
  150. network.load(self.model_path)
  151. start_checked = False
  152. def start_callback(ios):
  153. nonlocal start_checked
  154. start_checked = True
  155. assert len(ios) == 1
  156. for key in ios:
  157. io = key
  158. data = ios[key].to_numpy().flatten()
  159. input_data = self.input_data.flatten()
  160. assert data.size == input_data.size
  161. assert io.name.decode("utf-8") == "data"
  162. for i in range(data.size):
  163. assert data[i] == input_data[i]
  164. return 0
  165. network.set_start_callback(start_callback)
  166. self.do_forward(network, 1)
  167. assert start_checked == True
  168. @require_cuda
  169. def test_network_finish_callback(self):
  170. config = LiteConfig()
  171. config.device = LiteDeviceType.LITE_CUDA
  172. network = LiteNetwork(config)
  173. network.load(self.model_path)
  174. finish_checked = False
  175. def finish_callback(ios):
  176. nonlocal finish_checked
  177. finish_checked = True
  178. assert len(ios) == 1
  179. for key in ios:
  180. io = key
  181. data = ios[key].to_numpy().flatten()
  182. output_data = self.correct_data.flatten()
  183. assert data.size == output_data.size
  184. for i in range(data.size):
  185. assert data[i] == output_data[i]
  186. return 0
  187. network.set_finish_callback(finish_callback)
  188. self.do_forward(network, 1)
  189. assert finish_checked == True
  190. @require_cuda()
  191. def test_enable_profile(self):
  192. config = LiteConfig()
  193. config.device_type = LiteDeviceType.LITE_CUDA
  194. network = LiteNetwork(config)
  195. network.load(self.model_path)
  196. network.enable_profile_performance("./profile.json")
  197. self.do_forward(network)
  198. fi = open("./profile.json", "r")
  199. fi.close()
  200. os.remove("./profile.json")
  201. @require_cuda()
  202. def test_algo_workspace_limit(self):
  203. config = LiteConfig()
  204. config.device_type = LiteDeviceType.LITE_CUDA
  205. network = LiteNetwork(config)
  206. network.load(self.model_path)
  207. print("modify the workspace limit.")
  208. network.set_network_algo_workspace_limit(10000)
  209. self.do_forward(network)
  210. @require_cuda()
  211. def test_network_algo_policy(self):
  212. config = LiteConfig()
  213. config.device_type = LiteDeviceType.LITE_CUDA
  214. network = LiteNetwork(config)
  215. network.load(self.model_path)
  216. network.set_network_algo_policy(
  217. LiteAlgoSelectStrategy.LITE_ALGO_PROFILE
  218. | LiteAlgoSelectStrategy.LITE_ALGO_REPRODUCIBLE
  219. )
  220. self.do_forward(network)
  221. @require_cuda()
  222. def test_enable_global_layout_transform(self):
  223. network = LiteNetwork()
  224. network.enable_global_layout_transform()
  225. network.load(self.model_path)
  226. self.do_forward(network)
  227. @require_cuda()
  228. def test_dump_layout_transform_model(self):
  229. network = LiteNetwork()
  230. network.enable_global_layout_transform()
  231. network.load(self.model_path)
  232. network.dump_layout_transform_model("./model_afer_layoutTrans.mgb")
  233. self.do_forward(network)
  234. fi = open("./model_afer_layoutTrans.mgb", "r")
  235. fi.close()
  236. os.remove("./model_afer_layoutTrans.mgb")