You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

network.py 23 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678
  1. # -*- coding: utf-8 -*-
  2. from ctypes import *
  3. import numpy as np
  4. from .base import _Cnetwork, _Ctensor, _lib, _LiteCObjBase
  5. from .struct import *
  6. from .tensor import *
  7. class LiteOptions(Structure):
  8. """
  9. the inference options will be used to config a network
  10. """
  11. _fields_ = [
  12. ("weight_preprocess", c_int),
  13. ("fuse_preprocess", c_int),
  14. ("fake_next_exec", c_int),
  15. ("var_sanity_check_first_run", c_int),
  16. ("const_shape", c_int),
  17. ("force_dynamic_alloc", c_int),
  18. ("force_output_dynamic_alloc", c_int),
  19. ("force_output_use_user_specified_memory", c_int),
  20. ("no_profiling_on_shape_change", c_int),
  21. ("jit_level", c_int),
  22. ("comp_node_seq_record_level", c_int),
  23. ("graph_opt_level", c_int),
  24. ("async_exec_level", c_int),
  25. # layout transform options
  26. ("enable_nchw44", c_int),
  27. ("enable_nchw44_dot", c_int),
  28. ("enable_nchw88", c_int),
  29. ("enable_nhwcd4", c_int),
  30. ("enable_nchw4", c_int),
  31. ("enable_nchw32", c_int),
  32. ("enable_nchw64", c_int),
  33. ]
  34. def __init__(self):
  35. self.weight_preprocess = False
  36. self.fuse_preprocess = False
  37. self.fake_next_exec = False
  38. self.var_sanity_check_first_run = True
  39. self.const_shape = False
  40. self.force_dynamic_alloc = False
  41. self.force_output_dynamic_alloc = False
  42. self.force_output_use_user_specified_memory = False
  43. self.no_profiling_on_shape_change = False
  44. self.jit_level = 0
  45. self.comp_node_seq_record_level = 0
  46. self.graph_opt_level = 2
  47. self.async_exec_level = 1
  48. def __repr__(self):
  49. data = {
  50. "weight_preprocess": bool(self.weight_preprocess),
  51. "fuse_preprocess": bool(self.fuse_preprocess),
  52. "fake_next_exec": bool(self.fake_next_exec),
  53. "var_sanity_check_first_run": bool(self.var_sanity_check_first_run),
  54. "const_shape": bool(self.const_shape),
  55. "force_dynamic_alloc": bool(self.force_dynamic_alloc),
  56. "force_output_dynamic_alloc": bool(self.force_output_dynamic_alloc),
  57. "force_output_use_user_specified_memory": bool(
  58. self.force_output_use_user_specified_memory
  59. ),
  60. "no_profiling_on_shape_change": bool(self.no_profiling_on_shape_change),
  61. "jit_level": self.jit_level,
  62. "comp_node_seq_record_level": self.comp_node_seq_record_level,
  63. "graph_opt_level": self.graph_opt_level,
  64. "async_exec_level": self.async_exec_level,
  65. }
  66. return data.__repr__()
  67. class LiteConfig(Structure):
  68. """
  69. Configuration when load and compile the graph
  70. bare_model_cryption_name: is the bare model cryption method name, bare
  71. model is not pack model info inside
  72. use_loader_dynamic_param: when model forward with device loader of npu,
  73. use_loader_dynamic_param used to flag whether the loader use device input or
  74. output, if use device input or output it will set Non-zero , else set zero
  75. has_compression: flag whether the model is compressed, the compress
  76. method will used to read the model
  77. """
  78. _fields_ = [
  79. ("has_compression", c_int),
  80. ("device_id", c_int),
  81. ("device_type", c_int),
  82. ("backend", c_int),
  83. ("_bare_model_cryption_name", c_char_p),
  84. ("options", LiteOptions),
  85. ]
  86. def __init__(self, device_type=LiteDeviceType.LITE_CPU, option=None):
  87. self.device_type = device_type
  88. if option:
  89. self.options = option
  90. else:
  91. self.options = LiteOptions()
  92. self._bare_model_cryption_name = c_char_p(b"")
  93. self.use_loader_dynamic_param = 0
  94. self.has_compression = 0
  95. self.backend = LiteBackend.LITE_DEFAULT
  96. @property
  97. def bare_model_cryption_name(self):
  98. return self._bare_model_cryption_name.decode("utf-8")
  99. @bare_model_cryption_name.setter
  100. def bare_model_cryption_name(self, name):
  101. if isinstance(name, str):
  102. self._bare_model_cryption_name = name.encode("utf-8")
  103. else:
  104. assert isinstance(name, bytes), "name should be str or bytes type."
  105. self._bare_model_cryption_name = name
  106. def __repr__(self):
  107. data = {
  108. "has_compression": bool(self.has_compression),
  109. "device_id": LiteDeviceType(self.device_id),
  110. "device_type": LiteDeviceType(self.device_type),
  111. "backend": LiteBackend(self.backend),
  112. "bare_model_cryption_name": self.bare_model_cryption_name,
  113. "options": self.options,
  114. }
  115. return data.__repr__()
  116. class LiteExtraConfig(Structure):
  117. """
  118. Extra configuration when load and compile the graph
  119. disable_configure_by_model_info: disable the configuration dumped with
  120. model, if set true, all configuration in the model will not apply, users
  121. should configure the network.
  122. """
  123. _fields_ = [
  124. ("disable_configure_by_model_info", c_int),
  125. ]
  126. def __init__(self, disable_model_config=False):
  127. self.disable_configure_by_model_info = disable_model_config
  128. def __repr__(self):
  129. data = {
  130. "disable_configure_by_model_info": bool(
  131. self.disable_configure_by_model_info
  132. ),
  133. }
  134. return data.__repr__()
  135. class LiteIO(Structure):
  136. """
  137. config the network input and output item
  138. name: the tensor name in the graph corresponding to the IO
  139. is_host: Used to mark where the input tensor comes from and the output where copy
  140. to, if is_host is true, the input is from host and output copy to host,
  141. otherwise device. Sometimes The input is from device and output no need
  142. copy to host, default is true.
  143. io_type: The IO type, it can be SHAPE or VALUE, when SHAPE is set, the input or
  144. output tensor value is invaid, only shape will be set, default is VALUE
  145. config_layout: The layout of the config from user, if other layout is set before
  146. forward or get after forward, this layout will by pass. if no other
  147. layout is set before forward, this layout will work. if this layout is
  148. no set, the model will forward with its origin layout. if in output, it
  149. will used to check.
  150. """
  151. _fields_ = [
  152. ("_name", c_char_p),
  153. ("is_host", c_int),
  154. ("io_type", c_int),
  155. ("config_layout", LiteLayout),
  156. ]
  157. def __init__(
  158. self, name, is_host=True, io_type=LiteIOType.LITE_IO_VALUE, layout=None
  159. ):
  160. if type(name) == str:
  161. self._name = c_char_p(name.encode("utf-8"))
  162. else:
  163. self._name = c_char_p(name)
  164. if layout:
  165. self.config_layout = layout
  166. else:
  167. self.config_layout = LiteLayout()
  168. self.is_host = is_host
  169. self.io_type = io_type
  170. @property
  171. def name(self):
  172. return self._name.decode("utf-8")
  173. @name.setter
  174. def name(self, name):
  175. if isinstance(name, str):
  176. self._name = name.encode("utf-8")
  177. else:
  178. assert isinstance(name, bytes), "name should be str or bytes type."
  179. self._name = name
  180. def __repr__(self):
  181. data = {
  182. "name": self.name,
  183. "is_host": bool(self.is_host),
  184. "io_type": LiteIOType(self.io_type),
  185. "config_layout": self.config_layout,
  186. }
  187. return data.__repr__()
  188. def __hash__(self):
  189. return hash(self.name)
  190. class _LiteNetworkIO(Structure):
  191. """
  192. the input and output information when load the network
  193. """
  194. _fields_ = [
  195. ("inputs", POINTER(LiteIO)),
  196. ("outputs", POINTER(LiteIO)),
  197. ("input_size", c_size_t),
  198. ("output_size", c_size_t),
  199. ]
  200. def __init__(self):
  201. self.inputs = POINTER(LiteIO)()
  202. self.outputs = POINTER(LiteIO)()
  203. self.input_size = 0
  204. self.output_size = 0
  205. class LiteNetworkIO(object):
  206. """
  207. the input and output information for user to construct _LiteNetWorkIO
  208. """
  209. def __init__(self, inputs=None, outputs=None):
  210. self.inputs = []
  211. self.outputs = []
  212. if inputs:
  213. for i in inputs:
  214. if isinstance(i, list):
  215. self.inputs.append(LiteIO(*i))
  216. else:
  217. assert isinstance(
  218. i, LiteIO
  219. ), "the param to construct LiteNetworkIO must be list of the LiteIO member or the LiteIO."
  220. self.inputs.append(i)
  221. if outputs:
  222. for i in outputs:
  223. if isinstance(i, list):
  224. self.outputs.append(LiteIO(*i))
  225. else:
  226. assert isinstance(
  227. i, LiteIO
  228. ), "the param to construct LiteNetworkIO must be list of the LiteIO member or the LiteIO."
  229. self.outputs.append(i)
  230. def add_input(
  231. self, obj, is_host=True, io_type=LiteIOType.LITE_IO_VALUE, layout=None
  232. ):
  233. if isinstance(obj, LiteIO):
  234. self.inputs.append(obj)
  235. else:
  236. name = obj
  237. self.add_input(LiteIO(name, is_host, io_type, layout))
  238. def add_output(
  239. self, obj, is_host=True, io_type=LiteIOType.LITE_IO_VALUE, layout=None
  240. ):
  241. if isinstance(obj, LiteIO):
  242. self.outputs.append(obj)
  243. else:
  244. name = obj
  245. self.add_output(LiteIO(name, is_host, io_type, layout))
  246. def _create_network_io(self):
  247. network_io = _LiteNetworkIO()
  248. length = 1 if len(self.inputs) == 0 else len(self.inputs)
  249. self.c_inputs = (LiteIO * length)(*self.inputs)
  250. length = 1 if len(self.outputs) == 0 else len(self.outputs)
  251. self.c_outputs = (LiteIO * length)(*self.outputs)
  252. network_io.inputs = pointer(self.c_inputs[0])
  253. network_io.outputs = pointer(self.c_outputs[0])
  254. network_io.input_size = len(self.inputs)
  255. network_io.output_size = len(self.outputs)
  256. return network_io
  257. def __repr__(self):
  258. data = {"inputs": list(self.inputs), "outputs": list(self.outputs)}
  259. return data.__repr__()
  260. LiteAsyncCallback = CFUNCTYPE(c_int)
  261. LiteStartCallback = CFUNCTYPE(c_int, POINTER(LiteIO), POINTER(_Ctensor), c_size_t)
  262. LiteFinishCallback = CFUNCTYPE(c_int, POINTER(LiteIO), POINTER(_Ctensor), c_size_t)
  263. def wrap_async_callback(func):
  264. global wrapper
  265. @CFUNCTYPE(c_int)
  266. def wrapper():
  267. return func()
  268. return wrapper
  269. def start_finish_callback(func):
  270. global wrapper
  271. @CFUNCTYPE(c_int, POINTER(LiteIO), POINTER(_Ctensor), c_size_t)
  272. def wrapper(c_ios, c_tensors, size):
  273. ios = {}
  274. for i in range(size):
  275. tensor = LiteTensor()
  276. tensor._tensor = c_void_p(c_tensors[i])
  277. tensor.update()
  278. io = c_ios[i]
  279. ios[io] = tensor
  280. return func(ios)
  281. return wrapper
  282. class _NetworkAPI(_LiteCObjBase):
  283. """
  284. get the network api from the lib
  285. """
  286. _api_ = [
  287. ("LITE_make_default_network", [POINTER(_Cnetwork)]),
  288. ("LITE_make_network", [POINTER(_Cnetwork), LiteConfig, _LiteNetworkIO]),
  289. ("LITE_load_model_from_mem", [_Cnetwork, c_void_p, c_size_t]),
  290. ("LITE_load_model_from_path", [_Cnetwork, c_char_p]),
  291. ("LITE_shared_weight_with_network", [_Cnetwork, _Ctensor]),
  292. ("LITE_destroy_network", [_Cnetwork]),
  293. ("LITE_forward", [_Cnetwork]),
  294. ("LITE_wait", [_Cnetwork]),
  295. ("LITE_get_io_tensor", [_Cnetwork, c_char_p, c_int, POINTER(_Ctensor)]),
  296. ("LITE_get_input_name", [_Cnetwork, c_size_t, POINTER(c_char_p)]),
  297. ("LITE_get_output_name", [_Cnetwork, c_size_t, POINTER(c_char_p)]),
  298. ("LITE_get_all_input_name", [_Cnetwork, POINTER(c_size_t), POINTER(c_char_p)]),
  299. ("LITE_get_all_output_name", [_Cnetwork, POINTER(c_size_t), POINTER(c_char_p)]),
  300. ("LITE_is_cpu_inplace_mode", [_Cnetwork, POINTER(c_int)]),
  301. ("LITE_get_cpu_threads_number", [_Cnetwork, POINTER(c_size_t)]),
  302. ("LITE_get_device_id", [_Cnetwork, POINTER(c_int)]),
  303. ("LITE_set_device_id", [_Cnetwork, c_int]),
  304. ("LITE_set_cpu_inplace_mode", [_Cnetwork]),
  305. ("LITE_use_tensorrt", [_Cnetwork]),
  306. ("LITE_set_cpu_threads_number", [_Cnetwork, c_size_t]),
  307. ("LITE_set_stream_id", [_Cnetwork, c_int]),
  308. ("LITE_get_stream_id", [_Cnetwork, POINTER(c_int)]),
  309. ("LITE_set_network_algo_policy", [_Cnetwork, c_int]),
  310. ("LITE_set_network_algo_fastrun_config", [_Cnetwork, c_int, c_int]),
  311. ("LITE_set_network_algo_workspace_limit", [_Cnetwork, c_size_t]),
  312. ("LITE_share_runtime_memroy", [_Cnetwork, _Cnetwork]),
  313. ("LITE_enable_profile_performance", [_Cnetwork, c_char_p]),
  314. ("LITE_enable_io_txt_dump", [_Cnetwork, c_char_p]),
  315. ("LITE_enable_io_bin_dump", [_Cnetwork, c_char_p]),
  316. ("LITE_set_async_callback", [_Cnetwork, LiteAsyncCallback]),
  317. ("LITE_set_start_callback", [_Cnetwork, LiteStartCallback]),
  318. ("LITE_set_finish_callback", [_Cnetwork, LiteFinishCallback]),
  319. ("LITE_get_static_memory_alloc_info", [_Cnetwork, c_char_p]),
  320. ("LITE_enable_global_layout_transform", [_Cnetwork]),
  321. ("LITE_dump_layout_transform_model", [_Cnetwork, c_char_p]),
  322. (
  323. "LITE_get_model_io_info_by_path",
  324. [c_char_p, LiteConfig, POINTER(_LiteNetworkIO)],
  325. ),
  326. (
  327. "LITE_get_model_io_info_by_memory",
  328. [c_char_p, c_size_t, LiteConfig, POINTER(_LiteNetworkIO)],
  329. ),
  330. ("LITE_extra_configure", [_Cnetwork, LiteExtraConfig]),
  331. ]
  332. class LiteNetwork(object):
  333. """
  334. the network to load a model and forward
  335. """
  336. _api = _NetworkAPI()._lib
  337. def __init__(self, config=None, io=None):
  338. """
  339. create a network with config and networkio
  340. """
  341. self._network = _Cnetwork()
  342. if config:
  343. self.config = config
  344. else:
  345. self.config = LiteConfig()
  346. if io:
  347. self.network_io = io
  348. else:
  349. self.network_io = LiteNetworkIO()
  350. c_network_io = self.network_io._create_network_io()
  351. self._api.LITE_make_network(byref(self._network), self.config, c_network_io)
  352. def __repr__(self):
  353. data = {"config": self.config, "IOs": self.network_io}
  354. return data.__repr__()
  355. def __del__(self):
  356. self._api.LITE_destroy_network(self._network)
  357. def load(self, path):
  358. c_path = c_char_p(path.encode("utf-8"))
  359. self._api.LITE_load_model_from_path(self._network, c_path)
  360. def forward(self):
  361. self._api.LITE_forward(self._network)
  362. def wait(self):
  363. self._api.LITE_wait(self._network)
  364. def is_cpu_inplace_mode(self):
  365. """
  366. whether the network run in cpu inpalce mode
  367. """
  368. inplace = c_int()
  369. self._api.LITE_is_cpu_inplace_mode(self._network, byref(inplace))
  370. return bool(inplace.value)
  371. def enable_cpu_inplace_mode(self):
  372. """
  373. set cpu forward in inplace mode with which cpu forward only create one
  374. thread
  375. Note: this must be set before the network loaded
  376. """
  377. self._api.LITE_set_cpu_inplace_mode(self._network)
  378. def use_tensorrt(self):
  379. """
  380. Note: this must be set before the network loaded
  381. """
  382. self._api.LITE_use_tensorrt(self._network)
  383. @property
  384. def device_id(self):
  385. """
  386. get the device id
  387. """
  388. device_id = c_int()
  389. self._api.LITE_get_device_id(self._network, byref(device_id))
  390. return device_id.value
  391. @device_id.setter
  392. def device_id(self, device_id):
  393. """
  394. set the device id
  395. Note: this must be set before the network loaded
  396. """
  397. self._api.LITE_set_device_id(self._network, device_id)
  398. @property
  399. def stream_id(self):
  400. """
  401. get the stream id
  402. """
  403. stream_id = c_int()
  404. self._api.LITE_get_stream_id(self._network, byref(stream_id))
  405. return stream_id.value
  406. @stream_id.setter
  407. def stream_id(self, stream_id):
  408. """
  409. set the stream id
  410. Note: this must be set before the network loaded
  411. """
  412. self._api.LITE_set_stream_id(self._network, stream_id)
  413. @property
  414. def threads_number(self):
  415. """
  416. get the thread number of the netwrok
  417. """
  418. nr_thread = c_size_t()
  419. self._api.LITE_get_cpu_threads_number(self._network, byref(nr_thread))
  420. return nr_thread.value
  421. @threads_number.setter
  422. def threads_number(self, nr_threads):
  423. """
  424. set the network forward in multithread mode, and the thread number
  425. Note: this must be set before the network loaded
  426. """
  427. self._api.LITE_set_cpu_threads_number(self._network, nr_threads)
  428. def get_io_tensor(self, name, phase=LiteTensorPhase.LITE_IO):
  429. """
  430. get input or output tensor by its name
  431. """
  432. if type(name) == str:
  433. c_name = c_char_p(name.encode("utf-8"))
  434. else:
  435. c_name = c_char_p(name)
  436. tensor = LiteTensor()
  437. self._api.LITE_get_io_tensor(
  438. self._network, c_name, phase, byref(tensor._tensor)
  439. )
  440. tensor.update()
  441. return tensor
  442. def get_input_name(self, index):
  443. """
  444. get the input name by the index in the network
  445. """
  446. c_name = c_char_p()
  447. self._api.LITE_get_input_name(self._network, index, byref(c_name))
  448. return c_name.value.decode("utf-8")
  449. def get_output_name(self, index):
  450. """
  451. get the output name by the index in the network
  452. """
  453. c_name = c_char_p()
  454. self._api.LITE_get_output_name(self._network, index, byref(c_name))
  455. return c_name.value.decode("utf-8")
  456. def get_all_input_name(self):
  457. """
  458. get all the input tensor name in the network
  459. """
  460. nr_input = c_size_t()
  461. self._api.LITE_get_all_input_name(self._network, byref(nr_input), None)
  462. if nr_input.value > 0:
  463. names = (c_char_p * nr_input.value)()
  464. self._api.LITE_get_all_input_name(self._network, None, names)
  465. ret_name = [names[i].decode("utf-8") for i in range(nr_input.value)]
  466. return ret_name
  467. def get_all_output_name(self):
  468. """
  469. get all the output tensor name in the network
  470. """
  471. nr_output = c_size_t()
  472. self._api.LITE_get_all_output_name(self._network, byref(nr_output), None)
  473. if nr_output.value > 0:
  474. names = (c_char_p * nr_output.value)()
  475. self._api.LITE_get_all_output_name(self._network, None, names)
  476. ret_name = [names[i].decode("utf-8") for i in range(nr_output.value)]
  477. return ret_name
  478. def extra_configure(self, extra_config):
  479. """
  480. Extra Configuration to the network.
  481. """
  482. self._api.LITE_extra_configure(self._network, extra_config)
  483. def share_weights_with(self, src_network):
  484. """
  485. share weights with the loaded network
  486. """
  487. assert isinstance(src_network, LiteNetwork)
  488. self._api.LITE_shared_weight_with_network(self._network, src_network._network)
  489. def share_runtime_memroy(self, src_network):
  490. """
  491. share runtime memory with the srouce network
  492. """
  493. assert isinstance(src_network, LiteNetwork)
  494. self._api.LITE_share_runtime_memroy(self._network, src_network._network)
  495. def async_with_callback(self, async_callback):
  496. callback = wrap_async_callback(async_callback)
  497. self._api.LITE_set_async_callback(self._network, callback)
  498. def set_start_callback(self, start_callback):
  499. """
  500. when the network start forward, the callback will be called,
  501. the start_callback with param mapping from LiteIO to the corresponding
  502. LiteTensor
  503. """
  504. callback = start_finish_callback(start_callback)
  505. self._api.LITE_set_start_callback(self._network, callback)
  506. def set_finish_callback(self, finish_callback):
  507. """
  508. when the network finish forward, the callback will be called,
  509. the finish_callback with param mapping from LiteIO to the corresponding
  510. LiteTensor
  511. """
  512. callback = start_finish_callback(finish_callback)
  513. self._api.LITE_set_finish_callback(self._network, callback)
  514. def enable_profile_performance(self, profile_file):
  515. c_file = profile_file.encode("utf-8")
  516. self._api.LITE_enable_profile_performance(self._network, c_file)
  517. def set_network_algo_workspace_limit(self, size_limit):
  518. self._api.LITE_set_network_algo_workspace_limit(self._network, size_limit)
  519. def set_network_algo_policy(
  520. self, policy, shared_batch_size=0, binary_equal_between_batch=False
  521. ):
  522. """
  523. shared_batch_size: the batch size used by fastrun,
  524. Non-zero value means that fastrun use this batch size
  525. regardless of the batch size of the model. Zero means
  526. fastrun use batch size of the model
  527. binary_equal_between_batch: if the content of each input batch is
  528. binary equal,whether the content of each output batch is
  529. promised to be equal
  530. """
  531. self._api.LITE_set_network_algo_policy(self._network, policy)
  532. self._api.LITE_set_network_algo_fastrun_config(
  533. self._network, shared_batch_size, binary_equal_between_batch
  534. )
  535. def io_txt_dump(self, txt_file):
  536. c_file = txt_file.encode("utf-8")
  537. self._api.LITE_enable_io_txt_dump(self._network, c_file)
  538. def io_bin_dump(self, bin_dir):
  539. c_dir = bin_dir.encode("utf-8")
  540. self._api.LITE_enable_io_bin_dump(self._network, c_dir)
  541. def get_static_memory_alloc_info(self, log_dir="logs/test"):
  542. c_log_dir = log_dir.encode("utf-8")
  543. self._api.LITE_get_static_memory_alloc_info(self._network, c_log_dir)
  544. def enable_global_layout_transform(self):
  545. self._api.LITE_enable_global_layout_transform(self._network)
  546. def dump_layout_transform_model(self, model_file):
  547. c_file = model_file.encode("utf-8")
  548. self._api.LITE_dump_layout_transform_model(self._network, c_file)
  549. def get_model_io_info(model_path, config=None):
  550. """
  551. get the model IO information before create the NetWork, this IO
  552. information can be used to configuration the NetWork.
  553. """
  554. api = _NetworkAPI()._lib
  555. c_path = c_char_p(model_path.encode("utf-8"))
  556. ios = _LiteNetworkIO()
  557. if config is not None:
  558. api.LITE_get_model_io_info_by_path(c_path, config, byref(ios))
  559. else:
  560. config = LiteConfig()
  561. api.LITE_get_model_io_info_by_path(c_path, config, byref(ios))
  562. ret_ios = LiteNetworkIO()
  563. for i in range(ios.input_size):
  564. ret_ios.add_input(ios.inputs[i])
  565. for i in range(ios.output_size):
  566. ret_ios.add_output(ios.outputs[i])
  567. return ret_ios