You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

network.py 18 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. from ctypes import *
  10. import numpy as np
  11. from .base import _Cnetwork, _Ctensor, _lib, _LiteCObjBase
  12. from .struct import *
  13. from .tensor import *
  14. class LiteOptions(Structure):
  15. """
  16. the inference options will be used to config a network
  17. """
  18. _fields_ = [
  19. ("weight_preprocess", c_int),
  20. ("fuse_preprocess", c_int),
  21. ("fake_next_exec", c_int),
  22. ("var_sanity_check_first_run", c_int),
  23. ("const_shape", c_int),
  24. ("force_dynamic_alloc", c_int),
  25. ("force_output_dynamic_alloc", c_int),
  26. ("no_profiling_on_shape_change", c_int),
  27. ("jit_level", c_int),
  28. ("comp_node_seq_record_level", c_int),
  29. ("graph_opt_level", c_int),
  30. ("async_exec_level", c_int),
  31. # layout transform options
  32. ("enable_nchw44", c_int),
  33. ("enable_nchw44_dot", c_int),
  34. ("enable_nchw88", c_int),
  35. ("enable_nhwcd4", c_int),
  36. ("enable_nchw4", c_int),
  37. ("enable_nchw32", c_int),
  38. ("enable_nchw64", c_int),
  39. ]
  40. def __init__(self):
  41. self.weight_preprocess = False
  42. self.fuse_preprocess = False
  43. self.fake_next_exec = False
  44. self.var_sanity_check_first_run = True
  45. self.const_shape = False
  46. self.force_dynamic_alloc = False
  47. self.force_output_dynamic_alloc = False
  48. self.no_profiling_on_shape_change = False
  49. self.jit_level = 0
  50. self.comp_node_seq_record_level = 0
  51. self.graph_opt_level = 2
  52. self.async_exec_level = 1
  53. def __repr__(self):
  54. data = {
  55. "weight_preprocess": bool(self.weight_preprocess),
  56. "fuse_preprocess": bool(self.fuse_preprocess),
  57. "fake_next_exec": bool(self.fake_next_exec),
  58. "var_sanity_check_first_run": bool(self.var_sanity_check_first_run),
  59. "const_shape": bool(self.const_shape),
  60. "force_dynamic_alloc": bool(self.force_dynamic_alloc),
  61. "force_output_dynamic_alloc": bool(self.force_output_dynamic_alloc),
  62. "no_profiling_on_shape_change": bool(self.no_profiling_on_shape_change),
  63. "jit_level": self.jit_level,
  64. "comp_node_seq_record_level": self.comp_node_seq_record_level,
  65. "graph_opt_level": self.graph_opt_level,
  66. "async_exec_level": self.async_exec_level,
  67. }
  68. return data.__repr__()
  69. class LiteConfig(Structure):
  70. """
  71. Configuration when load and compile the graph
  72. bare_model_cryption_name: is the bare model cryption method name, bare
  73. model is not pack model info inside
  74. use_loader_dynamic_param: when model forward with device loader of npu,
  75. use_loader_dynamic_param used to flag whether the loader use device input or
  76. output, if use device input or output it will set Non-zero , else set zero
  77. has_compression: flag whether the model is compressed, the compress
  78. method will used to read the model
  79. """
  80. _fields_ = [
  81. ("has_compression", c_int),
  82. ("device_id", c_int),
  83. ("device_type", c_int),
  84. ("backend", c_int),
  85. ("bare_model_cryption_name", c_char_p),
  86. ("options", LiteOptions),
  87. ]
  88. def __init__(self, device_type=LiteDeviceType.LITE_CPU, option=None):
  89. self.device_type = device_type
  90. if option:
  91. self.options = option
  92. else:
  93. self.options = LiteOptions()
  94. self.bare_model_cryption_name = c_char_p(b"")
  95. self.use_loader_dynamic_param = 0
  96. self.has_compression = 0
  97. self.backend = LiteBackend.LITE_DEFAULT
  98. def __repr__(self):
  99. data = {
  100. "has_compression": bool(self.has_compression),
  101. "device_id": LiteDeviceType(self.device_id),
  102. "device_type": LiteDeviceType(self.device_type),
  103. "backend": LiteBackend(self.backend),
  104. "bare_model_cryption_name": self.bare_model_cryption_name.decode("utf-8"),
  105. "options": self.options,
  106. }
  107. return data.__repr__()
  108. class LiteIO(Structure):
  109. """
  110. config the network input and output item
  111. name: the tensor name in the graph corresponding to the IO
  112. is_host: Used to mark where the input tensor comes from and the output where copy
  113. to, if is_host is true, the input is from host and output copy to host,
  114. otherwise device. Sometimes The input is from device and output no need
  115. copy to host, default is true.
  116. io_type: The IO type, it can be SHAPE or VALUE, when SHAPE is set, the input or
  117. output tensor value is invaid, only shape will be set, default is VALUE
  118. config_layout: The layout of the config from user, if other layout is set before
  119. forward or get after forward, this layout will by pass. if no other
  120. layout is set before forward, this layout will work. if this layout is
  121. no set, the model will forward with its origin layout. if in output, it
  122. will used to check.
  123. """
  124. _fields_ = [
  125. ("name", c_char_p),
  126. ("is_host", c_int),
  127. ("io_type", c_int),
  128. ("config_layout", LiteLayout),
  129. ]
  130. def __init__(
  131. self, name, is_host=True, io_type=LiteIOType.LITE_IO_VALUE, layout=None
  132. ):
  133. if type(name) == str:
  134. self.name = c_char_p(name.encode("utf-8"))
  135. else:
  136. self.name = c_char_p(name)
  137. if layout:
  138. self.config_layout = layout
  139. else:
  140. self.config_layout = LiteLayout()
  141. self.is_host = is_host
  142. self.io_type = io_type
  143. def __repr__(self):
  144. data = {
  145. "name": self.name,
  146. "is_host": bool(self.is_host),
  147. "io_type": LiteIOType(self.io_type),
  148. "config_layout": self.config_layout,
  149. }
  150. return data.__repr__()
  151. def __hash__(self):
  152. return hash(self.name)
  153. class _LiteNetworkIO(Structure):
  154. """
  155. the input and output information when load the network
  156. """
  157. _fields_ = [
  158. ("inputs", POINTER(LiteIO)),
  159. ("outputs", POINTER(LiteIO)),
  160. ("input_size", c_size_t),
  161. ("output_size", c_size_t),
  162. ]
  163. def __init__(self):
  164. self.inputs = POINTER(LiteIO)()
  165. self.outputs = POINTER(LiteIO)()
  166. self.input_size = 0
  167. self.output_size = 0
  168. class LiteNetworkIO(object):
  169. """
  170. the input and output information for user to construct _LiteNetWorkIO
  171. """
  172. def __init__(self):
  173. self.inputs = []
  174. self.outputs = []
  175. def add_input(self, input_io):
  176. assert isinstance(input_io, LiteIO)
  177. self.inputs.append(input_io)
  178. def add_output(self, output_io):
  179. assert isinstance(output_io, LiteIO)
  180. self.outputs.append(output_io)
  181. def _create_network_io(self):
  182. network_io = _LiteNetworkIO()
  183. length = 1 if len(self.inputs) == 0 else len(self.inputs)
  184. self.c_inputs = (LiteIO * length)(*self.inputs)
  185. length = 1 if len(self.outputs) == 0 else len(self.outputs)
  186. self.c_outputs = (LiteIO * length)(*self.outputs)
  187. network_io.inputs = pointer(self.c_inputs[0])
  188. network_io.outputs = pointer(self.c_outputs[0])
  189. network_io.input_size = len(self.inputs)
  190. network_io.output_size = len(self.outputs)
  191. return network_io
  192. def __repr__(self):
  193. data = {"inputs": list(self.inputs), "outputs": list(self.outputs)}
  194. return data.__repr__()
  195. LiteAsyncCallback = CFUNCTYPE(c_int)
  196. def start_finish_callback(func):
  197. @CFUNCTYPE(c_int, POINTER(LiteIO), POINTER(_Ctensor), c_size_t)
  198. def wrapper(c_ios, c_tensors, size):
  199. ios = {}
  200. for i in range(size):
  201. tensor = LiteTensor()
  202. tensor._tensor = c_tensors[i]
  203. tensor.update()
  204. io = c_ios[i]
  205. ios[io] = tensor
  206. return func(ios)
  207. return wrapper
  208. class _NetworkAPI(_LiteCObjBase):
  209. """
  210. get the network api from the lib
  211. """
  212. _api_ = [
  213. ("LITE_make_default_network", [POINTER(_Cnetwork)]),
  214. ("LITE_make_network", [POINTER(_Cnetwork), LiteConfig, _LiteNetworkIO]),
  215. ("LITE_load_model_from_mem", [_Cnetwork, c_void_p, c_size_t]),
  216. ("LITE_load_model_from_path", [_Cnetwork, c_char_p]),
  217. ("LITE_shared_weight_with_network", [_Cnetwork, _Ctensor]),
  218. ("LITE_destroy_network", [_Cnetwork]),
  219. ("LITE_forward", [_Cnetwork]),
  220. ("LITE_wait", [_Cnetwork]),
  221. ("LITE_get_io_tensor", [_Cnetwork, c_char_p, c_int, POINTER(_Ctensor)]),
  222. ("LITE_get_input_name", [_Cnetwork, c_size_t, POINTER(c_char_p)]),
  223. ("LITE_get_output_name", [_Cnetwork, c_size_t, POINTER(c_char_p)]),
  224. ("LITE_get_all_input_name", [_Cnetwork, POINTER(c_size_t), POINTER(c_char_p)]),
  225. ("LITE_get_all_output_name", [_Cnetwork, POINTER(c_size_t), POINTER(c_char_p)]),
  226. ("LITE_is_cpu_inplace_mode", [_Cnetwork, POINTER(c_int)]),
  227. ("LITE_get_cpu_threads_number", [_Cnetwork, POINTER(c_size_t)]),
  228. ("LITE_get_device_id", [_Cnetwork, POINTER(c_int)]),
  229. ("LITE_set_device_id", [_Cnetwork, c_int]),
  230. ("LITE_set_cpu_inplace_mode", [_Cnetwork]),
  231. ("LITE_use_tensorrt", [_Cnetwork]),
  232. ("LITE_set_cpu_threads_number", [_Cnetwork, c_size_t]),
  233. ("LITE_set_stream_id", [_Cnetwork, c_int]),
  234. ("LITE_get_stream_id", [_Cnetwork, POINTER(c_int)]),
  235. ("LITE_set_network_algo_policy", [_Cnetwork, c_int]),
  236. ("LITE_set_network_algo_fastrun_config", [_Cnetwork, c_int, c_int]),
  237. ("LITE_set_network_algo_workspace_limit", [_Cnetwork, c_size_t]),
  238. ("LITE_share_runtime_memroy", [_Cnetwork, _Cnetwork]),
  239. ("LITE_enable_profile_performance", [_Cnetwork, c_char_p]),
  240. ("LITE_enable_io_txt_dump", [_Cnetwork, c_char_p]),
  241. ("LITE_enable_io_bin_dump", [_Cnetwork, c_char_p]),
  242. ("LITE_set_async_callback", [_Cnetwork, LiteAsyncCallback]),
  243. ("LITE_set_start_callback", [_Cnetwork]),
  244. ("LITE_set_finish_callback", [_Cnetwork]),
  245. ]
  246. class LiteNetwork(object):
  247. """
  248. the network to load a model and forward
  249. """
  250. _api = _NetworkAPI()._lib
  251. def __init__(self, config=None, io=None):
  252. """
  253. create a network with config and networkio
  254. """
  255. self._network = _Cnetwork()
  256. if config:
  257. self.config = config
  258. else:
  259. self.config = LiteConfig()
  260. if io:
  261. self.network_io = io
  262. else:
  263. self.network_io = LiteNetworkIO()
  264. c_network_io = self.network_io._create_network_io()
  265. self._api.LITE_make_network(byref(self._network), self.config, c_network_io)
  266. def __repr__(self):
  267. data = {"config": self.config, "IOs": self.network_io}
  268. return data.__repr__()
  269. def __del__(self):
  270. self._api.LITE_destroy_network(self._network)
  271. def load(self, path):
  272. c_path = c_char_p(path.encode("utf-8"))
  273. self._api.LITE_load_model_from_path(self._network, c_path)
  274. def forward(self):
  275. self._api.LITE_forward(self._network)
  276. def wait(self):
  277. self._api.LITE_wait(self._network)
  278. def is_cpu_inplace_mode(self):
  279. """
  280. whether the network run in cpu inpalce mode
  281. """
  282. inplace = c_int()
  283. self._api.LITE_is_cpu_inplace_mode(self._network, byref(inplace))
  284. return bool(inplace.value)
  285. def enable_cpu_inplace_mode(self):
  286. """
  287. set cpu forward in inplace mode with which cpu forward only create one
  288. thread
  289. Note: this must be set before the network loaded
  290. """
  291. self._api.LITE_set_cpu_inplace_mode(self._network)
  292. def use_tensorrt(self):
  293. """
  294. Note: this must be set before the network loaded
  295. """
  296. self._api.LITE_use_tensorrt(self._network)
  297. @property
  298. def device_id(self):
  299. """
  300. get the device id
  301. """
  302. device_id = c_int()
  303. self._api.LITE_get_device_id(self._network, byref(device_id))
  304. return device_id.value
  305. @device_id.setter
  306. def device_id(self, device_id):
  307. """
  308. set the device id
  309. Note: this must be set before the network loaded
  310. """
  311. self._api.LITE_set_device_id(self._network, device_id)
  312. @property
  313. def stream_id(self):
  314. """
  315. get the stream id
  316. """
  317. stream_id = c_int()
  318. self._api.LITE_get_stream_id(self._network, byref(stream_id))
  319. return stream_id.value
  320. @stream_id.setter
  321. def stream_id(self, stream_id):
  322. """
  323. set the stream id
  324. Note: this must be set before the network loaded
  325. """
  326. self._api.LITE_set_stream_id(self._network, stream_id)
  327. @property
  328. def threads_number(self):
  329. """
  330. get the thread number of the netwrok
  331. """
  332. nr_thread = c_size_t()
  333. self._api.LITE_get_cpu_threads_number(self._network, byref(nr_thread))
  334. return nr_thread.value
  335. @threads_number.setter
  336. def threads_number(self, nr_threads):
  337. """
  338. set the network forward in multithread mode, and the thread number
  339. Note: this must be set before the network loaded
  340. """
  341. self._api.LITE_set_cpu_threads_number(self._network, nr_threads)
  342. def get_io_tensor(self, name, phase=LiteTensorPhase.LITE_IO):
  343. """
  344. get input or output tensor by its name
  345. """
  346. if type(name) == str:
  347. c_name = c_char_p(name.encode("utf-8"))
  348. else:
  349. c_name = c_char_p(name)
  350. tensor = LiteTensor()
  351. self._api.LITE_get_io_tensor(
  352. self._network, c_name, phase, byref(tensor._tensor)
  353. )
  354. tensor.update()
  355. return tensor
  356. def get_input_name(self, index):
  357. """
  358. get the input name by the index in the network
  359. """
  360. c_name = c_char_p()
  361. self._api.LITE_get_input_name(self._network, index, byref(c_name))
  362. return c_name.value.decode("utf-8")
  363. def get_output_name(self, index):
  364. """
  365. get the output name by the index in the network
  366. """
  367. c_name = c_char_p()
  368. self._api.LITE_get_output_name(self._network, index, byref(c_name))
  369. return c_name.value.decode("utf-8")
  370. def get_all_input_name(self):
  371. """
  372. get all the input tensor name in the network
  373. """
  374. nr_input = c_size_t()
  375. self._api.LITE_get_all_input_name(self._network, byref(nr_input), None)
  376. if nr_input.value > 0:
  377. names = (c_char_p * nr_input.value)()
  378. self._api.LITE_get_all_input_name(self._network, None, names)
  379. ret_name = [names[i].decode("utf-8") for i in range(nr_input.value)]
  380. return ret_name
  381. def get_all_output_name(self):
  382. """
  383. get all the output tensor name in the network
  384. """
  385. nr_output = c_size_t()
  386. self._api.LITE_get_all_output_name(self._network, byref(nr_output), None)
  387. if nr_output.value > 0:
  388. names = (c_char_p * nr_output.value)()
  389. self._api.LITE_get_all_output_name(self._network, None, names)
  390. ret_name = [names[i].decode("utf-8") for i in range(nr_output.value)]
  391. return ret_name
  392. def share_weights_with(self, src_network):
  393. """
  394. share weights with the loaded network
  395. """
  396. assert isinstance(src_network, LiteNetwork)
  397. self._api.LITE_shared_weight_with_network(self._network, src_network._network)
  398. def share_runtime_memroy(self, src_network):
  399. """
  400. share runtime memory with the srouce network
  401. """
  402. assert isinstance(src_network, LiteNetwork)
  403. self._api.LITE_share_runtime_memroy(self._network, src_network._network)
  404. def async_with_callback(self, async_callback):
  405. async_callback = LiteAsyncCallback(async_callback)
  406. self._api.LITE_set_async_callback(self._network, async_callback)
  407. def set_start_callback(self, start_callback):
  408. """
  409. when the network start forward, the callback will be called,
  410. the start_callback with param mapping from LiteIO to the corresponding
  411. LiteTensor
  412. """
  413. self._api.LITE_set_start_callback(self._network, start_callback)
  414. def set_finish_callback(self, finish_callback):
  415. """
  416. when the network finish forward, the callback will be called,
  417. the finish_callback with param mapping from LiteIO to the corresponding
  418. LiteTensor
  419. """
  420. self._api.LITE_set_finish_callback(self._network, finish_callback)
  421. def enable_profile_performance(self, profile_file):
  422. c_file = profile_file.encode("utf-8")
  423. self._api.LITE_enable_profile_performance(self._network, c_file)
  424. def set_network_algo_workspace_limit(self, size_limit):
  425. self._api.LITE_set_network_algo_workspace_limit(self._network, size_limit)
  426. def set_network_algo_policy(
  427. self, policy, shared_batch_size=0, binary_equal_between_batch=False
  428. ):
  429. """
  430. shared_batch_size: the batch size used by fastrun,
  431. Non-zero value means that fastrun use this batch size
  432. regardless of the batch size of the model. Zero means
  433. fastrun use batch size of the model
  434. binary_equal_between_batch: if the content of each input batch is
  435. binary equal,whether the content of each output batch is
  436. promised to be equal
  437. """
  438. self._api.LITE_set_network_algo_policy(self._network, policy)
  439. self._api.LITE_set_network_algo_fastrun_config(
  440. self._network, shared_batch_size, binary_equal_between_batch
  441. )
  442. def io_txt_dump(self, txt_file):
  443. c_file = txt_file.encode("utf-8")
  444. self._api.LITE_enable_io_txt_dump(self._network, c_file)
  445. def io_bin_dump(self, bin_dir):
  446. c_dir = bin_dir.encode("utf-8")
  447. self._api.LITE_enable_io_bin_dump(self._network, c_dir)

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台