You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

network.py 22 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. from ctypes import *
  10. import numpy as np
  11. from .base import _Cnetwork, _Ctensor, _lib, _LiteCObjBase
  12. from .struct import *
  13. from .tensor import *
  14. class LiteOptions(Structure):
  15. """
  16. the inference options will be used to config a network
  17. """
  18. _fields_ = [
  19. ("weight_preprocess", c_int),
  20. ("fuse_preprocess", c_int),
  21. ("fake_next_exec", c_int),
  22. ("var_sanity_check_first_run", c_int),
  23. ("const_shape", c_int),
  24. ("force_dynamic_alloc", c_int),
  25. ("force_output_dynamic_alloc", c_int),
  26. ("force_output_use_user_specified_memory", c_int),
  27. ("no_profiling_on_shape_change", c_int),
  28. ("jit_level", c_int),
  29. ("comp_node_seq_record_level", c_int),
  30. ("graph_opt_level", c_int),
  31. ("async_exec_level", c_int),
  32. # layout transform options
  33. ("enable_nchw44", c_int),
  34. ("enable_nchw44_dot", c_int),
  35. ("enable_nchw88", c_int),
  36. ("enable_nhwcd4", c_int),
  37. ("enable_nchw4", c_int),
  38. ("enable_nchw32", c_int),
  39. ("enable_nchw64", c_int),
  40. ]
  41. def __init__(self):
  42. self.weight_preprocess = False
  43. self.fuse_preprocess = False
  44. self.fake_next_exec = False
  45. self.var_sanity_check_first_run = True
  46. self.const_shape = False
  47. self.force_dynamic_alloc = False
  48. self.force_output_dynamic_alloc = False
  49. self.force_output_use_user_specified_memory = False
  50. self.no_profiling_on_shape_change = False
  51. self.jit_level = 0
  52. self.comp_node_seq_record_level = 0
  53. self.graph_opt_level = 2
  54. self.async_exec_level = 1
  55. def __repr__(self):
  56. data = {
  57. "weight_preprocess": bool(self.weight_preprocess),
  58. "fuse_preprocess": bool(self.fuse_preprocess),
  59. "fake_next_exec": bool(self.fake_next_exec),
  60. "var_sanity_check_first_run": bool(self.var_sanity_check_first_run),
  61. "const_shape": bool(self.const_shape),
  62. "force_dynamic_alloc": bool(self.force_dynamic_alloc),
  63. "force_output_dynamic_alloc": bool(self.force_output_dynamic_alloc),
  64. "force_output_use_user_specified_memory": bool(
  65. self.force_output_use_user_specified_memory
  66. ),
  67. "no_profiling_on_shape_change": bool(self.no_profiling_on_shape_change),
  68. "jit_level": self.jit_level,
  69. "comp_node_seq_record_level": self.comp_node_seq_record_level,
  70. "graph_opt_level": self.graph_opt_level,
  71. "async_exec_level": self.async_exec_level,
  72. }
  73. return data.__repr__()
  74. class LiteConfig(Structure):
  75. """
  76. Configuration when load and compile the graph
  77. bare_model_cryption_name: is the bare model cryption method name, bare
  78. model is not pack model info inside
  79. use_loader_dynamic_param: when model forward with device loader of npu,
  80. use_loader_dynamic_param used to flag whether the loader use device input or
  81. output, if use device input or output it will set Non-zero , else set zero
  82. has_compression: flag whether the model is compressed, the compress
  83. method will used to read the model
  84. """
  85. _fields_ = [
  86. ("has_compression", c_int),
  87. ("device_id", c_int),
  88. ("device_type", c_int),
  89. ("backend", c_int),
  90. ("_bare_model_cryption_name", c_char_p),
  91. ("options", LiteOptions),
  92. ]
  93. def __init__(self, device_type=LiteDeviceType.LITE_CPU, option=None):
  94. self.device_type = device_type
  95. if option:
  96. self.options = option
  97. else:
  98. self.options = LiteOptions()
  99. self._bare_model_cryption_name = c_char_p(b"")
  100. self.use_loader_dynamic_param = 0
  101. self.has_compression = 0
  102. self.backend = LiteBackend.LITE_DEFAULT
  103. @property
  104. def bare_model_cryption_name(self):
  105. return self._bare_model_cryption_name.decode("utf-8")
  106. @bare_model_cryption_name.setter
  107. def bare_model_cryption_name(self, name):
  108. if isinstance(name, str):
  109. self._bare_model_cryption_name = name.encode("utf-8")
  110. else:
  111. assert isinstance(name, bytes), "name should be str or bytes type."
  112. self._bare_model_cryption_name = name
  113. def __repr__(self):
  114. data = {
  115. "has_compression": bool(self.has_compression),
  116. "device_id": LiteDeviceType(self.device_id),
  117. "device_type": LiteDeviceType(self.device_type),
  118. "backend": LiteBackend(self.backend),
  119. "bare_model_cryption_name": self.bare_model_cryption_name,
  120. "options": self.options,
  121. }
  122. return data.__repr__()
  123. class LiteIO(Structure):
  124. """
  125. config the network input and output item
  126. name: the tensor name in the graph corresponding to the IO
  127. is_host: Used to mark where the input tensor comes from and the output where copy
  128. to, if is_host is true, the input is from host and output copy to host,
  129. otherwise device. Sometimes The input is from device and output no need
  130. copy to host, default is true.
  131. io_type: The IO type, it can be SHAPE or VALUE, when SHAPE is set, the input or
  132. output tensor value is invaid, only shape will be set, default is VALUE
  133. config_layout: The layout of the config from user, if other layout is set before
  134. forward or get after forward, this layout will by pass. if no other
  135. layout is set before forward, this layout will work. if this layout is
  136. no set, the model will forward with its origin layout. if in output, it
  137. will used to check.
  138. """
  139. _fields_ = [
  140. ("_name", c_char_p),
  141. ("is_host", c_int),
  142. ("io_type", c_int),
  143. ("config_layout", LiteLayout),
  144. ]
  145. def __init__(
  146. self, name, is_host=True, io_type=LiteIOType.LITE_IO_VALUE, layout=None
  147. ):
  148. if type(name) == str:
  149. self._name = c_char_p(name.encode("utf-8"))
  150. else:
  151. self._name = c_char_p(name)
  152. if layout:
  153. self.config_layout = layout
  154. else:
  155. self.config_layout = LiteLayout()
  156. self.is_host = is_host
  157. self.io_type = io_type
  158. @property
  159. def name(self):
  160. return self._name.decode("utf-8")
  161. @name.setter
  162. def name(self, name):
  163. if isinstance(name, str):
  164. self._name = name.encode("utf-8")
  165. else:
  166. assert isinstance(name, bytes), "name should be str or bytes type."
  167. self._name = name
  168. def __repr__(self):
  169. data = {
  170. "name": self.name,
  171. "is_host": bool(self.is_host),
  172. "io_type": LiteIOType(self.io_type),
  173. "config_layout": self.config_layout,
  174. }
  175. return data.__repr__()
  176. def __hash__(self):
  177. return hash(self.name)
  178. class _LiteNetworkIO(Structure):
  179. """
  180. the input and output information when load the network
  181. """
  182. _fields_ = [
  183. ("inputs", POINTER(LiteIO)),
  184. ("outputs", POINTER(LiteIO)),
  185. ("input_size", c_size_t),
  186. ("output_size", c_size_t),
  187. ]
  188. def __init__(self):
  189. self.inputs = POINTER(LiteIO)()
  190. self.outputs = POINTER(LiteIO)()
  191. self.input_size = 0
  192. self.output_size = 0
  193. class LiteNetworkIO(object):
  194. """
  195. the input and output information for user to construct _LiteNetWorkIO
  196. """
  197. def __init__(self, inputs=None, outputs=None):
  198. self.inputs = []
  199. self.outputs = []
  200. if inputs:
  201. for i in inputs:
  202. if isinstance(i, list):
  203. self.inputs.append(LiteIO(*i))
  204. else:
  205. assert isinstance(
  206. i, LiteIO
  207. ), "the param to construct LiteNetworkIO must be list of the LiteIO member or the LiteIO."
  208. self.inputs.append(i)
  209. if outputs:
  210. for i in outputs:
  211. if isinstance(i, list):
  212. self.outputs.append(LiteIO(*i))
  213. else:
  214. assert isinstance(
  215. i, LiteIO
  216. ), "the param to construct LiteNetworkIO must be list of the LiteIO member or the LiteIO."
  217. self.outputs.append(i)
  218. def add_input(
  219. self, obj, is_host=True, io_type=LiteIOType.LITE_IO_VALUE, layout=None
  220. ):
  221. if isinstance(obj, LiteIO):
  222. self.inputs.append(obj)
  223. else:
  224. name = obj
  225. self.add_input(LiteIO(name, is_host, io_type, layout))
  226. def add_output(
  227. self, obj, is_host=True, io_type=LiteIOType.LITE_IO_VALUE, layout=None
  228. ):
  229. if isinstance(obj, LiteIO):
  230. self.outputs.append(obj)
  231. else:
  232. name = obj
  233. self.add_output(LiteIO(name, is_host, io_type, layout))
  234. def _create_network_io(self):
  235. network_io = _LiteNetworkIO()
  236. length = 1 if len(self.inputs) == 0 else len(self.inputs)
  237. self.c_inputs = (LiteIO * length)(*self.inputs)
  238. length = 1 if len(self.outputs) == 0 else len(self.outputs)
  239. self.c_outputs = (LiteIO * length)(*self.outputs)
  240. network_io.inputs = pointer(self.c_inputs[0])
  241. network_io.outputs = pointer(self.c_outputs[0])
  242. network_io.input_size = len(self.inputs)
  243. network_io.output_size = len(self.outputs)
  244. return network_io
  245. def __repr__(self):
  246. data = {"inputs": list(self.inputs), "outputs": list(self.outputs)}
  247. return data.__repr__()
  248. LiteAsyncCallback = CFUNCTYPE(c_int)
  249. LiteStartCallback = CFUNCTYPE(c_int, POINTER(LiteIO), POINTER(_Ctensor), c_size_t)
  250. LiteFinishCallback = CFUNCTYPE(c_int, POINTER(LiteIO), POINTER(_Ctensor), c_size_t)
  251. def wrap_async_callback(func):
  252. global wrapper
  253. @CFUNCTYPE(c_int)
  254. def wrapper():
  255. return func()
  256. return wrapper
  257. def start_finish_callback(func):
  258. global wrapper
  259. @CFUNCTYPE(c_int, POINTER(LiteIO), POINTER(_Ctensor), c_size_t)
  260. def wrapper(c_ios, c_tensors, size):
  261. ios = {}
  262. for i in range(size):
  263. tensor = LiteTensor()
  264. tensor._tensor = c_void_p(c_tensors[i])
  265. tensor.update()
  266. io = c_ios[i]
  267. ios[io] = tensor
  268. return func(ios)
  269. return wrapper
  270. class _NetworkAPI(_LiteCObjBase):
  271. """
  272. get the network api from the lib
  273. """
  274. _api_ = [
  275. ("LITE_make_default_network", [POINTER(_Cnetwork)]),
  276. ("LITE_make_network", [POINTER(_Cnetwork), LiteConfig, _LiteNetworkIO]),
  277. ("LITE_load_model_from_mem", [_Cnetwork, c_void_p, c_size_t]),
  278. ("LITE_load_model_from_path", [_Cnetwork, c_char_p]),
  279. ("LITE_shared_weight_with_network", [_Cnetwork, _Ctensor]),
  280. ("LITE_destroy_network", [_Cnetwork]),
  281. ("LITE_forward", [_Cnetwork]),
  282. ("LITE_wait", [_Cnetwork]),
  283. ("LITE_get_io_tensor", [_Cnetwork, c_char_p, c_int, POINTER(_Ctensor)]),
  284. ("LITE_get_input_name", [_Cnetwork, c_size_t, POINTER(c_char_p)]),
  285. ("LITE_get_output_name", [_Cnetwork, c_size_t, POINTER(c_char_p)]),
  286. ("LITE_get_all_input_name", [_Cnetwork, POINTER(c_size_t), POINTER(c_char_p)]),
  287. ("LITE_get_all_output_name", [_Cnetwork, POINTER(c_size_t), POINTER(c_char_p)]),
  288. ("LITE_is_cpu_inplace_mode", [_Cnetwork, POINTER(c_int)]),
  289. ("LITE_get_cpu_threads_number", [_Cnetwork, POINTER(c_size_t)]),
  290. ("LITE_get_device_id", [_Cnetwork, POINTER(c_int)]),
  291. ("LITE_set_device_id", [_Cnetwork, c_int]),
  292. ("LITE_set_cpu_inplace_mode", [_Cnetwork]),
  293. ("LITE_use_tensorrt", [_Cnetwork]),
  294. ("LITE_set_cpu_threads_number", [_Cnetwork, c_size_t]),
  295. ("LITE_set_stream_id", [_Cnetwork, c_int]),
  296. ("LITE_get_stream_id", [_Cnetwork, POINTER(c_int)]),
  297. ("LITE_set_network_algo_policy", [_Cnetwork, c_int]),
  298. ("LITE_set_network_algo_fastrun_config", [_Cnetwork, c_int, c_int]),
  299. ("LITE_set_network_algo_workspace_limit", [_Cnetwork, c_size_t]),
  300. ("LITE_share_runtime_memroy", [_Cnetwork, _Cnetwork]),
  301. ("LITE_enable_profile_performance", [_Cnetwork, c_char_p]),
  302. ("LITE_enable_io_txt_dump", [_Cnetwork, c_char_p]),
  303. ("LITE_enable_io_bin_dump", [_Cnetwork, c_char_p]),
  304. ("LITE_set_async_callback", [_Cnetwork, LiteAsyncCallback]),
  305. ("LITE_set_start_callback", [_Cnetwork, LiteStartCallback]),
  306. ("LITE_set_finish_callback", [_Cnetwork, LiteFinishCallback]),
  307. ("LITE_get_static_memory_alloc_info", [_Cnetwork, c_char_p]),
  308. ("LITE_enable_global_layout_transform", [_Cnetwork]),
  309. ("LITE_dump_layout_transform_model", [_Cnetwork, c_char_p]),
  310. ]
  311. class LiteNetwork(object):
  312. """
  313. the network to load a model and forward
  314. """
  315. _api = _NetworkAPI()._lib
  316. def __init__(self, config=None, io=None):
  317. """
  318. create a network with config and networkio
  319. """
  320. self._network = _Cnetwork()
  321. if config:
  322. self.config = config
  323. else:
  324. self.config = LiteConfig()
  325. if io:
  326. self.network_io = io
  327. else:
  328. self.network_io = LiteNetworkIO()
  329. c_network_io = self.network_io._create_network_io()
  330. self._api.LITE_make_network(byref(self._network), self.config, c_network_io)
  331. def __repr__(self):
  332. data = {"config": self.config, "IOs": self.network_io}
  333. return data.__repr__()
  334. def __del__(self):
  335. self._api.LITE_destroy_network(self._network)
  336. def load(self, path):
  337. c_path = c_char_p(path.encode("utf-8"))
  338. self._api.LITE_load_model_from_path(self._network, c_path)
  339. def forward(self):
  340. self._api.LITE_forward(self._network)
  341. def wait(self):
  342. self._api.LITE_wait(self._network)
  343. def is_cpu_inplace_mode(self):
  344. """
  345. whether the network run in cpu inpalce mode
  346. """
  347. inplace = c_int()
  348. self._api.LITE_is_cpu_inplace_mode(self._network, byref(inplace))
  349. return bool(inplace.value)
  350. def enable_cpu_inplace_mode(self):
  351. """
  352. set cpu forward in inplace mode with which cpu forward only create one
  353. thread
  354. Note: this must be set before the network loaded
  355. """
  356. self._api.LITE_set_cpu_inplace_mode(self._network)
  357. def use_tensorrt(self):
  358. """
  359. Note: this must be set before the network loaded
  360. """
  361. self._api.LITE_use_tensorrt(self._network)
  362. @property
  363. def device_id(self):
  364. """
  365. get the device id
  366. """
  367. device_id = c_int()
  368. self._api.LITE_get_device_id(self._network, byref(device_id))
  369. return device_id.value
  370. @device_id.setter
  371. def device_id(self, device_id):
  372. """
  373. set the device id
  374. Note: this must be set before the network loaded
  375. """
  376. self._api.LITE_set_device_id(self._network, device_id)
  377. @property
  378. def stream_id(self):
  379. """
  380. get the stream id
  381. """
  382. stream_id = c_int()
  383. self._api.LITE_get_stream_id(self._network, byref(stream_id))
  384. return stream_id.value
  385. @stream_id.setter
  386. def stream_id(self, stream_id):
  387. """
  388. set the stream id
  389. Note: this must be set before the network loaded
  390. """
  391. self._api.LITE_set_stream_id(self._network, stream_id)
  392. @property
  393. def threads_number(self):
  394. """
  395. get the thread number of the netwrok
  396. """
  397. nr_thread = c_size_t()
  398. self._api.LITE_get_cpu_threads_number(self._network, byref(nr_thread))
  399. return nr_thread.value
  400. @threads_number.setter
  401. def threads_number(self, nr_threads):
  402. """
  403. set the network forward in multithread mode, and the thread number
  404. Note: this must be set before the network loaded
  405. """
  406. self._api.LITE_set_cpu_threads_number(self._network, nr_threads)
  407. def get_io_tensor(self, name, phase=LiteTensorPhase.LITE_IO):
  408. """
  409. get input or output tensor by its name
  410. """
  411. if type(name) == str:
  412. c_name = c_char_p(name.encode("utf-8"))
  413. else:
  414. c_name = c_char_p(name)
  415. tensor = LiteTensor()
  416. self._api.LITE_get_io_tensor(
  417. self._network, c_name, phase, byref(tensor._tensor)
  418. )
  419. tensor.update()
  420. return tensor
  421. def get_input_name(self, index):
  422. """
  423. get the input name by the index in the network
  424. """
  425. c_name = c_char_p()
  426. self._api.LITE_get_input_name(self._network, index, byref(c_name))
  427. return c_name.value.decode("utf-8")
  428. def get_output_name(self, index):
  429. """
  430. get the output name by the index in the network
  431. """
  432. c_name = c_char_p()
  433. self._api.LITE_get_output_name(self._network, index, byref(c_name))
  434. return c_name.value.decode("utf-8")
  435. def get_all_input_name(self):
  436. """
  437. get all the input tensor name in the network
  438. """
  439. nr_input = c_size_t()
  440. self._api.LITE_get_all_input_name(self._network, byref(nr_input), None)
  441. if nr_input.value > 0:
  442. names = (c_char_p * nr_input.value)()
  443. self._api.LITE_get_all_input_name(self._network, None, names)
  444. ret_name = [names[i].decode("utf-8") for i in range(nr_input.value)]
  445. return ret_name
  446. def get_all_output_name(self):
  447. """
  448. get all the output tensor name in the network
  449. """
  450. nr_output = c_size_t()
  451. self._api.LITE_get_all_output_name(self._network, byref(nr_output), None)
  452. if nr_output.value > 0:
  453. names = (c_char_p * nr_output.value)()
  454. self._api.LITE_get_all_output_name(self._network, None, names)
  455. ret_name = [names[i].decode("utf-8") for i in range(nr_output.value)]
  456. return ret_name
  457. def share_weights_with(self, src_network):
  458. """
  459. share weights with the loaded network
  460. """
  461. assert isinstance(src_network, LiteNetwork)
  462. self._api.LITE_shared_weight_with_network(self._network, src_network._network)
  463. def share_runtime_memroy(self, src_network):
  464. """
  465. share runtime memory with the srouce network
  466. """
  467. assert isinstance(src_network, LiteNetwork)
  468. self._api.LITE_share_runtime_memroy(self._network, src_network._network)
  469. def async_with_callback(self, async_callback):
  470. callback = wrap_async_callback(async_callback)
  471. self._api.LITE_set_async_callback(self._network, callback)
  472. def set_start_callback(self, start_callback):
  473. """
  474. when the network start forward, the callback will be called,
  475. the start_callback with param mapping from LiteIO to the corresponding
  476. LiteTensor
  477. """
  478. callback = start_finish_callback(start_callback)
  479. self._api.LITE_set_start_callback(self._network, callback)
  480. def set_finish_callback(self, finish_callback):
  481. """
  482. when the network finish forward, the callback will be called,
  483. the finish_callback with param mapping from LiteIO to the corresponding
  484. LiteTensor
  485. """
  486. callback = start_finish_callback(finish_callback)
  487. self._api.LITE_set_finish_callback(self._network, callback)
  488. def enable_profile_performance(self, profile_file):
  489. c_file = profile_file.encode("utf-8")
  490. self._api.LITE_enable_profile_performance(self._network, c_file)
  491. def set_network_algo_workspace_limit(self, size_limit):
  492. self._api.LITE_set_network_algo_workspace_limit(self._network, size_limit)
  493. def set_network_algo_policy(
  494. self, policy, shared_batch_size=0, binary_equal_between_batch=False
  495. ):
  496. """
  497. shared_batch_size: the batch size used by fastrun,
  498. Non-zero value means that fastrun use this batch size
  499. regardless of the batch size of the model. Zero means
  500. fastrun use batch size of the model
  501. binary_equal_between_batch: if the content of each input batch is
  502. binary equal,whether the content of each output batch is
  503. promised to be equal
  504. """
  505. self._api.LITE_set_network_algo_policy(self._network, policy)
  506. self._api.LITE_set_network_algo_fastrun_config(
  507. self._network, shared_batch_size, binary_equal_between_batch
  508. )
  509. def io_txt_dump(self, txt_file):
  510. c_file = txt_file.encode("utf-8")
  511. self._api.LITE_enable_io_txt_dump(self._network, c_file)
  512. def io_bin_dump(self, bin_dir):
  513. c_dir = bin_dir.encode("utf-8")
  514. self._api.LITE_enable_io_bin_dump(self._network, c_dir)
  515. def get_static_memory_alloc_info(self, log_dir="logs/test"):
  516. c_log_dir = log_dir.encode("utf-8")
  517. self._api.LITE_get_static_memory_alloc_info(self._network, c_log_dir)
  518. def enable_global_layout_transform(self):
  519. self._api.LITE_enable_global_layout_transform(self._network)
  520. def dump_layout_transform_model(self, model_file):
  521. c_file = model_file.encode("utf-8")
  522. self._api.LITE_dump_layout_transform_model(self._network, c_file)