You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

network.py 18 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531
  1. # -*- coding: utf-8 -*-
  2. # This file is part of MegEngine, a deep learning framework developed by
  3. # Megvii.
  4. #
  5. # Copyright (c) Copyright (c) 2020-2021 Megvii Inc. All rights reserved.
  6. from ctypes import *
  7. import numpy as np
  8. from .base import _Cnetwork, _Ctensor, _lib, _LiteCObjBase
  9. from .struct import *
  10. from .tensor import *
  11. class LiteOptions(Structure):
  12. """
  13. the inference options will be used to config a network
  14. """
  15. _fields_ = [
  16. ("weight_preprocess", c_int),
  17. ("fuse_preprocess", c_int),
  18. ("fake_next_exec", c_int),
  19. ("var_sanity_check_first_run", c_int),
  20. ("const_shape", c_int),
  21. ("force_dynamic_alloc", c_int),
  22. ("force_output_dynamic_alloc", c_int),
  23. ("no_profiling_on_shape_change", c_int),
  24. ("jit_level", c_int),
  25. ("comp_node_seq_record_level", c_int),
  26. ("graph_opt_level", c_int),
  27. ("async_exec_level", c_int),
  28. # layout transform options
  29. ("enable_nchw44", c_int),
  30. ("enable_nchw44_dot", c_int),
  31. ("enable_nchw88", c_int),
  32. ("enable_nhwcd4", c_int),
  33. ("enable_nchw4", c_int),
  34. ("enable_nchw32", c_int),
  35. ("enable_nchw64", c_int),
  36. ]
  37. def __init__(self):
  38. self.weight_preprocess = False
  39. self.fuse_preprocess = False
  40. self.fake_next_exec = False
  41. self.var_sanity_check_first_run = True
  42. self.const_shape = False
  43. self.force_dynamic_alloc = False
  44. self.force_output_dynamic_alloc = False
  45. self.no_profiling_on_shape_change = False
  46. self.jit_level = 0
  47. self.comp_node_seq_record_level = 0
  48. self.graph_opt_level = 2
  49. self.async_exec_level = 1
  50. def __repr__(self):
  51. data = {
  52. "weight_preprocess": bool(self.weight_preprocess),
  53. "fuse_preprocess": bool(self.fuse_preprocess),
  54. "fake_next_exec": bool(self.fake_next_exec),
  55. "var_sanity_check_first_run": bool(self.var_sanity_check_first_run),
  56. "const_shape": bool(self.const_shape),
  57. "force_dynamic_alloc": bool(self.force_dynamic_alloc),
  58. "force_output_dynamic_alloc": bool(self.force_output_dynamic_alloc),
  59. "no_profiling_on_shape_change": bool(self.no_profiling_on_shape_change),
  60. "jit_level": self.jit_level,
  61. "comp_node_seq_record_level": self.comp_node_seq_record_level,
  62. "graph_opt_level": self.graph_opt_level,
  63. "async_exec_level": self.async_exec_level,
  64. }
  65. return data.__repr__()
  66. class LiteConfig(Structure):
  67. """
  68. Configuration when load and compile the graph
  69. bare_model_cryption_name: is the bare model cryption method name, bare
  70. model is not pack model info inside
  71. use_loader_dynamic_param: when model forward with device loader of npu,
  72. use_loader_dynamic_param used to flag whether the loader use device input or
  73. output, if use device input or output it will set Non-zero , else set zero
  74. has_compression: flag whether the model is compressed, the compress
  75. method will used to read the model
  76. """
  77. _fields_ = [
  78. ("has_compression", c_int),
  79. ("device_id", c_int),
  80. ("device_type", c_int),
  81. ("backend", c_int),
  82. ("bare_model_cryption_name", c_char_p),
  83. ("options", LiteOptions),
  84. ]
  85. def __init__(self, device_type=LiteDeviceType.LITE_CPU, option=None):
  86. self.device_type = device_type
  87. if option:
  88. self.options = option
  89. else:
  90. self.options = LiteOptions()
  91. self.bare_model_cryption_name = c_char_p(b"")
  92. self.use_loader_dynamic_param = 0
  93. self.has_compression = 0
  94. self.backend = LiteBackend.LITE_DEFAULT
  95. def __repr__(self):
  96. data = {
  97. "has_compression": bool(self.has_compression),
  98. "device_id": LiteDeviceType(self.device_id),
  99. "device_type": LiteDeviceType(self.device_type),
  100. "backend": LiteBackend(self.backend),
  101. "bare_model_cryption_name": self.bare_model_cryption_name.decode("utf-8"),
  102. "options": self.options,
  103. }
  104. return data.__repr__()
  105. class LiteIO(Structure):
  106. """
  107. config the network input and output item
  108. name: the tensor name in the graph corresponding to the IO
  109. is_host: Used to mark where the input tensor comes from and the output where copy
  110. to, if is_host is true, the input is from host and output copy to host,
  111. otherwise device. Sometimes The input is from device and output no need
  112. copy to host, default is true.
  113. io_type: The IO type, it can be SHAPE or VALUE, when SHAPE is set, the input or
  114. output tensor value is invaid, only shape will be set, default is VALUE
  115. config_layout: The layout of the config from user, if other layout is set before
  116. forward or get after forward, this layout will by pass. if no other
  117. layout is set before forward, this layout will work. if this layout is
  118. no set, the model will forward with its origin layout. if in output, it
  119. will used to check.
  120. """
  121. _fields_ = [
  122. ("name", c_char_p),
  123. ("is_host", c_int),
  124. ("io_type", c_int),
  125. ("config_layout", LiteLayout),
  126. ]
  127. def __init__(
  128. self, name, is_host=True, io_type=LiteIOType.LITE_IO_VALUE, layout=None
  129. ):
  130. if type(name) == str:
  131. self.name = c_char_p(name.encode("utf-8"))
  132. else:
  133. self.name = c_char_p(name)
  134. if layout:
  135. self.config_layout = layout
  136. else:
  137. self.config_layout = LiteLayout()
  138. self.is_host = is_host
  139. self.io_type = io_type
  140. def __repr__(self):
  141. data = {
  142. "name": self.name,
  143. "is_host": bool(self.is_host),
  144. "io_type": LiteIOType(self.io_type),
  145. "config_layout": self.config_layout,
  146. }
  147. return data.__repr__()
  148. def __hash__(self):
  149. return hash(self.name)
  150. class _LiteNetworkIO(Structure):
  151. """
  152. the input and output information when load the network
  153. """
  154. _fields_ = [
  155. ("inputs", POINTER(LiteIO)),
  156. ("outputs", POINTER(LiteIO)),
  157. ("input_size", c_size_t),
  158. ("output_size", c_size_t),
  159. ]
  160. def __init__(self):
  161. self.inputs = POINTER(LiteIO)()
  162. self.outputs = POINTER(LiteIO)()
  163. self.input_size = 0
  164. self.output_size = 0
  165. class LiteNetworkIO(object):
  166. """
  167. the input and output information for user to construct _LiteNetWorkIO
  168. """
  169. def __init__(self):
  170. self.inputs = []
  171. self.outputs = []
  172. def add_input(self, input_io):
  173. assert isinstance(input_io, LiteIO)
  174. self.inputs.append(input_io)
  175. def add_output(self, output_io):
  176. assert isinstance(output_io, LiteIO)
  177. self.outputs.append(output_io)
  178. def _create_network_io(self):
  179. network_io = _LiteNetworkIO()
  180. length = 1 if len(self.inputs) == 0 else len(self.inputs)
  181. self.c_inputs = (LiteIO * length)(*self.inputs)
  182. length = 1 if len(self.outputs) == 0 else len(self.outputs)
  183. self.c_outputs = (LiteIO * length)(*self.outputs)
  184. network_io.inputs = pointer(self.c_inputs[0])
  185. network_io.outputs = pointer(self.c_outputs[0])
  186. network_io.input_size = len(self.inputs)
  187. network_io.output_size = len(self.outputs)
  188. return network_io
  189. def __repr__(self):
  190. data = {"inputs": list(self.inputs), "outputs": list(self.outputs)}
  191. return data.__repr__()
  192. LiteAsyncCallback = CFUNCTYPE(c_int)
  193. def start_finish_callback(func):
  194. @CFUNCTYPE(c_int, POINTER(LiteIO), POINTER(_Ctensor), c_size_t)
  195. def wrapper(c_ios, c_tensors, size):
  196. ios = {}
  197. for i in range(size):
  198. tensor = LiteTensor()
  199. tensor._tensor = c_tensors[i]
  200. tensor.update()
  201. io = c_ios[i]
  202. ios[io] = tensor
  203. return func(ios)
  204. return wrapper
  205. class _NetworkAPI(_LiteCObjBase):
  206. """
  207. get the network api from the lib
  208. """
  209. _api_ = [
  210. ("LITE_make_default_network", [POINTER(_Cnetwork)]),
  211. ("LITE_make_network", [POINTER(_Cnetwork), LiteConfig, _LiteNetworkIO]),
  212. ("LITE_load_model_from_mem", [_Cnetwork, c_void_p, c_size_t]),
  213. ("LITE_load_model_from_path", [_Cnetwork, c_char_p]),
  214. ("LITE_shared_weight_with_network", [_Cnetwork, _Ctensor]),
  215. ("LITE_destroy_network", [_Cnetwork]),
  216. ("LITE_forward", [_Cnetwork]),
  217. ("LITE_wait", [_Cnetwork]),
  218. ("LITE_get_io_tensor", [_Cnetwork, c_char_p, c_int, POINTER(_Ctensor)]),
  219. ("LITE_get_input_name", [_Cnetwork, c_size_t, POINTER(c_char_p)]),
  220. ("LITE_get_output_name", [_Cnetwork, c_size_t, POINTER(c_char_p)]),
  221. ("LITE_get_all_input_name", [_Cnetwork, POINTER(c_size_t), POINTER(c_char_p)]),
  222. ("LITE_get_all_output_name", [_Cnetwork, POINTER(c_size_t), POINTER(c_char_p)]),
  223. ("LITE_is_cpu_inplace_mode", [_Cnetwork, POINTER(c_int)]),
  224. ("LITE_get_cpu_threads_number", [_Cnetwork, POINTER(c_size_t)]),
  225. ("LITE_get_device_id", [_Cnetwork, POINTER(c_int)]),
  226. ("LITE_set_device_id", [_Cnetwork, c_int]),
  227. ("LITE_set_cpu_inplace_mode", [_Cnetwork]),
  228. ("LITE_use_tensorrt", [_Cnetwork]),
  229. ("LITE_set_cpu_threads_number", [_Cnetwork, c_size_t]),
  230. ("LITE_set_stream_id", [_Cnetwork, c_int]),
  231. ("LITE_get_stream_id", [_Cnetwork, POINTER(c_int)]),
  232. ("LITE_set_network_algo_policy", [_Cnetwork, c_int]),
  233. ("LITE_set_network_algo_fastrun_config", [_Cnetwork, c_int, c_int]),
  234. ("LITE_set_network_algo_workspace_limit", [_Cnetwork, c_size_t]),
  235. ("LITE_share_runtime_memroy", [_Cnetwork, _Cnetwork]),
  236. ("LITE_enable_profile_performance", [_Cnetwork, c_char_p]),
  237. ("LITE_enable_io_txt_dump", [_Cnetwork, c_char_p]),
  238. ("LITE_enable_io_bin_dump", [_Cnetwork, c_char_p]),
  239. ("LITE_set_async_callback", [_Cnetwork, LiteAsyncCallback]),
  240. ("LITE_set_start_callback", [_Cnetwork]),
  241. ("LITE_set_finish_callback", [_Cnetwork]),
  242. ]
  243. class LiteNetwork(object):
  244. """
  245. the network to load a model and forward
  246. """
  247. _api = _NetworkAPI()._lib
  248. def __init__(self, config=None, io=None):
  249. """
  250. create a network with config and networkio
  251. """
  252. self._network = _Cnetwork()
  253. if config:
  254. self.config = config
  255. else:
  256. self.config = LiteConfig()
  257. if io:
  258. self.network_io = io
  259. else:
  260. self.network_io = LiteNetworkIO()
  261. c_network_io = self.network_io._create_network_io()
  262. self._api.LITE_make_network(byref(self._network), self.config, c_network_io)
  263. def __repr__(self):
  264. data = {"config": self.config, "IOs": self.network_io}
  265. return data.__repr__()
  266. def __del__(self):
  267. self._api.LITE_destroy_network(self._network)
  268. def load(self, path):
  269. c_path = c_char_p(path.encode("utf-8"))
  270. self._api.LITE_load_model_from_path(self._network, c_path)
  271. def forward(self):
  272. self._api.LITE_forward(self._network)
  273. def wait(self):
  274. self._api.LITE_wait(self._network)
  275. def is_cpu_inplace_mode(self):
  276. """
  277. whether the network run in cpu inpalce mode
  278. """
  279. inplace = c_int()
  280. self._api.LITE_is_cpu_inplace_mode(self._network, byref(inplace))
  281. return bool(inplace.value)
  282. def enable_cpu_inplace_mode(self):
  283. """
  284. set cpu forward in inplace mode with which cpu forward only create one
  285. thread
  286. Note: this must be set before the network loaded
  287. """
  288. self._api.LITE_set_cpu_inplace_mode(self._network)
  289. def use_tensorrt(self):
  290. """
  291. Note: this must be set before the network loaded
  292. """
  293. self._api.LITE_use_tensorrt(self._network)
  294. @property
  295. def device_id(self):
  296. """
  297. get the device id
  298. """
  299. device_id = c_int()
  300. self._api.LITE_get_device_id(self._network, byref(device_id))
  301. return device_id.value
  302. @device_id.setter
  303. def device_id(self, device_id):
  304. """
  305. set the device id
  306. Note: this must be set before the network loaded
  307. """
  308. self._api.LITE_set_device_id(self._network, device_id)
  309. @property
  310. def stream_id(self):
  311. """
  312. get the stream id
  313. """
  314. stream_id = c_int()
  315. self._api.LITE_get_stream_id(self._network, byref(stream_id))
  316. return stream_id.value
  317. @stream_id.setter
  318. def stream_id(self, stream_id):
  319. """
  320. set the stream id
  321. Note: this must be set before the network loaded
  322. """
  323. self._api.LITE_set_stream_id(self._network, stream_id)
  324. @property
  325. def threads_number(self):
  326. """
  327. get the thread number of the netwrok
  328. """
  329. nr_thread = c_size_t()
  330. self._api.LITE_get_cpu_threads_number(self._network, byref(nr_thread))
  331. return nr_thread.value
  332. @threads_number.setter
  333. def threads_number(self, nr_threads):
  334. """
  335. set the network forward in multithread mode, and the thread number
  336. Note: this must be set before the network loaded
  337. """
  338. self._api.LITE_set_cpu_threads_number(self._network, nr_threads)
  339. def get_io_tensor(self, name, phase=LiteTensorPhase.LITE_IO):
  340. """
  341. get input or output tensor by its name
  342. """
  343. if type(name) == str:
  344. c_name = c_char_p(name.encode("utf-8"))
  345. else:
  346. c_name = c_char_p(name)
  347. tensor = LiteTensor()
  348. self._api.LITE_get_io_tensor(
  349. self._network, c_name, phase, byref(tensor._tensor)
  350. )
  351. tensor.update()
  352. return tensor
  353. def get_input_name(self, index):
  354. """
  355. get the input name by the index in the network
  356. """
  357. c_name = c_char_p()
  358. self._api.LITE_get_input_name(self._network, index, byref(c_name))
  359. return c_name.value.decode("utf-8")
  360. def get_output_name(self, index):
  361. """
  362. get the output name by the index in the network
  363. """
  364. c_name = c_char_p()
  365. self._api.LITE_get_output_name(self._network, index, byref(c_name))
  366. return c_name.value.decode("utf-8")
  367. def get_all_input_name(self):
  368. """
  369. get all the input tensor name in the network
  370. """
  371. nr_input = c_size_t()
  372. self._api.LITE_get_all_input_name(self._network, byref(nr_input), None)
  373. if nr_input.value > 0:
  374. names = (c_char_p * nr_input.value)()
  375. self._api.LITE_get_all_input_name(self._network, None, names)
  376. ret_name = [names[i].decode("utf-8") for i in range(nr_input.value)]
  377. return ret_name
  378. def get_all_output_name(self):
  379. """
  380. get all the output tensor name in the network
  381. """
  382. nr_output = c_size_t()
  383. self._api.LITE_get_all_output_name(self._network, byref(nr_output), None)
  384. if nr_output.value > 0:
  385. names = (c_char_p * nr_output.value)()
  386. self._api.LITE_get_all_output_name(self._network, None, names)
  387. ret_name = [names[i].decode("utf-8") for i in range(nr_output.value)]
  388. return ret_name
  389. def share_weights_with(self, src_network):
  390. """
  391. share weights with the loaded network
  392. """
  393. assert isinstance(src_network, LiteNetwork)
  394. self._api.LITE_shared_weight_with_network(self._network, src_network._network)
  395. def share_runtime_memroy(self, src_network):
  396. """
  397. share runtime memory with the srouce network
  398. """
  399. assert isinstance(src_network, LiteNetwork)
  400. self._api.LITE_share_runtime_memroy(self._network, src_network._network)
  401. def async_with_callback(self, async_callback):
  402. async_callback = LiteAsyncCallback(async_callback)
  403. self._api.LITE_set_async_callback(self._network, async_callback)
  404. def set_start_callback(self, start_callback):
  405. """
  406. when the network start forward, the callback will be called,
  407. the start_callback with param mapping from LiteIO to the corresponding
  408. LiteTensor
  409. """
  410. self._api.LITE_set_start_callback(self._network, start_callback)
  411. def set_finish_callback(self, finish_callback):
  412. """
  413. when the network finish forward, the callback will be called,
  414. the finish_callback with param mapping from LiteIO to the corresponding
  415. LiteTensor
  416. """
  417. self._api.LITE_set_finish_callback(self._network, finish_callback)
  418. def enable_profile_performance(self, profile_file):
  419. c_file = profile_file.encode("utf-8")
  420. self._api.LITE_enable_profile_performance(self._network, c_file)
  421. def set_network_algo_workspace_limit(self, size_limit):
  422. self._api.LITE_set_network_algo_workspace_limit(self._network, size_limit)
  423. def set_network_algo_policy(
  424. self, policy, shared_batch_size=0, binary_equal_between_batch=False
  425. ):
  426. """
  427. shared_batch_size: the batch size used by fastrun,
  428. Non-zero value means that fastrun use this batch size
  429. regardless of the batch size of the model. Zero means
  430. fastrun use batch size of the model
  431. binary_equal_between_batch: if the content of each input batch is
  432. binary equal,whether the content of each output batch is
  433. promised to be equal
  434. """
  435. self._api.LITE_set_network_algo_policy(self._network, policy)
  436. self._api.LITE_set_network_algo_fastrun_config(
  437. self._network, shared_batch_size, binary_equal_between_batch
  438. )
  439. def io_txt_dump(self, txt_file):
  440. c_file = txt_file.encode("utf-8")
  441. self._api.LITE_enable_io_txt_dump(self._network, c_file)
  442. def io_bin_dump(self, bin_dir):
  443. c_dir = bin_dir.encode("utf-8")
  444. self._api.LITE_enable_io_bin_dump(self._network, c_dir)

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台