You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

network.py 32 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977
  1. # -*- coding: utf-8 -*-
  2. from ctypes import *
  3. import numpy as np
  4. from .base import _Cnetwork, _Ctensor, _lib, _LiteCObjBase
  5. from .struct import *
  6. from .tensor import *
  7. class LiteOptions(Structure):
  8. """
  9. the inference options which can optimize the network forwarding
  10. performance
  11. Attributes:
  12. weight_preprocess: is the option which optimize the inference performance
  13. with processing the weights of the network ahead
  14. fuse_preprocess: fuse preprocess patten, like astype + pad_channel +
  15. dimshuffle
  16. fake_next_exec: whether only to perform non-computing tasks (like
  17. memory allocation and queue initialization) for next exec. This will be
  18. reset to false when the graph is executed.
  19. var_sanity_check_first_run: Disable var sanity check on the first run.
  20. Var sanity check is enabled on the first-time execution by default, and can
  21. be used to find some potential memory access errors in the operator
  22. const_shape: used to reduce memory usage and improve performance since some
  23. static inference data structures can be omitted and some operators can be
  24. compute before forwarding
  25. force_dynamic_alloc: force dynamic allocate memory for all vars
  26. force_output_dynamic_alloc: force dynamic allocate memory for output tensor
  27. which are used as the input of CallbackCaller Operator
  28. no_profiling_on_shape_change: do not re-profile to select best implement
  29. algo when input shape changes (use previous algo)
  30. jit_level: Execute supported operators with JIT (support MLIR,
  31. NVRTC). Can only be used on Nvidia GPUs and X86 CPU, this value indicates JIT level:
  32. level 1: for JIT execute with basic elemwise operator
  33. level 2: for JIT execute elemwise and reduce operators
  34. record_level: flags to optimize the inference performance with record the
  35. kernel tasks in first run, hereafter the inference all need is to execute the
  36. recorded tasks.
  37. level = 0 means the normal inference
  38. level = 1 means use record inference
  39. level = 2 means record inference with free the extra memory
  40. graph_opt_level: network optimization level:
  41. 0: disable
  42. 1: level-1: inplace arith transformations during graph construction
  43. 2: level-2: level-1, plus global optimization before graph compiling
  44. 3: also enable JIT
  45. async_exec_level: level of dispatch on separate threads for different comp_node.
  46. 0: do not perform async dispatch
  47. 1: dispatch async if there are more than one comp node with limited queue
  48. mask 0b10: async if there are multiple comp nodes with
  49. mask 0b100: always async
  50. Examples:
  51. .. code-block::
  52. from megenginelite import *
  53. options = LiteOptions()
  54. options.weight_preprocess = true
  55. options.record_level = 1
  56. options.fuse_preprocess = true
  57. """
  58. _fields_ = [
  59. ("weight_preprocess", c_int),
  60. ("fuse_preprocess", c_int),
  61. ("fake_next_exec", c_int),
  62. ("var_sanity_check_first_run", c_int),
  63. ("const_shape", c_int),
  64. ("force_dynamic_alloc", c_int),
  65. ("force_output_dynamic_alloc", c_int),
  66. ("force_output_use_user_specified_memory", c_int),
  67. ("no_profiling_on_shape_change", c_int),
  68. ("jit_level", c_int),
  69. ("comp_node_seq_record_level", c_int),
  70. ("graph_opt_level", c_int),
  71. ("async_exec_level", c_int),
  72. # layout transform options
  73. ("enable_nchw44", c_int),
  74. ("enable_nchw44_dot", c_int),
  75. ("enable_nchw88", c_int),
  76. ("enable_nhwcd4", c_int),
  77. ("enable_nchw4", c_int),
  78. ("enable_nchw32", c_int),
  79. ("enable_nchw64", c_int),
  80. ]
  81. def __init__(self):
  82. self.weight_preprocess = False
  83. self.fuse_preprocess = False
  84. self.fake_next_exec = False
  85. self.var_sanity_check_first_run = True
  86. self.const_shape = False
  87. self.force_dynamic_alloc = False
  88. self.force_output_dynamic_alloc = False
  89. self.force_output_use_user_specified_memory = False
  90. self.no_profiling_on_shape_change = False
  91. self.jit_level = 0
  92. self.comp_node_seq_record_level = 0
  93. self.graph_opt_level = 2
  94. self.async_exec_level = 1
  95. def __repr__(self):
  96. data = {
  97. "weight_preprocess": bool(self.weight_preprocess),
  98. "fuse_preprocess": bool(self.fuse_preprocess),
  99. "fake_next_exec": bool(self.fake_next_exec),
  100. "var_sanity_check_first_run": bool(self.var_sanity_check_first_run),
  101. "const_shape": bool(self.const_shape),
  102. "force_dynamic_alloc": bool(self.force_dynamic_alloc),
  103. "force_output_dynamic_alloc": bool(self.force_output_dynamic_alloc),
  104. "force_output_use_user_specified_memory": bool(
  105. self.force_output_use_user_specified_memory
  106. ),
  107. "no_profiling_on_shape_change": bool(self.no_profiling_on_shape_change),
  108. "jit_level": self.jit_level,
  109. "comp_node_seq_record_level": self.comp_node_seq_record_level,
  110. "graph_opt_level": self.graph_opt_level,
  111. "async_exec_level": self.async_exec_level,
  112. }
  113. return data.__repr__()
  114. class LiteConfig(Structure):
  115. """
  116. Configuration when load and compile a network
  117. Attributes:
  118. has_compression: flag whether the model is compressed, the compress
  119. method is stored in the model
  120. device_id: configure the device id of a network
  121. device_type: configure the device type of a network
  122. backend: configure the inference backend of a network, now only support
  123. megengine
  124. bare_model_cryption_name: is the bare model encryption method name, bare
  125. model is not packed with json information, this encryption method name is
  126. useful to decrypt the encrypted bare model
  127. options: configuration of Options
  128. Examples:
  129. .. code-block::
  130. from megenginelite import *
  131. config = LiteConfig()
  132. config.has_compression = false
  133. config.device_type = LiteDeviceType.LITE_CPU
  134. config.backend = LiteBackend.LITE_DEFAULT
  135. config.bare_model_cryption_name = "AES_default".encode("utf-8")
  136. """
  137. _fields_ = [
  138. ("has_compression", c_int),
  139. ("device_id", c_int),
  140. ("device_type", c_int),
  141. ("backend", c_int),
  142. ("_bare_model_cryption_name", c_char_p),
  143. ("options", LiteOptions),
  144. ]
  145. def __init__(self, device_type=LiteDeviceType.LITE_CPU, option=None):
  146. self.device_type = device_type
  147. if option:
  148. self.options = option
  149. else:
  150. self.options = LiteOptions()
  151. self._bare_model_cryption_name = c_char_p(b"")
  152. self.use_loader_dynamic_param = 0
  153. self.has_compression = 0
  154. self.backend = LiteBackend.LITE_DEFAULT
  155. @property
  156. def bare_model_cryption_name(self):
  157. return self._bare_model_cryption_name.decode("utf-8")
  158. @bare_model_cryption_name.setter
  159. def bare_model_cryption_name(self, name):
  160. if isinstance(name, str):
  161. self._bare_model_cryption_name = name.encode("utf-8")
  162. else:
  163. assert isinstance(name, bytes), "name should be str or bytes type."
  164. self._bare_model_cryption_name = name
  165. def __repr__(self):
  166. data = {
  167. "has_compression": bool(self.has_compression),
  168. "device_id": LiteDeviceType(self.device_id),
  169. "device_type": LiteDeviceType(self.device_type),
  170. "backend": LiteBackend(self.backend),
  171. "bare_model_cryption_name": self.bare_model_cryption_name,
  172. "options": self.options,
  173. }
  174. return data.__repr__()
  175. class LiteExtraConfig(Structure):
  176. """
  177. Extra configuration when load and compile the graph
  178. disable_configure_by_model_info: disable the configuration dumped with
  179. model, if set true, all configuration in the model will not apply, users
  180. should configure the network.
  181. """
  182. _fields_ = [
  183. ("disable_configure_by_model_info", c_int),
  184. ]
  185. def __init__(self, disable_model_config=False):
  186. self.disable_configure_by_model_info = disable_model_config
  187. def __repr__(self):
  188. data = {
  189. "disable_configure_by_model_info": bool(
  190. self.disable_configure_by_model_info
  191. ),
  192. }
  193. return data.__repr__()
  194. class LiteIO(Structure):
  195. """
  196. config the network input and output item, the input and output tensor
  197. information will describe there
  198. Attributes:
  199. name: the tensor name in the graph corresponding to the IO
  200. is_host: Used to mark where the input tensor comes from and where the output
  201. tensor will copy to, if is_host is true, the input is from host and output copy
  202. to host, otherwise in device. Sometimes the input is from device and output no need
  203. copy to host, default is true.
  204. io_type: The IO type, it can be SHAPE or VALUE, when SHAPE is set, the input or
  205. output tensor value is invaid, only shape will be set, default is VALUE
  206. config_layout: The layout of the config from user, if other layout is set before
  207. forward or get after forward, this layout will by pass. if no other
  208. layout is set before forward, this layout will work. if this layout is
  209. no set, the model will forward with its origin layout. if in output, it
  210. will used to check.
  211. Note:
  212. if other layout is set to input tensor before forwarding, this layout will not work
  213. if no layout is set before forwarding, the model will forward with its origin layout
  214. if layout is set in output tensor, it will used to check whether the layout computed from the network is correct
  215. Examples:
  216. .. code-block::
  217. from megenginelite import *
  218. io = LiteIO(
  219. "data2",
  220. is_host=True,
  221. io_type=LiteIOType.LITE_IO_SHAPE,
  222. layout=LiteLayout([2, 4, 4]),
  223. )
  224. """
  225. _fields_ = [
  226. ("_name", c_char_p),
  227. ("is_host", c_int),
  228. ("io_type", c_int),
  229. ("config_layout", LiteLayout),
  230. ]
  231. def __init__(
  232. self, name, is_host=True, io_type=LiteIOType.LITE_IO_VALUE, layout=None
  233. ):
  234. if type(name) == str:
  235. self._name = c_char_p(name.encode("utf-8"))
  236. else:
  237. self._name = c_char_p(name)
  238. if layout:
  239. self.config_layout = layout
  240. else:
  241. self.config_layout = LiteLayout()
  242. self.is_host = is_host
  243. self.io_type = io_type
  244. @property
  245. def name(self):
  246. """
  247. get the name of IO item
  248. """
  249. return self._name.decode("utf-8")
  250. @name.setter
  251. def name(self, name):
  252. """
  253. set the name of IO item
  254. """
  255. if isinstance(name, str):
  256. self._name = name.encode("utf-8")
  257. else:
  258. assert isinstance(name, bytes), "name should be str or bytes type."
  259. self._name = name
  260. def __repr__(self):
  261. data = {
  262. "name": self.name,
  263. "is_host": bool(self.is_host),
  264. "io_type": LiteIOType(self.io_type),
  265. "config_layout": self.config_layout,
  266. }
  267. return data.__repr__()
  268. def __hash__(self):
  269. return hash(self.name)
  270. class _LiteNetworkIO(Structure):
  271. _fields_ = [
  272. ("inputs", POINTER(LiteIO)),
  273. ("outputs", POINTER(LiteIO)),
  274. ("input_size", c_size_t),
  275. ("output_size", c_size_t),
  276. ]
  277. def __init__(self):
  278. self.inputs = POINTER(LiteIO)()
  279. self.outputs = POINTER(LiteIO)()
  280. self.input_size = 0
  281. self.output_size = 0
  282. class LiteNetworkIO(object):
  283. """
  284. the input and output information when load the network for user
  285. the NetworkIO will remain in the network until the network is destroyed.
  286. Attributes:
  287. inputs: The all input tensors information that will configure to the network
  288. outputs: The all output tensors information that will configure to the network
  289. Examples:
  290. .. code-block::
  291. from megenginelite import *
  292. input_io = LiteIO("data", is_host=False, io_type=LiteIOType.LITE_IO_VALUE)
  293. io = LiteNetworkIO()
  294. io.add_input(input_io)
  295. output_io = LiteIO("out", is_host=True, layout=LiteLayout([1, 1000]))
  296. io.add_output(output_io)
  297. """
  298. def __init__(self, inputs=None, outputs=None):
  299. self.inputs = []
  300. self.outputs = []
  301. if inputs:
  302. for i in inputs:
  303. if isinstance(i, list):
  304. self.inputs.append(LiteIO(*i))
  305. else:
  306. assert isinstance(
  307. i, LiteIO
  308. ), "the param to construct LiteNetworkIO must be list of the LiteIO member or the LiteIO."
  309. self.inputs.append(i)
  310. if outputs:
  311. for i in outputs:
  312. if isinstance(i, list):
  313. self.outputs.append(LiteIO(*i))
  314. else:
  315. assert isinstance(
  316. i, LiteIO
  317. ), "the param to construct LiteNetworkIO must be list of the LiteIO member or the LiteIO."
  318. self.outputs.append(i)
  319. def add_input(
  320. self, obj, is_host=True, io_type=LiteIOType.LITE_IO_VALUE, layout=None
  321. ):
  322. """
  323. add input information into LiteNetworkIO
  324. """
  325. if isinstance(obj, LiteIO):
  326. self.inputs.append(obj)
  327. else:
  328. name = obj
  329. self.add_input(LiteIO(name, is_host, io_type, layout))
  330. def add_output(
  331. self, obj, is_host=True, io_type=LiteIOType.LITE_IO_VALUE, layout=None
  332. ):
  333. """
  334. add output information into LiteNetworkIO
  335. """
  336. if isinstance(obj, LiteIO):
  337. self.outputs.append(obj)
  338. else:
  339. name = obj
  340. self.add_output(LiteIO(name, is_host, io_type, layout))
  341. def _create_network_io(self):
  342. network_io = _LiteNetworkIO()
  343. length = 1 if len(self.inputs) == 0 else len(self.inputs)
  344. self.c_inputs = (LiteIO * length)(*self.inputs)
  345. length = 1 if len(self.outputs) == 0 else len(self.outputs)
  346. self.c_outputs = (LiteIO * length)(*self.outputs)
  347. network_io.inputs = pointer(self.c_inputs[0])
  348. network_io.outputs = pointer(self.c_outputs[0])
  349. network_io.input_size = len(self.inputs)
  350. network_io.output_size = len(self.outputs)
  351. return network_io
  352. def __repr__(self):
  353. data = {"inputs": list(self.inputs), "outputs": list(self.outputs)}
  354. return data.__repr__()
  355. LiteAsyncCallback = CFUNCTYPE(c_int)
  356. LiteStartCallback = CFUNCTYPE(c_int, POINTER(LiteIO), POINTER(_Ctensor), c_size_t)
  357. LiteFinishCallback = CFUNCTYPE(c_int, POINTER(LiteIO), POINTER(_Ctensor), c_size_t)
  358. def wrap_async_callback(func):
  359. global wrapper
  360. @CFUNCTYPE(c_int)
  361. def wrapper():
  362. return func()
  363. return wrapper
  364. def start_finish_callback(func):
  365. global wrapper
  366. @CFUNCTYPE(c_int, POINTER(LiteIO), POINTER(_Ctensor), c_size_t)
  367. def wrapper(c_ios, c_tensors, size):
  368. ios = {}
  369. for i in range(size):
  370. tensor = LiteTensor()
  371. tensor._tensor = c_void_p(c_tensors[i])
  372. tensor.update()
  373. io = c_ios[i]
  374. ios[io] = tensor
  375. return func(ios)
  376. return wrapper
  377. class _NetworkAPI(_LiteCObjBase):
  378. """
  379. get the network api from the lib
  380. """
  381. _api_ = [
  382. ("LITE_make_default_network", [POINTER(_Cnetwork)]),
  383. ("LITE_make_network", [POINTER(_Cnetwork), LiteConfig, _LiteNetworkIO]),
  384. ("LITE_load_model_from_mem", [_Cnetwork, c_void_p, c_size_t]),
  385. ("LITE_load_model_from_path", [_Cnetwork, c_char_p]),
  386. ("LITE_shared_weight_with_network", [_Cnetwork, _Ctensor]),
  387. ("LITE_destroy_network", [_Cnetwork]),
  388. ("LITE_forward", [_Cnetwork]),
  389. ("LITE_wait", [_Cnetwork]),
  390. ("LITE_get_io_tensor", [_Cnetwork, c_char_p, c_int, POINTER(_Ctensor)]),
  391. ("LITE_get_input_name", [_Cnetwork, c_size_t, POINTER(c_char_p)]),
  392. ("LITE_get_output_name", [_Cnetwork, c_size_t, POINTER(c_char_p)]),
  393. ("LITE_get_all_input_name", [_Cnetwork, POINTER(c_size_t), POINTER(c_char_p)]),
  394. ("LITE_get_all_output_name", [_Cnetwork, POINTER(c_size_t), POINTER(c_char_p)]),
  395. ("LITE_is_cpu_inplace_mode", [_Cnetwork, POINTER(c_int)]),
  396. ("LITE_get_cpu_threads_number", [_Cnetwork, POINTER(c_size_t)]),
  397. ("LITE_get_device_id", [_Cnetwork, POINTER(c_int)]),
  398. ("LITE_set_device_id", [_Cnetwork, c_int]),
  399. ("LITE_set_cpu_inplace_mode", [_Cnetwork]),
  400. ("LITE_use_tensorrt", [_Cnetwork]),
  401. ("LITE_set_cpu_threads_number", [_Cnetwork, c_size_t]),
  402. ("LITE_set_stream_id", [_Cnetwork, c_int]),
  403. ("LITE_get_stream_id", [_Cnetwork, POINTER(c_int)]),
  404. ("LITE_set_network_algo_policy", [_Cnetwork, c_int]),
  405. ("LITE_set_network_algo_fastrun_config", [_Cnetwork, c_int, c_int]),
  406. ("LITE_set_network_algo_workspace_limit", [_Cnetwork, c_size_t]),
  407. ("LITE_share_runtime_memroy", [_Cnetwork, _Cnetwork]),
  408. ("LITE_enable_profile_performance", [_Cnetwork, c_char_p]),
  409. ("LITE_enable_io_txt_dump", [_Cnetwork, c_char_p]),
  410. ("LITE_enable_io_bin_dump", [_Cnetwork, c_char_p]),
  411. ("LITE_set_async_callback", [_Cnetwork, LiteAsyncCallback]),
  412. ("LITE_set_start_callback", [_Cnetwork, LiteStartCallback]),
  413. ("LITE_set_finish_callback", [_Cnetwork, LiteFinishCallback]),
  414. ("LITE_get_static_memory_alloc_info", [_Cnetwork, c_char_p]),
  415. ("LITE_enable_global_layout_transform", [_Cnetwork]),
  416. ("LITE_dump_layout_transform_model", [_Cnetwork, c_char_p]),
  417. (
  418. "LITE_get_model_io_info_by_path",
  419. [c_char_p, LiteConfig, POINTER(_LiteNetworkIO)],
  420. ),
  421. (
  422. "LITE_get_model_io_info_by_memory",
  423. [c_char_p, c_size_t, LiteConfig, POINTER(_LiteNetworkIO)],
  424. ),
  425. ("LITE_extra_configure", [_Cnetwork, LiteExtraConfig]),
  426. ]
  427. class LiteNetwork(object):
  428. """
  429. the network to load a model and forward
  430. Examples:
  431. .. code-block::
  432. from megenginelite import *
  433. config = LiteConfig()
  434. config.device_type = LiteDeviceType.LITE_CPU
  435. network = LiteNetwork(config)
  436. network.load("model_path")
  437. input_name = network.get_input_name(0)
  438. input_tensor = network.get_io_tensor(input_name)
  439. output_name = network.get_output_name(0)
  440. output_tensor = network.get_io_tensor(output_name)
  441. input_tensor.set_data_by_copy(input_data)
  442. network.forward()
  443. network.wait()
  444. """
  445. _api = _NetworkAPI()._lib
  446. def __init__(self, config=None, io=None):
  447. """
  448. create a network with config and networkio
  449. """
  450. self._network = _Cnetwork()
  451. if config:
  452. self.config = config
  453. else:
  454. self.config = LiteConfig()
  455. if io:
  456. self.network_io = io
  457. else:
  458. self.network_io = LiteNetworkIO()
  459. c_network_io = self.network_io._create_network_io()
  460. self._api.LITE_make_network(byref(self._network), self.config, c_network_io)
  461. def __repr__(self):
  462. data = {"config": self.config, "IOs": self.network_io}
  463. return data.__repr__()
  464. def __del__(self):
  465. self._api.LITE_destroy_network(self._network)
  466. def load(self, path):
  467. """
  468. load network from given path
  469. """
  470. c_path = c_char_p(path.encode("utf-8"))
  471. self._api.LITE_load_model_from_path(self._network, c_path)
  472. def forward(self):
  473. """
  474. forward the network with filled input data and fill the output data
  475. to the output tensor
  476. """
  477. self._api.LITE_forward(self._network)
  478. def wait(self):
  479. """
  480. wait until forward finish in sync model
  481. """
  482. self._api.LITE_wait(self._network)
  483. def is_cpu_inplace_mode(self):
  484. """
  485. whether the network run in cpu inpalce mode
  486. Returns:
  487. if use inpalce mode return True, else return False
  488. """
  489. inplace = c_int()
  490. self._api.LITE_is_cpu_inplace_mode(self._network, byref(inplace))
  491. return bool(inplace.value)
  492. def enable_cpu_inplace_mode(self):
  493. """
  494. set cpu forward in inplace mode with which cpu forward only create one
  495. thread
  496. Note:
  497. this must be set before the network loaded
  498. """
  499. self._api.LITE_set_cpu_inplace_mode(self._network)
  500. def use_tensorrt(self):
  501. """
  502. use TensorRT
  503. Note:
  504. this must be set before the network loaded
  505. """
  506. self._api.LITE_use_tensorrt(self._network)
  507. @property
  508. def device_id(self):
  509. """
  510. get the device id
  511. Returns:
  512. the device id of current network used
  513. """
  514. device_id = c_int()
  515. self._api.LITE_get_device_id(self._network, byref(device_id))
  516. return device_id.value
  517. @device_id.setter
  518. def device_id(self, device_id):
  519. """
  520. set the device id
  521. Note:
  522. this must be set before the network loaded
  523. """
  524. self._api.LITE_set_device_id(self._network, device_id)
  525. @property
  526. def stream_id(self):
  527. """
  528. get the stream id
  529. Returns:
  530. the value of stream id set for detwork
  531. """
  532. stream_id = c_int()
  533. self._api.LITE_get_stream_id(self._network, byref(stream_id))
  534. return stream_id.value
  535. @stream_id.setter
  536. def stream_id(self, stream_id):
  537. """
  538. set the stream id
  539. Note:
  540. this must be set before the network loaded
  541. """
  542. self._api.LITE_set_stream_id(self._network, stream_id)
  543. @property
  544. def threads_number(self):
  545. """
  546. get the thread number of the netwrok
  547. Returns:
  548. the number of thread set in the network
  549. """
  550. nr_thread = c_size_t()
  551. self._api.LITE_get_cpu_threads_number(self._network, byref(nr_thread))
  552. return nr_thread.value
  553. @threads_number.setter
  554. def threads_number(self, nr_threads):
  555. """
  556. set the network forward in multithread mode, and the thread number
  557. Note:
  558. this must be set before the network loaded
  559. """
  560. self._api.LITE_set_cpu_threads_number(self._network, nr_threads)
  561. def get_io_tensor(self, name, phase=LiteTensorPhase.LITE_IO):
  562. """
  563. get input or output tensor by its name
  564. Args:
  565. name: the name of io tensor
  566. phase: the type of LiteTensor, this is useful to separate input or output tensor with the same name
  567. Returns:
  568. the tensor with given name and type
  569. """
  570. if type(name) == str:
  571. c_name = c_char_p(name.encode("utf-8"))
  572. else:
  573. c_name = c_char_p(name)
  574. tensor = LiteTensor()
  575. self._api.LITE_get_io_tensor(
  576. self._network, c_name, phase, byref(tensor._tensor)
  577. )
  578. tensor.update()
  579. return tensor
  580. def get_input_name(self, index):
  581. """
  582. get the input name by the index in the network
  583. Args:
  584. index: the index of the input name
  585. Returns:
  586. the name of input tesor with given index
  587. """
  588. c_name = c_char_p()
  589. self._api.LITE_get_input_name(self._network, index, byref(c_name))
  590. return c_name.value.decode("utf-8")
  591. def get_output_name(self, index):
  592. """
  593. get the output name by the index in the network
  594. Args:
  595. index: the index of the output name
  596. Returns:
  597. the name of output tesor with given index
  598. """
  599. c_name = c_char_p()
  600. self._api.LITE_get_output_name(self._network, index, byref(c_name))
  601. return c_name.value.decode("utf-8")
  602. def get_all_input_name(self):
  603. """
  604. get all the input tensor name in the network
  605. Returns:
  606. the names of all input tesor in the network
  607. """
  608. nr_input = c_size_t()
  609. self._api.LITE_get_all_input_name(self._network, byref(nr_input), None)
  610. if nr_input.value > 0:
  611. names = (c_char_p * nr_input.value)()
  612. self._api.LITE_get_all_input_name(self._network, None, names)
  613. ret_name = [names[i].decode("utf-8") for i in range(nr_input.value)]
  614. return ret_name
  615. def get_all_output_name(self):
  616. """
  617. get all the output tensor name in the network
  618. Returns:
  619. the names of all output tesor in the network
  620. """
  621. nr_output = c_size_t()
  622. self._api.LITE_get_all_output_name(self._network, byref(nr_output), None)
  623. if nr_output.value > 0:
  624. names = (c_char_p * nr_output.value)()
  625. self._api.LITE_get_all_output_name(self._network, None, names)
  626. ret_name = [names[i].decode("utf-8") for i in range(nr_output.value)]
  627. return ret_name
  628. def extra_configure(self, extra_config):
  629. """
  630. Extra Configuration to the network.
  631. """
  632. self._api.LITE_extra_configure(self._network, extra_config)
  633. def share_weights_with(self, src_network):
  634. """
  635. share weights with the loaded network
  636. Args:
  637. src_network: the network to share weights
  638. """
  639. assert isinstance(src_network, LiteNetwork)
  640. self._api.LITE_shared_weight_with_network(self._network, src_network._network)
  641. def share_runtime_memroy(self, src_network):
  642. """
  643. share runtime memory with the srouce network
  644. Args:
  645. src_network: the network to share runtime memory
  646. """
  647. assert isinstance(src_network, LiteNetwork)
  648. self._api.LITE_share_runtime_memroy(self._network, src_network._network)
  649. def async_with_callback(self, async_callback):
  650. """
  651. set the network forwarding in async mode and set the AsyncCallback callback
  652. function
  653. Args:
  654. async_callback: the callback to set for network
  655. """
  656. callback = wrap_async_callback(async_callback)
  657. self._api.LITE_set_async_callback(self._network, callback)
  658. def set_start_callback(self, start_callback):
  659. """
  660. when the network start forward, the callback will be called,
  661. the start_callback with param mapping from LiteIO to the corresponding
  662. LiteTensor
  663. Args:
  664. start_callback: the callback to set for network
  665. """
  666. callback = start_finish_callback(start_callback)
  667. self._api.LITE_set_start_callback(self._network, callback)
  668. def set_finish_callback(self, finish_callback):
  669. """
  670. when the network finish forward, the callback will be called,
  671. the finish_callback with param mapping from LiteIO to the corresponding
  672. LiteTensor
  673. Args:
  674. finish_callback: the callback to set for network
  675. """
  676. callback = start_finish_callback(finish_callback)
  677. self._api.LITE_set_finish_callback(self._network, callback)
  678. def enable_profile_performance(self, profile_file):
  679. """
  680. enable get the network performance profiled information and save into given file
  681. Args:
  682. profile_file: the file to save profile information
  683. """
  684. c_file = profile_file.encode("utf-8")
  685. self._api.LITE_enable_profile_performance(self._network, c_file)
  686. def set_network_algo_workspace_limit(self, size_limit):
  687. """
  688. set the opr workspace limitation in the target network, some opr
  689. maybe use large of workspace to get good performance, set workspace limitation
  690. can save memory but may influence the performance
  691. Args:
  692. size_limit: the byte size of workspace limitation
  693. """
  694. self._api.LITE_set_network_algo_workspace_limit(self._network, size_limit)
  695. def set_network_algo_policy(
  696. self, policy, shared_batch_size=0, binary_equal_between_batch=False
  697. ):
  698. """
  699. set the network algorithm search policy for fast-run
  700. Args:
  701. shared_batch_size: the batch size used by fastrun,
  702. Non-zero value means that fastrun use this batch size
  703. regardless of the batch size of the model. Zero means
  704. fastrun use batch size of the model
  705. binary_equal_between_batch: if the content of each input batch is
  706. binary equal,whether the content of each output batch is
  707. promised to be equal
  708. """
  709. self._api.LITE_set_network_algo_policy(self._network, policy)
  710. self._api.LITE_set_network_algo_fastrun_config(
  711. self._network, shared_batch_size, binary_equal_between_batch
  712. )
  713. def io_txt_dump(self, txt_file):
  714. """
  715. dump all input/output tensor of all operators to the output file, in txt
  716. format, user can use this function to debug compute error
  717. Args:
  718. txt_file: the txt file
  719. """
  720. c_file = txt_file.encode("utf-8")
  721. self._api.LITE_enable_io_txt_dump(self._network, c_file)
  722. def io_bin_dump(self, bin_dir):
  723. """
  724. dump all input/output tensor of all operators to the output file, in
  725. binary format, user can use this function to debug compute error
  726. Args:
  727. bin_dir: the binary file directory
  728. """
  729. c_dir = bin_dir.encode("utf-8")
  730. self._api.LITE_enable_io_bin_dump(self._network, c_dir)
  731. def get_static_memory_alloc_info(self, log_dir="logs/test"):
  732. """
  733. get static peak memory info showed by Graph visualization
  734. Args:
  735. log_dir: the directory to save information log
  736. """
  737. c_log_dir = log_dir.encode("utf-8")
  738. self._api.LITE_get_static_memory_alloc_info(self._network, c_log_dir)
  739. def enable_global_layout_transform(self):
  740. """
  741. set global layout transform optimization for network, global
  742. layout optimization can auto determine the layout of every operator in
  743. the network by profile, thus it can improve the performance of the
  744. network forwarding
  745. """
  746. self._api.LITE_enable_global_layout_transform(self._network)
  747. def dump_layout_transform_model(self, model_file):
  748. """
  749. dump network after global layout transform optimization to the
  750. specific path
  751. Args:
  752. model_file: the file path to dump model
  753. """
  754. c_file = model_file.encode("utf-8")
  755. self._api.LITE_dump_layout_transform_model(self._network, c_file)
  756. def get_model_io_info(model_path, config=None):
  757. """
  758. get the model io information before model loaded by model path.
  759. Args:
  760. model_path: the model path to get the model IO information
  761. config the model configuration
  762. Returns:
  763. the input and output information in the network configuration
  764. """
  765. api = _NetworkAPI()._lib
  766. c_path = c_char_p(model_path.encode("utf-8"))
  767. ios = _LiteNetworkIO()
  768. if config is not None:
  769. api.LITE_get_model_io_info_by_path(c_path, config, byref(ios))
  770. else:
  771. config = LiteConfig()
  772. api.LITE_get_model_io_info_by_path(c_path, config, byref(ios))
  773. ret_ios = LiteNetworkIO()
  774. for i in range(ios.input_size):
  775. ret_ios.add_input(ios.inputs[i])
  776. for i in range(ios.output_size):
  777. ret_ios.add_output(ios.outputs[i])
  778. return ret_ios