You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

network.py 33 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983
  1. # -*- coding: utf-8 -*-
  2. from ctypes import *
  3. import numpy as np
  4. from .base import _Cnetwork, _Ctensor, _lib, _LiteCObjBase
  5. from .struct import *
  6. from .tensor import *
  7. class LiteOptions(Structure):
  8. """
  9. the inference options which can optimize the network forwarding
  10. performance
  11. Attributes:
  12. weight_preprocess: is the option which optimize the inference performance
  13. with processing the weights of the network ahead
  14. fuse_preprocess: fuse preprocess patten, like astype + pad_channel +
  15. dimshuffle
  16. fake_next_exec: whether only to perform non-computing tasks (like
  17. memory allocation and queue initialization) for next exec. This will be
  18. reset to false when the graph is executed.
  19. var_sanity_check_first_run: Disable var sanity check on the first run.
  20. Var sanity check is enabled on the first-time execution by default, and can
  21. be used to find some potential memory access errors in the operator
  22. const_shape: used to reduce memory usage and improve performance since some
  23. static inference data structures can be omitted and some operators can be
  24. compute before forwarding
  25. force_dynamic_alloc: force dynamic allocate memory for all vars
  26. force_output_dynamic_alloc: force dynamic allocate memory for output tensor
  27. which are used as the input of CallbackCaller Operator
  28. no_profiling_on_shape_change: do not re-profile to select best implement
  29. algo when input shape changes (use previous algo)
  30. jit_level: Execute supported operators with JIT (support MLIR,
  31. NVRTC). Can only be used on Nvidia GPUs and X86 CPU, this value indicates JIT level:
  32. level 1: for JIT execute with basic elemwise operator
  33. level 2: for JIT execute elemwise and reduce operators
  34. record_level: flags to optimize the inference performance with record the
  35. kernel tasks in first run, hereafter the inference all need is to execute the
  36. recorded tasks.
  37. level = 0 means the normal inference
  38. level = 1 means use record inference
  39. level = 2 means record inference with free the extra memory
  40. graph_opt_level: network optimization level:
  41. 0: disable
  42. 1: level-1: inplace arith transformations during graph construction
  43. 2: level-2: level-1, plus global optimization before graph compiling
  44. 3: also enable JIT
  45. async_exec_level: level of dispatch on separate threads for different comp_node.
  46. 0: do not perform async dispatch
  47. 1: dispatch async if there are more than one comp node with limited queue
  48. mask 0b10: async if there are multiple comp nodes with
  49. mask 0b100: always async
  50. Examples:
  51. .. code-block::
  52. from megenginelite import *
  53. options = LiteOptions()
  54. options.weight_preprocess = true
  55. options.record_level = 1
  56. options.fuse_preprocess = true
  57. """
  58. _fields_ = [
  59. ("weight_preprocess", c_int),
  60. ("fuse_preprocess", c_int),
  61. ("fake_next_exec", c_int),
  62. ("var_sanity_check_first_run", c_int),
  63. ("const_shape", c_int),
  64. ("force_dynamic_alloc", c_int),
  65. ("force_output_dynamic_alloc", c_int),
  66. ("force_output_use_user_specified_memory", c_int),
  67. ("no_profiling_on_shape_change", c_int),
  68. ("jit_level", c_int),
  69. ("comp_node_seq_record_level", c_int),
  70. ("graph_opt_level", c_int),
  71. ("async_exec_level", c_int),
  72. # layout transform options
  73. ("enable_nchw44", c_int),
  74. ("enable_nchw44_dot", c_int),
  75. ("enable_nchw88", c_int),
  76. ("enable_nhwcd4", c_int),
  77. ("enable_nchw4", c_int),
  78. ("enable_nchw32", c_int),
  79. ("enable_nchw64", c_int),
  80. ]
  81. def __init__(self):
  82. self.weight_preprocess = False
  83. self.fuse_preprocess = False
  84. self.fake_next_exec = False
  85. self.var_sanity_check_first_run = True
  86. self.const_shape = False
  87. self.force_dynamic_alloc = False
  88. self.force_output_dynamic_alloc = False
  89. self.force_output_use_user_specified_memory = False
  90. self.no_profiling_on_shape_change = False
  91. self.jit_level = 0
  92. self.comp_node_seq_record_level = 0
  93. self.graph_opt_level = 2
  94. self.async_exec_level = 1
  95. def __repr__(self):
  96. data = {
  97. "weight_preprocess": bool(self.weight_preprocess),
  98. "fuse_preprocess": bool(self.fuse_preprocess),
  99. "fake_next_exec": bool(self.fake_next_exec),
  100. "var_sanity_check_first_run": bool(self.var_sanity_check_first_run),
  101. "const_shape": bool(self.const_shape),
  102. "force_dynamic_alloc": bool(self.force_dynamic_alloc),
  103. "force_output_dynamic_alloc": bool(self.force_output_dynamic_alloc),
  104. "force_output_use_user_specified_memory": bool(
  105. self.force_output_use_user_specified_memory
  106. ),
  107. "no_profiling_on_shape_change": bool(self.no_profiling_on_shape_change),
  108. "jit_level": self.jit_level,
  109. "comp_node_seq_record_level": self.comp_node_seq_record_level,
  110. "graph_opt_level": self.graph_opt_level,
  111. "async_exec_level": self.async_exec_level,
  112. }
  113. return data.__repr__()
  114. class LiteConfig(Structure):
  115. """
  116. Configuration when load and compile a network
  117. Attributes:
  118. has_compression: flag whether the model is compressed, the compress
  119. method is stored in the model
  120. device_id: configure the device id of a network
  121. device_type: configure the device type of a network
  122. backend: configure the inference backend of a network, now only support
  123. megengine
  124. bare_model_cryption_name: is the bare model encryption method name, bare
  125. model is not packed with json information, this encryption method name is
  126. useful to decrypt the encrypted bare model
  127. options: configuration of Options
  128. auto_optimize_inference: lite will detect the device information add set the options heuristically
  129. Examples:
  130. .. code-block::
  131. from megenginelite import *
  132. config = LiteConfig()
  133. config.has_compression = False
  134. config.device_type = LiteDeviceType.LITE_CPU
  135. config.backend = LiteBackend.LITE_DEFAULT
  136. config.bare_model_cryption_name = "AES_default".encode("utf-8")
  137. config.auto_optimize_inference = False
  138. """
  139. _fields_ = [
  140. ("has_compression", c_int),
  141. ("device_id", c_int),
  142. ("device_type", c_int),
  143. ("backend", c_int),
  144. ("_bare_model_cryption_name", c_char_p),
  145. ("options", LiteOptions),
  146. ("auto_optimize_inference", c_int),
  147. ]
  148. def __init__(self, device_type=LiteDeviceType.LITE_CPU, option=None):
  149. self.device_type = device_type
  150. if option:
  151. self.options = option
  152. else:
  153. self.options = LiteOptions()
  154. self._bare_model_cryption_name = c_char_p(b"")
  155. self.use_loader_dynamic_param = 0
  156. self.has_compression = 0
  157. self.backend = LiteBackend.LITE_DEFAULT
  158. self.auto_optimize_inference = 0
  159. @property
  160. def bare_model_cryption_name(self):
  161. return self._bare_model_cryption_name.decode("utf-8")
  162. @bare_model_cryption_name.setter
  163. def bare_model_cryption_name(self, name):
  164. if isinstance(name, str):
  165. self._bare_model_cryption_name = name.encode("utf-8")
  166. else:
  167. assert isinstance(name, bytes), "name should be str or bytes type."
  168. self._bare_model_cryption_name = name
  169. def __repr__(self):
  170. data = {
  171. "has_compression": bool(self.has_compression),
  172. "device_id": LiteDeviceType(self.device_id),
  173. "device_type": LiteDeviceType(self.device_type),
  174. "backend": LiteBackend(self.backend),
  175. "bare_model_cryption_name": self.bare_model_cryption_name,
  176. "options": self.options,
  177. "auto_optimize_inference": self.auto_optimize_inference,
  178. }
  179. return data.__repr__()
  180. class LiteExtraConfig(Structure):
  181. """
  182. Extra configuration when load and compile the graph
  183. disable_configure_by_model_info: disable the configuration dumped with
  184. model, if set true, all configuration in the model will not apply, users
  185. should configure the network.
  186. """
  187. _fields_ = [
  188. ("disable_configure_by_model_info", c_int),
  189. ]
  190. def __init__(self, disable_model_config=False):
  191. self.disable_configure_by_model_info = disable_model_config
  192. def __repr__(self):
  193. data = {
  194. "disable_configure_by_model_info": bool(
  195. self.disable_configure_by_model_info
  196. ),
  197. }
  198. return data.__repr__()
  199. class LiteIO(Structure):
  200. """
  201. config the network input and output item, the input and output tensor
  202. information will describe there
  203. Attributes:
  204. name: the tensor name in the graph corresponding to the IO
  205. is_host: Used to mark where the input tensor comes from and where the output
  206. tensor will copy to, if is_host is true, the input is from host and output copy
  207. to host, otherwise in device. Sometimes the input is from device and output no need
  208. copy to host, default is true.
  209. io_type: The IO type, it can be SHAPE or VALUE, when SHAPE is set, the input or
  210. output tensor value is invaid, only shape will be set, default is VALUE
  211. config_layout: The layout of the config from user, if other layout is set before
  212. forward or get after forward, this layout will by pass. if no other
  213. layout is set before forward, this layout will work. if this layout is
  214. no set, the model will forward with its origin layout. if in output, it
  215. will used to check.
  216. Note:
  217. if other layout is set to input tensor before forwarding, this layout will not work
  218. if no layout is set before forwarding, the model will forward with its origin layout
  219. if layout is set in output tensor, it will used to check whether the layout computed from the network is correct
  220. Examples:
  221. .. code-block::
  222. from megenginelite import *
  223. io = LiteIO(
  224. "data2",
  225. is_host=True,
  226. io_type=LiteIOType.LITE_IO_SHAPE,
  227. layout=LiteLayout([2, 4, 4]),
  228. )
  229. """
  230. _fields_ = [
  231. ("_name", c_char_p),
  232. ("is_host", c_int),
  233. ("io_type", c_int),
  234. ("config_layout", LiteLayout),
  235. ]
  236. def __init__(
  237. self, name, is_host=True, io_type=LiteIOType.LITE_IO_VALUE, layout=None
  238. ):
  239. if type(name) == str:
  240. self._name = c_char_p(name.encode("utf-8"))
  241. else:
  242. self._name = c_char_p(name)
  243. if layout:
  244. self.config_layout = layout
  245. else:
  246. self.config_layout = LiteLayout()
  247. self.is_host = is_host
  248. self.io_type = io_type
  249. @property
  250. def name(self):
  251. """
  252. get the name of IO item
  253. """
  254. return self._name.decode("utf-8")
  255. @name.setter
  256. def name(self, name):
  257. """
  258. set the name of IO item
  259. """
  260. if isinstance(name, str):
  261. self._name = name.encode("utf-8")
  262. else:
  263. assert isinstance(name, bytes), "name should be str or bytes type."
  264. self._name = name
  265. def __repr__(self):
  266. data = {
  267. "name": self.name,
  268. "is_host": bool(self.is_host),
  269. "io_type": LiteIOType(self.io_type),
  270. "config_layout": self.config_layout,
  271. }
  272. return data.__repr__()
  273. def __hash__(self):
  274. return hash(self.name)
  275. class _LiteNetworkIO(Structure):
  276. _fields_ = [
  277. ("inputs", POINTER(LiteIO)),
  278. ("outputs", POINTER(LiteIO)),
  279. ("input_size", c_size_t),
  280. ("output_size", c_size_t),
  281. ]
  282. def __init__(self):
  283. self.inputs = POINTER(LiteIO)()
  284. self.outputs = POINTER(LiteIO)()
  285. self.input_size = 0
  286. self.output_size = 0
  287. class LiteNetworkIO(object):
  288. """
  289. the input and output information when load the network for user
  290. the NetworkIO will remain in the network until the network is destroyed.
  291. Attributes:
  292. inputs: The all input tensors information that will configure to the network
  293. outputs: The all output tensors information that will configure to the network
  294. Examples:
  295. .. code-block::
  296. from megenginelite import *
  297. input_io = LiteIO("data", is_host=False, io_type=LiteIOType.LITE_IO_VALUE)
  298. io = LiteNetworkIO()
  299. io.add_input(input_io)
  300. output_io = LiteIO("out", is_host=True, layout=LiteLayout([1, 1000]))
  301. io.add_output(output_io)
  302. """
  303. def __init__(self, inputs=None, outputs=None):
  304. self.inputs = []
  305. self.outputs = []
  306. if inputs:
  307. for i in inputs:
  308. if isinstance(i, list):
  309. self.inputs.append(LiteIO(*i))
  310. else:
  311. assert isinstance(
  312. i, LiteIO
  313. ), "the param to construct LiteNetworkIO must be list of the LiteIO member or the LiteIO."
  314. self.inputs.append(i)
  315. if outputs:
  316. for i in outputs:
  317. if isinstance(i, list):
  318. self.outputs.append(LiteIO(*i))
  319. else:
  320. assert isinstance(
  321. i, LiteIO
  322. ), "the param to construct LiteNetworkIO must be list of the LiteIO member or the LiteIO."
  323. self.outputs.append(i)
  324. def add_input(
  325. self, obj, is_host=True, io_type=LiteIOType.LITE_IO_VALUE, layout=None
  326. ):
  327. """
  328. add input information into LiteNetworkIO
  329. """
  330. if isinstance(obj, LiteIO):
  331. self.inputs.append(obj)
  332. else:
  333. name = obj
  334. self.add_input(LiteIO(name, is_host, io_type, layout))
  335. def add_output(
  336. self, obj, is_host=True, io_type=LiteIOType.LITE_IO_VALUE, layout=None
  337. ):
  338. """
  339. add output information into LiteNetworkIO
  340. """
  341. if isinstance(obj, LiteIO):
  342. self.outputs.append(obj)
  343. else:
  344. name = obj
  345. self.add_output(LiteIO(name, is_host, io_type, layout))
  346. def _create_network_io(self):
  347. network_io = _LiteNetworkIO()
  348. length = 1 if len(self.inputs) == 0 else len(self.inputs)
  349. self.c_inputs = (LiteIO * length)(*self.inputs)
  350. length = 1 if len(self.outputs) == 0 else len(self.outputs)
  351. self.c_outputs = (LiteIO * length)(*self.outputs)
  352. network_io.inputs = pointer(self.c_inputs[0])
  353. network_io.outputs = pointer(self.c_outputs[0])
  354. network_io.input_size = len(self.inputs)
  355. network_io.output_size = len(self.outputs)
  356. return network_io
  357. def __repr__(self):
  358. data = {"inputs": list(self.inputs), "outputs": list(self.outputs)}
  359. return data.__repr__()
  360. LiteAsyncCallback = CFUNCTYPE(c_int)
  361. LiteStartCallback = CFUNCTYPE(c_int, POINTER(LiteIO), POINTER(_Ctensor), c_size_t)
  362. LiteFinishCallback = CFUNCTYPE(c_int, POINTER(LiteIO), POINTER(_Ctensor), c_size_t)
  363. def wrap_async_callback(func):
  364. global wrapper
  365. @CFUNCTYPE(c_int)
  366. def wrapper():
  367. return func()
  368. return wrapper
  369. def start_finish_callback(func):
  370. global wrapper
  371. @CFUNCTYPE(c_int, POINTER(LiteIO), POINTER(_Ctensor), c_size_t)
  372. def wrapper(c_ios, c_tensors, size):
  373. ios = {}
  374. for i in range(size):
  375. tensor = LiteTensor(physic_construct=False)
  376. tensor._tensor = c_void_p(c_tensors[i])
  377. tensor.update()
  378. io = c_ios[i]
  379. ios[io] = tensor
  380. return func(ios)
  381. return wrapper
  382. class _NetworkAPI(_LiteCObjBase):
  383. """
  384. get the network api from the lib
  385. """
  386. _api_ = [
  387. ("LITE_make_default_network", [POINTER(_Cnetwork)]),
  388. ("LITE_make_network", [POINTER(_Cnetwork), LiteConfig, _LiteNetworkIO]),
  389. ("LITE_load_model_from_mem", [_Cnetwork, c_void_p, c_size_t]),
  390. ("LITE_load_model_from_path", [_Cnetwork, c_char_p]),
  391. ("LITE_shared_weight_with_network", [_Cnetwork, _Ctensor]),
  392. ("LITE_destroy_network", [_Cnetwork]),
  393. ("LITE_forward", [_Cnetwork]),
  394. ("LITE_wait", [_Cnetwork]),
  395. ("LITE_get_io_tensor", [_Cnetwork, c_char_p, c_int, POINTER(_Ctensor)]),
  396. ("LITE_get_input_name", [_Cnetwork, c_size_t, POINTER(c_char_p)]),
  397. ("LITE_get_output_name", [_Cnetwork, c_size_t, POINTER(c_char_p)]),
  398. ("LITE_get_all_input_name", [_Cnetwork, POINTER(c_size_t), POINTER(c_char_p)]),
  399. ("LITE_get_all_output_name", [_Cnetwork, POINTER(c_size_t), POINTER(c_char_p)]),
  400. ("LITE_is_cpu_inplace_mode", [_Cnetwork, POINTER(c_int)]),
  401. ("LITE_get_cpu_threads_number", [_Cnetwork, POINTER(c_size_t)]),
  402. ("LITE_get_device_id", [_Cnetwork, POINTER(c_int)]),
  403. ("LITE_set_device_id", [_Cnetwork, c_int]),
  404. ("LITE_set_cpu_inplace_mode", [_Cnetwork]),
  405. ("LITE_use_tensorrt", [_Cnetwork]),
  406. ("LITE_set_cpu_threads_number", [_Cnetwork, c_size_t]),
  407. ("LITE_set_stream_id", [_Cnetwork, c_int]),
  408. ("LITE_get_stream_id", [_Cnetwork, POINTER(c_int)]),
  409. ("LITE_set_network_algo_policy", [_Cnetwork, c_int]),
  410. ("LITE_set_network_algo_fastrun_config", [_Cnetwork, c_int, c_int]),
  411. ("LITE_set_network_algo_workspace_limit", [_Cnetwork, c_size_t]),
  412. ("LITE_share_runtime_memroy", [_Cnetwork, _Cnetwork]),
  413. ("LITE_enable_profile_performance", [_Cnetwork, c_char_p]),
  414. ("LITE_enable_io_txt_dump", [_Cnetwork, c_char_p]),
  415. ("LITE_enable_io_bin_dump", [_Cnetwork, c_char_p]),
  416. ("LITE_set_async_callback", [_Cnetwork, LiteAsyncCallback]),
  417. ("LITE_set_start_callback", [_Cnetwork, LiteStartCallback]),
  418. ("LITE_set_finish_callback", [_Cnetwork, LiteFinishCallback]),
  419. ("LITE_get_static_memory_alloc_info", [_Cnetwork, c_char_p]),
  420. ("LITE_enable_global_layout_transform", [_Cnetwork]),
  421. ("LITE_dump_layout_transform_model", [_Cnetwork, c_char_p]),
  422. (
  423. "LITE_get_model_io_info_by_path",
  424. [c_char_p, LiteConfig, POINTER(_LiteNetworkIO)],
  425. ),
  426. (
  427. "LITE_get_model_io_info_by_memory",
  428. [c_char_p, c_size_t, LiteConfig, POINTER(_LiteNetworkIO)],
  429. ),
  430. ("LITE_extra_configure", [_Cnetwork, LiteExtraConfig]),
  431. ]
  432. class LiteNetwork(object):
  433. """
  434. the network to load a model and forward
  435. Examples:
  436. .. code-block::
  437. from megenginelite import *
  438. config = LiteConfig()
  439. config.device_type = LiteDeviceType.LITE_CPU
  440. network = LiteNetwork(config)
  441. network.load("model_path")
  442. input_name = network.get_input_name(0)
  443. input_tensor = network.get_io_tensor(input_name)
  444. output_name = network.get_output_name(0)
  445. output_tensor = network.get_io_tensor(output_name)
  446. input_tensor.set_data_by_copy(input_data)
  447. network.forward()
  448. network.wait()
  449. """
  450. _api = _NetworkAPI()._lib
  451. def __init__(self, config=None, io=None):
  452. """
  453. create a network with config and networkio
  454. """
  455. self._network = _Cnetwork()
  456. if config:
  457. self.config = config
  458. else:
  459. self.config = LiteConfig()
  460. if io:
  461. self.network_io = io
  462. else:
  463. self.network_io = LiteNetworkIO()
  464. c_network_io = self.network_io._create_network_io()
  465. self._api.LITE_make_network(byref(self._network), self.config, c_network_io)
  466. def __repr__(self):
  467. data = {"config": self.config, "IOs": self.network_io}
  468. return data.__repr__()
  469. def __del__(self):
  470. self._api.LITE_destroy_network(self._network)
  471. def load(self, path):
  472. """
  473. load network from given path
  474. """
  475. c_path = c_char_p(path.encode("utf-8"))
  476. self._api.LITE_load_model_from_path(self._network, c_path)
  477. def forward(self):
  478. """
  479. forward the network with filled input data and fill the output data
  480. to the output tensor
  481. """
  482. self._api.LITE_forward(self._network)
  483. def wait(self):
  484. """
  485. wait until forward finish in sync model
  486. """
  487. self._api.LITE_wait(self._network)
  488. def is_cpu_inplace_mode(self):
  489. """
  490. whether the network run in cpu inpalce mode
  491. Returns:
  492. if use inpalce mode return True, else return False
  493. """
  494. inplace = c_int()
  495. self._api.LITE_is_cpu_inplace_mode(self._network, byref(inplace))
  496. return bool(inplace.value)
  497. def enable_cpu_inplace_mode(self):
  498. """
  499. set cpu forward in inplace mode with which cpu forward only create one
  500. thread
  501. Note:
  502. this must be set before the network loaded
  503. """
  504. self._api.LITE_set_cpu_inplace_mode(self._network)
  505. def use_tensorrt(self):
  506. """
  507. use TensorRT
  508. Note:
  509. this must be set before the network loaded
  510. """
  511. self._api.LITE_use_tensorrt(self._network)
  512. @property
  513. def device_id(self):
  514. """
  515. get the device id
  516. Returns:
  517. the device id of current network used
  518. """
  519. device_id = c_int()
  520. self._api.LITE_get_device_id(self._network, byref(device_id))
  521. return device_id.value
  522. @device_id.setter
  523. def device_id(self, device_id):
  524. """
  525. set the device id
  526. Note:
  527. this must be set before the network loaded
  528. """
  529. self._api.LITE_set_device_id(self._network, device_id)
  530. @property
  531. def stream_id(self):
  532. """
  533. get the stream id
  534. Returns:
  535. the value of stream id set for detwork
  536. """
  537. stream_id = c_int()
  538. self._api.LITE_get_stream_id(self._network, byref(stream_id))
  539. return stream_id.value
  540. @stream_id.setter
  541. def stream_id(self, stream_id):
  542. """
  543. set the stream id
  544. Note:
  545. this must be set before the network loaded
  546. """
  547. self._api.LITE_set_stream_id(self._network, stream_id)
  548. @property
  549. def threads_number(self):
  550. """
  551. get the thread number of the netwrok
  552. Returns:
  553. the number of thread set in the network
  554. """
  555. nr_thread = c_size_t()
  556. self._api.LITE_get_cpu_threads_number(self._network, byref(nr_thread))
  557. return nr_thread.value
  558. @threads_number.setter
  559. def threads_number(self, nr_threads):
  560. """
  561. set the network forward in multithread mode, and the thread number
  562. Note:
  563. this must be set before the network loaded
  564. """
  565. self._api.LITE_set_cpu_threads_number(self._network, nr_threads)
  566. def get_io_tensor(self, name, phase=LiteTensorPhase.LITE_IO):
  567. """
  568. get input or output tensor by its name
  569. Args:
  570. name: the name of io tensor
  571. phase: the type of LiteTensor, this is useful to separate input or output tensor with the same name
  572. Returns:
  573. the tensor with given name and type
  574. """
  575. if type(name) == str:
  576. c_name = c_char_p(name.encode("utf-8"))
  577. else:
  578. c_name = c_char_p(name)
  579. tensor = LiteTensor(physic_construct=False)
  580. self._api.LITE_get_io_tensor(
  581. self._network, c_name, phase, byref(tensor._tensor)
  582. )
  583. tensor.update()
  584. return tensor
  585. def get_input_name(self, index):
  586. """
  587. get the input name by the index in the network
  588. Args:
  589. index: the index of the input name
  590. Returns:
  591. the name of input tesor with given index
  592. """
  593. c_name = c_char_p()
  594. self._api.LITE_get_input_name(self._network, index, byref(c_name))
  595. return c_name.value.decode("utf-8")
  596. def get_output_name(self, index):
  597. """
  598. get the output name by the index in the network
  599. Args:
  600. index: the index of the output name
  601. Returns:
  602. the name of output tesor with given index
  603. """
  604. c_name = c_char_p()
  605. self._api.LITE_get_output_name(self._network, index, byref(c_name))
  606. return c_name.value.decode("utf-8")
  607. def get_all_input_name(self):
  608. """
  609. get all the input tensor name in the network
  610. Returns:
  611. the names of all input tesor in the network
  612. """
  613. nr_input = c_size_t()
  614. self._api.LITE_get_all_input_name(self._network, byref(nr_input), None)
  615. if nr_input.value > 0:
  616. names = (c_char_p * nr_input.value)()
  617. self._api.LITE_get_all_input_name(self._network, None, names)
  618. ret_name = [names[i].decode("utf-8") for i in range(nr_input.value)]
  619. return ret_name
  620. def get_all_output_name(self):
  621. """
  622. get all the output tensor name in the network
  623. Returns:
  624. the names of all output tesor in the network
  625. """
  626. nr_output = c_size_t()
  627. self._api.LITE_get_all_output_name(self._network, byref(nr_output), None)
  628. if nr_output.value > 0:
  629. names = (c_char_p * nr_output.value)()
  630. self._api.LITE_get_all_output_name(self._network, None, names)
  631. ret_name = [names[i].decode("utf-8") for i in range(nr_output.value)]
  632. return ret_name
  633. def extra_configure(self, extra_config):
  634. """
  635. Extra Configuration to the network.
  636. """
  637. self._api.LITE_extra_configure(self._network, extra_config)
  638. def share_weights_with(self, src_network):
  639. """
  640. share weights with the loaded network
  641. Args:
  642. src_network: the network to share weights
  643. """
  644. assert isinstance(src_network, LiteNetwork)
  645. self._api.LITE_shared_weight_with_network(self._network, src_network._network)
  646. def share_runtime_memroy(self, src_network):
  647. """
  648. share runtime memory with the srouce network
  649. Args:
  650. src_network: the network to share runtime memory
  651. """
  652. assert isinstance(src_network, LiteNetwork)
  653. self._api.LITE_share_runtime_memroy(self._network, src_network._network)
  654. def async_with_callback(self, async_callback):
  655. """
  656. set the network forwarding in async mode and set the AsyncCallback callback
  657. function
  658. Args:
  659. async_callback: the callback to set for network
  660. """
  661. callback = wrap_async_callback(async_callback)
  662. self._api.LITE_set_async_callback(self._network, callback)
  663. def set_start_callback(self, start_callback):
  664. """
  665. when the network start forward, the callback will be called,
  666. the start_callback with param mapping from LiteIO to the corresponding
  667. LiteTensor
  668. Args:
  669. start_callback: the callback to set for network
  670. """
  671. callback = start_finish_callback(start_callback)
  672. self._api.LITE_set_start_callback(self._network, callback)
  673. def set_finish_callback(self, finish_callback):
  674. """
  675. when the network finish forward, the callback will be called,
  676. the finish_callback with param mapping from LiteIO to the corresponding
  677. LiteTensor
  678. Args:
  679. finish_callback: the callback to set for network
  680. """
  681. callback = start_finish_callback(finish_callback)
  682. self._api.LITE_set_finish_callback(self._network, callback)
  683. def enable_profile_performance(self, profile_file):
  684. """
  685. enable get the network performance profiled information and save into given file
  686. Args:
  687. profile_file: the file to save profile information
  688. """
  689. c_file = profile_file.encode("utf-8")
  690. self._api.LITE_enable_profile_performance(self._network, c_file)
  691. def set_network_algo_workspace_limit(self, size_limit):
  692. """
  693. set the opr workspace limitation in the target network, some opr
  694. maybe use large of workspace to get good performance, set workspace limitation
  695. can save memory but may influence the performance
  696. Args:
  697. size_limit: the byte size of workspace limitation
  698. """
  699. self._api.LITE_set_network_algo_workspace_limit(self._network, size_limit)
  700. def set_network_algo_policy(
  701. self, policy, shared_batch_size=0, binary_equal_between_batch=False
  702. ):
  703. """
  704. set the network algorithm search policy for fast-run
  705. Args:
  706. shared_batch_size: the batch size used by fastrun,
  707. Non-zero value means that fastrun use this batch size
  708. regardless of the batch size of the model. Zero means
  709. fastrun use batch size of the model
  710. binary_equal_between_batch: if the content of each input batch is
  711. binary equal,whether the content of each output batch is
  712. promised to be equal
  713. """
  714. self._api.LITE_set_network_algo_policy(self._network, policy)
  715. self._api.LITE_set_network_algo_fastrun_config(
  716. self._network, shared_batch_size, binary_equal_between_batch
  717. )
  718. def io_txt_dump(self, txt_file):
  719. """
  720. dump all input/output tensor of all operators to the output file, in txt
  721. format, user can use this function to debug compute error
  722. Args:
  723. txt_file: the txt file
  724. """
  725. c_file = txt_file.encode("utf-8")
  726. self._api.LITE_enable_io_txt_dump(self._network, c_file)
  727. def io_bin_dump(self, bin_dir):
  728. """
  729. dump all input/output tensor of all operators to the output file, in
  730. binary format, user can use this function to debug compute error
  731. Args:
  732. bin_dir: the binary file directory
  733. """
  734. c_dir = bin_dir.encode("utf-8")
  735. self._api.LITE_enable_io_bin_dump(self._network, c_dir)
  736. def get_static_memory_alloc_info(self, log_dir="logs/test"):
  737. """
  738. get static peak memory info showed by Graph visualization
  739. Args:
  740. log_dir: the directory to save information log
  741. """
  742. c_log_dir = log_dir.encode("utf-8")
  743. self._api.LITE_get_static_memory_alloc_info(self._network, c_log_dir)
  744. def enable_global_layout_transform(self):
  745. """
  746. set global layout transform optimization for network, global
  747. layout optimization can auto determine the layout of every operator in
  748. the network by profile, thus it can improve the performance of the
  749. network forwarding
  750. """
  751. self._api.LITE_enable_global_layout_transform(self._network)
  752. def dump_layout_transform_model(self, model_file):
  753. """
  754. dump network after global layout transform optimization to the
  755. specific path
  756. Args:
  757. model_file: the file path to dump model
  758. """
  759. c_file = model_file.encode("utf-8")
  760. self._api.LITE_dump_layout_transform_model(self._network, c_file)
  761. def get_model_io_info(model_path, config=None):
  762. """
  763. get the model io information before model loaded by model path.
  764. Args:
  765. model_path: the model path to get the model IO information
  766. config the model configuration
  767. Returns:
  768. the input and output information in the network configuration
  769. """
  770. api = _NetworkAPI()._lib
  771. c_path = c_char_p(model_path.encode("utf-8"))
  772. ios = _LiteNetworkIO()
  773. if config is not None:
  774. api.LITE_get_model_io_info_by_path(c_path, config, byref(ios))
  775. else:
  776. config = LiteConfig()
  777. api.LITE_get_model_io_info_by_path(c_path, config, byref(ios))
  778. ret_ios = LiteNetworkIO()
  779. for i in range(ios.input_size):
  780. ret_ios.add_input(ios.inputs[i])
  781. for i in range(ios.output_size):
  782. ret_ios.add_output(ios.outputs[i])
  783. return ret_ios