You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

network.py 34 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016
  1. # -*- coding: utf-8 -*-
  2. from ctypes import *
  3. import numpy as np
  4. from .base import _Cnetwork, _Ctensor, _lib, _LiteCObjBase
  5. from .struct import *
  6. from .tensor import *
  7. class LiteOptions(Structure):
  8. """
  9. the inference options which can optimize the network forwarding
  10. performance
  11. Attributes:
  12. weight_preprocess: is the option which optimize the inference performance
  13. with processing the weights of the network ahead
  14. fuse_preprocess: fuse preprocess patten, like astype + pad_channel +
  15. dimshuffle
  16. fake_next_exec: whether only to perform non-computing tasks (like
  17. memory allocation and queue initialization) for next exec. This will be
  18. reset to false when the graph is executed.
  19. var_sanity_check_first_run: Disable var sanity check on the first run.
  20. Var sanity check is enabled on the first-time execution by default, and can
  21. be used to find some potential memory access errors in the operator
  22. const_shape: used to reduce memory usage and improve performance since some
  23. static inference data structures can be omitted and some operators can be
  24. compute before forwarding
  25. force_dynamic_alloc: force dynamic allocate memory for all vars
  26. force_output_dynamic_alloc: force dynamic allocate memory for output tensor
  27. which are used as the input of CallbackCaller Operator
  28. no_profiling_on_shape_change: do not re-profile to select best implement
  29. algo when input shape changes (use previous algo)
  30. jit_level: Execute supported operators with JIT (support MLIR,
  31. NVRTC). Can only be used on Nvidia GPUs and X86 CPU, this value indicates JIT level:
  32. level 1: for JIT execute with basic elemwise operator
  33. level 2: for JIT execute elemwise and reduce operators
  34. record_level: flags to optimize the inference performance with record the
  35. kernel tasks in first run, hereafter the inference all need is to execute the
  36. recorded tasks.
  37. level = 0 means the normal inference
  38. level = 1 means use record inference
  39. level = 2 means record inference with free the extra memory
  40. graph_opt_level: network optimization level:
  41. 0: disable
  42. 1: level-1: inplace arith transformations during graph construction
  43. 2: level-2: level-1, plus global optimization before graph compiling
  44. 3: also enable JIT
  45. async_exec_level: level of dispatch on separate threads for different comp_node.
  46. 0: do not perform async dispatch
  47. 1: dispatch async if there are more than one comp node with limited queue
  48. mask 0b10: async if there are multiple comp nodes with
  49. mask 0b100: always async
  50. Examples:
  51. .. code-block::
  52. from megenginelite import *
  53. options = LiteOptions()
  54. options.weight_preprocess = true
  55. options.record_level = 1
  56. options.fuse_preprocess = true
  57. """
  58. _fields_ = [
  59. ("weight_preprocess", c_int),
  60. ("fuse_preprocess", c_int),
  61. ("fake_next_exec", c_int),
  62. ("var_sanity_check_first_run", c_int),
  63. ("const_shape", c_int),
  64. ("force_dynamic_alloc", c_int),
  65. ("force_output_dynamic_alloc", c_int),
  66. ("force_output_use_user_specified_memory", c_int),
  67. ("no_profiling_on_shape_change", c_int),
  68. ("jit_level", c_int),
  69. ("comp_node_seq_record_level", c_int),
  70. ("graph_opt_level", c_int),
  71. ("async_exec_level", c_int),
  72. # layout transform options
  73. ("enable_nchw44", c_int),
  74. ("enable_nchw44_dot", c_int),
  75. ("enable_nchw88", c_int),
  76. ("enable_nhwcd4", c_int),
  77. ("enable_nchw4", c_int),
  78. ("enable_nchw32", c_int),
  79. ("enable_nchw64", c_int),
  80. ]
  81. def __init__(self):
  82. self.weight_preprocess = False
  83. self.fuse_preprocess = False
  84. self.fake_next_exec = False
  85. self.var_sanity_check_first_run = True
  86. self.const_shape = False
  87. self.force_dynamic_alloc = False
  88. self.force_output_dynamic_alloc = False
  89. self.force_output_use_user_specified_memory = False
  90. self.no_profiling_on_shape_change = False
  91. self.jit_level = 0
  92. self.comp_node_seq_record_level = 0
  93. self.graph_opt_level = 2
  94. self.async_exec_level = 1
  95. def __repr__(self):
  96. data = {
  97. "weight_preprocess": bool(self.weight_preprocess),
  98. "fuse_preprocess": bool(self.fuse_preprocess),
  99. "fake_next_exec": bool(self.fake_next_exec),
  100. "var_sanity_check_first_run": bool(self.var_sanity_check_first_run),
  101. "const_shape": bool(self.const_shape),
  102. "force_dynamic_alloc": bool(self.force_dynamic_alloc),
  103. "force_output_dynamic_alloc": bool(self.force_output_dynamic_alloc),
  104. "force_output_use_user_specified_memory": bool(
  105. self.force_output_use_user_specified_memory
  106. ),
  107. "no_profiling_on_shape_change": bool(self.no_profiling_on_shape_change),
  108. "jit_level": self.jit_level,
  109. "comp_node_seq_record_level": self.comp_node_seq_record_level,
  110. "graph_opt_level": self.graph_opt_level,
  111. "async_exec_level": self.async_exec_level,
  112. }
  113. return data.__repr__()
  114. class LiteConfig(Structure):
  115. """
  116. Configuration when load and compile a network
  117. Attributes:
  118. has_compression: flag whether the model is compressed, the compress
  119. method is stored in the model
  120. device_id: configure the device id of a network
  121. device_type: configure the device type of a network
  122. backend: configure the inference backend of a network, now only support
  123. megengine
  124. bare_model_cryption_name: is the bare model encryption method name, bare
  125. model is not packed with json information, this encryption method name is
  126. useful to decrypt the encrypted bare model
  127. options: configuration of Options
  128. auto_optimize_inference: lite will detect the device information add set the options heuristically
  129. discrete_input_name: configure which input is composed of discrete multiple tensors
  130. Examples:
  131. .. code-block::
  132. from megenginelite import *
  133. config = LiteConfig()
  134. config.has_compression = False
  135. config.device_type = LiteDeviceType.LITE_CPU
  136. config.backend = LiteBackend.LITE_DEFAULT
  137. config.bare_model_cryption_name = "AES_default".encode("utf-8")
  138. config.auto_optimize_inference = False
  139. """
  140. _fields_ = [
  141. ("has_compression", c_int),
  142. ("device_id", c_int),
  143. ("device_type", c_int),
  144. ("backend", c_int),
  145. ("_bare_model_cryption_name", c_char_p),
  146. ("options", LiteOptions),
  147. ("auto_optimize_inference", c_int),
  148. ("discrete_input_name", c_char_p),
  149. ]
  150. def __init__(self, device_type=LiteDeviceType.LITE_CPU, option=None):
  151. self.device_type = device_type
  152. if option:
  153. self.options = option
  154. else:
  155. self.options = LiteOptions()
  156. self._bare_model_cryption_name = c_char_p(b"")
  157. self.use_loader_dynamic_param = 0
  158. self.has_compression = 0
  159. self.backend = LiteBackend.LITE_DEFAULT
  160. self.auto_optimize_inference = 0
  161. self.discrete_input_name = c_char_p(b"")
  162. @property
  163. def bare_model_cryption_name(self):
  164. return self._bare_model_cryption_name.decode("utf-8")
  165. @bare_model_cryption_name.setter
  166. def bare_model_cryption_name(self, name):
  167. if isinstance(name, str):
  168. self._bare_model_cryption_name = name.encode("utf-8")
  169. else:
  170. assert isinstance(name, bytes), "name should be str or bytes type."
  171. self._bare_model_cryption_name = name
  172. def __repr__(self):
  173. data = {
  174. "has_compression": bool(self.has_compression),
  175. "device_id": LiteDeviceType(self.device_id),
  176. "device_type": LiteDeviceType(self.device_type),
  177. "backend": LiteBackend(self.backend),
  178. "bare_model_cryption_name": self.bare_model_cryption_name,
  179. "options": self.options,
  180. "auto_optimize_inference": self.auto_optimize_inference,
  181. "discrete_input_name": self.discrete_input_name,
  182. }
  183. return data.__repr__()
  184. class LiteExtraConfig(Structure):
  185. """
  186. Extra configuration when load and compile the graph
  187. disable_configure_by_model_info: disable the configuration dumped with
  188. model, if set true, all configuration in the model will not apply, users
  189. should configure the network.
  190. """
  191. _fields_ = [
  192. ("disable_configure_by_model_info", c_int),
  193. ]
  194. def __init__(self, disable_model_config=False):
  195. self.disable_configure_by_model_info = disable_model_config
  196. def __repr__(self):
  197. data = {
  198. "disable_configure_by_model_info": bool(
  199. self.disable_configure_by_model_info
  200. ),
  201. }
  202. return data.__repr__()
  203. class LiteIO(Structure):
  204. """
  205. config the network input and output item, the input and output tensor
  206. information will describe there
  207. Attributes:
  208. name: the tensor name in the graph corresponding to the IO
  209. is_host: Used to mark where the input tensor comes from and where the output
  210. tensor will copy to, if is_host is true, the input is from host and output copy
  211. to host, otherwise in device. Sometimes the input is from device and output no need
  212. copy to host, default is true.
  213. io_type: The IO type, it can be SHAPE or VALUE, when SHAPE is set, the input or
  214. output tensor value is invaid, only shape will be set, default is VALUE
  215. config_layout: The layout of the config from user, if other layout is set before
  216. forward or get after forward, this layout will by pass. if no other
  217. layout is set before forward, this layout will work. if this layout is
  218. no set, the model will forward with its origin layout. if in output, it
  219. will used to check.
  220. Note:
  221. if other layout is set to input tensor before forwarding, this layout will not work
  222. if no layout is set before forwarding, the model will forward with its origin layout
  223. if layout is set in output tensor, it will used to check whether the layout computed from the network is correct
  224. Examples:
  225. .. code-block::
  226. from megenginelite import *
  227. io = LiteIO(
  228. "data2",
  229. is_host=True,
  230. io_type=LiteIOType.LITE_IO_SHAPE,
  231. layout=LiteLayout([2, 4, 4]),
  232. )
  233. """
  234. _fields_ = [
  235. ("_name", c_char_p),
  236. ("is_host", c_int),
  237. ("io_type", c_int),
  238. ("config_layout", LiteLayout),
  239. ]
  240. def __init__(
  241. self, name, is_host=True, io_type=LiteIOType.LITE_IO_VALUE, layout=None
  242. ):
  243. if type(name) == str:
  244. self._name = c_char_p(name.encode("utf-8"))
  245. else:
  246. self._name = c_char_p(name)
  247. if layout:
  248. self.config_layout = layout
  249. else:
  250. self.config_layout = LiteLayout()
  251. self.is_host = is_host
  252. self.io_type = io_type
  253. @property
  254. def name(self):
  255. """
  256. get the name of IO item
  257. """
  258. return self._name.decode("utf-8")
  259. @name.setter
  260. def name(self, name):
  261. """
  262. set the name of IO item
  263. """
  264. if isinstance(name, str):
  265. self._name = name.encode("utf-8")
  266. else:
  267. assert isinstance(name, bytes), "name should be str or bytes type."
  268. self._name = name
  269. def __repr__(self):
  270. data = {
  271. "name": self.name,
  272. "is_host": bool(self.is_host),
  273. "io_type": LiteIOType(self.io_type),
  274. "config_layout": self.config_layout,
  275. }
  276. return data.__repr__()
  277. def __hash__(self):
  278. return hash(self.name)
  279. class _LiteNetworkIO(Structure):
  280. _fields_ = [
  281. ("inputs", POINTER(LiteIO)),
  282. ("outputs", POINTER(LiteIO)),
  283. ("input_size", c_size_t),
  284. ("output_size", c_size_t),
  285. ]
  286. def __init__(self):
  287. self.inputs = POINTER(LiteIO)()
  288. self.outputs = POINTER(LiteIO)()
  289. self.input_size = 0
  290. self.output_size = 0
  291. class LiteNetworkIO(object):
  292. """
  293. the input and output information when load the network for user
  294. the NetworkIO will remain in the network until the network is destroyed.
  295. Attributes:
  296. inputs: The all input tensors information that will configure to the network
  297. outputs: The all output tensors information that will configure to the network
  298. Examples:
  299. .. code-block::
  300. from megenginelite import *
  301. input_io = LiteIO("data", is_host=False, io_type=LiteIOType.LITE_IO_VALUE)
  302. io = LiteNetworkIO()
  303. io.add_input(input_io)
  304. output_io = LiteIO("out", is_host=True, layout=LiteLayout([1, 1000]))
  305. io.add_output(output_io)
  306. """
  307. def __init__(self, inputs=None, outputs=None):
  308. self.inputs = []
  309. self.outputs = []
  310. if inputs:
  311. for i in inputs:
  312. if isinstance(i, list):
  313. self.inputs.append(LiteIO(*i))
  314. else:
  315. assert isinstance(
  316. i, LiteIO
  317. ), "the param to construct LiteNetworkIO must be list of the LiteIO member or the LiteIO."
  318. self.inputs.append(i)
  319. if outputs:
  320. for i in outputs:
  321. if isinstance(i, list):
  322. self.outputs.append(LiteIO(*i))
  323. else:
  324. assert isinstance(
  325. i, LiteIO
  326. ), "the param to construct LiteNetworkIO must be list of the LiteIO member or the LiteIO."
  327. self.outputs.append(i)
  328. def add_input(
  329. self, obj, is_host=True, io_type=LiteIOType.LITE_IO_VALUE, layout=None
  330. ):
  331. """
  332. add input information into LiteNetworkIO
  333. """
  334. if isinstance(obj, LiteIO):
  335. self.inputs.append(obj)
  336. else:
  337. name = obj
  338. self.add_input(LiteIO(name, is_host, io_type, layout))
  339. def add_output(
  340. self, obj, is_host=True, io_type=LiteIOType.LITE_IO_VALUE, layout=None
  341. ):
  342. """
  343. add output information into LiteNetworkIO
  344. """
  345. if isinstance(obj, LiteIO):
  346. self.outputs.append(obj)
  347. else:
  348. name = obj
  349. self.add_output(LiteIO(name, is_host, io_type, layout))
  350. def _create_network_io(self):
  351. network_io = _LiteNetworkIO()
  352. length = 1 if len(self.inputs) == 0 else len(self.inputs)
  353. self.c_inputs = (LiteIO * length)(*self.inputs)
  354. length = 1 if len(self.outputs) == 0 else len(self.outputs)
  355. self.c_outputs = (LiteIO * length)(*self.outputs)
  356. network_io.inputs = pointer(self.c_inputs[0])
  357. network_io.outputs = pointer(self.c_outputs[0])
  358. network_io.input_size = len(self.inputs)
  359. network_io.output_size = len(self.outputs)
  360. return network_io
  361. def __repr__(self):
  362. data = {"inputs": list(self.inputs), "outputs": list(self.outputs)}
  363. return data.__repr__()
  364. LiteAsyncCallback = CFUNCTYPE(c_int)
  365. LiteStartCallback = CFUNCTYPE(c_int, POINTER(LiteIO), POINTER(_Ctensor), c_size_t)
  366. LiteFinishCallback = CFUNCTYPE(c_int, POINTER(LiteIO), POINTER(_Ctensor), c_size_t)
  367. def wrap_async_callback(func):
  368. global wrapper
  369. @CFUNCTYPE(c_int)
  370. def wrapper():
  371. return func()
  372. return wrapper
  373. def start_finish_callback(func):
  374. global wrapper
  375. @CFUNCTYPE(c_int, POINTER(LiteIO), POINTER(_Ctensor), c_size_t)
  376. def wrapper(c_ios, c_tensors, size):
  377. ios = {}
  378. for i in range(size):
  379. tensor = LiteTensor(physic_construct=False)
  380. tensor._tensor = c_void_p(c_tensors[i])
  381. tensor.update()
  382. io = c_ios[i]
  383. ios[io] = tensor
  384. return func(ios)
  385. return wrapper
  386. class _NetworkAPI(_LiteCObjBase):
  387. """
  388. get the network api from the lib
  389. """
  390. _api_ = [
  391. ("LITE_make_default_network", [POINTER(_Cnetwork)]),
  392. ("LITE_make_network", [POINTER(_Cnetwork), LiteConfig, _LiteNetworkIO]),
  393. ("LITE_load_model_from_mem", [_Cnetwork, c_void_p, c_size_t]),
  394. ("LITE_load_model_from_path", [_Cnetwork, c_char_p]),
  395. ("LITE_shared_weight_with_network", [_Cnetwork, _Ctensor]),
  396. ("LITE_destroy_network", [_Cnetwork]),
  397. ("LITE_forward", [_Cnetwork]),
  398. ("LITE_wait", [_Cnetwork]),
  399. ("LITE_get_io_tensor", [_Cnetwork, c_char_p, c_int, POINTER(_Ctensor)]),
  400. ("LITE_get_input_name", [_Cnetwork, c_size_t, POINTER(c_char_p)]),
  401. ("LITE_get_output_name", [_Cnetwork, c_size_t, POINTER(c_char_p)]),
  402. ("LITE_get_all_input_name", [_Cnetwork, POINTER(c_size_t), POINTER(c_char_p)]),
  403. ("LITE_get_all_output_name", [_Cnetwork, POINTER(c_size_t), POINTER(c_char_p)]),
  404. ("LITE_is_cpu_inplace_mode", [_Cnetwork, POINTER(c_int)]),
  405. ("LITE_get_cpu_threads_number", [_Cnetwork, POINTER(c_size_t)]),
  406. ("LITE_get_device_id", [_Cnetwork, POINTER(c_int)]),
  407. ("LITE_set_device_id", [_Cnetwork, c_int]),
  408. ("LITE_set_cpu_inplace_mode", [_Cnetwork]),
  409. ("LITE_use_tensorrt", [_Cnetwork]),
  410. ("LITE_set_cpu_threads_number", [_Cnetwork, c_size_t]),
  411. ("LITE_set_stream_id", [_Cnetwork, c_int]),
  412. ("LITE_get_stream_id", [_Cnetwork, POINTER(c_int)]),
  413. ("LITE_set_network_algo_policy", [_Cnetwork, c_int]),
  414. ("LITE_set_network_algo_fastrun_config", [_Cnetwork, c_int, c_int]),
  415. ("LITE_set_network_algo_workspace_limit", [_Cnetwork, c_size_t]),
  416. ("LITE_share_runtime_memroy", [_Cnetwork, _Cnetwork]),
  417. ("LITE_enable_profile_performance", [_Cnetwork, c_char_p]),
  418. ("LITE_enable_io_txt_dump", [_Cnetwork, c_char_p]),
  419. ("LITE_enable_io_bin_dump", [_Cnetwork, c_char_p]),
  420. ("LITE_set_async_callback", [_Cnetwork, LiteAsyncCallback]),
  421. ("LITE_set_start_callback", [_Cnetwork, LiteStartCallback]),
  422. ("LITE_set_finish_callback", [_Cnetwork, LiteFinishCallback]),
  423. ("LITE_get_static_memory_alloc_info", [_Cnetwork, c_char_p]),
  424. ("LITE_enable_global_layout_transform", [_Cnetwork]),
  425. ("LITE_dump_layout_transform_model", [_Cnetwork, c_char_p]),
  426. (
  427. "LITE_get_model_io_info_by_path",
  428. [c_char_p, LiteConfig, POINTER(_LiteNetworkIO)],
  429. ),
  430. (
  431. "LITE_get_model_io_info_by_memory",
  432. [c_char_p, c_size_t, LiteConfig, POINTER(_LiteNetworkIO)],
  433. ),
  434. ("LITE_extra_configure", [_Cnetwork, LiteExtraConfig]),
  435. (
  436. "LITE_get_discrete_tensor",
  437. [_Cnetwork, c_char_p, c_size_t, c_int, POINTER(_Ctensor)],
  438. ),
  439. ]
  440. class LiteNetwork(object):
  441. """
  442. the network to load a model and forward
  443. Examples:
  444. .. code-block::
  445. from megenginelite import *
  446. config = LiteConfig()
  447. config.device_type = LiteDeviceType.LITE_CPU
  448. network = LiteNetwork(config)
  449. network.load("model_path")
  450. input_name = network.get_input_name(0)
  451. input_tensor = network.get_io_tensor(input_name)
  452. output_name = network.get_output_name(0)
  453. output_tensor = network.get_io_tensor(output_name)
  454. input_tensor.set_data_by_copy(input_data)
  455. network.forward()
  456. network.wait()
  457. """
  458. _api = _NetworkAPI()._lib
  459. def __init__(self, config=None, io=None):
  460. """
  461. create a network with config and networkio
  462. """
  463. self._network = _Cnetwork()
  464. if config:
  465. self.config = config
  466. else:
  467. self.config = LiteConfig()
  468. if io:
  469. self.network_io = io
  470. else:
  471. self.network_io = LiteNetworkIO()
  472. c_network_io = self.network_io._create_network_io()
  473. self._api.LITE_make_network(byref(self._network), self.config, c_network_io)
  474. def __repr__(self):
  475. data = {"config": self.config, "IOs": self.network_io}
  476. return data.__repr__()
  477. def __del__(self):
  478. self._api.LITE_destroy_network(self._network)
  479. def load(self, path):
  480. """
  481. load network from given path
  482. """
  483. c_path = c_char_p(path.encode("utf-8"))
  484. self._api.LITE_load_model_from_path(self._network, c_path)
  485. def forward(self):
  486. """
  487. forward the network with filled input data and fill the output data
  488. to the output tensor
  489. """
  490. self._api.LITE_forward(self._network)
  491. def wait(self):
  492. """
  493. wait until forward finish in sync model
  494. """
  495. self._api.LITE_wait(self._network)
  496. def is_cpu_inplace_mode(self):
  497. """
  498. whether the network run in cpu inpalce mode
  499. Returns:
  500. if use inpalce mode return True, else return False
  501. """
  502. inplace = c_int()
  503. self._api.LITE_is_cpu_inplace_mode(self._network, byref(inplace))
  504. return bool(inplace.value)
  505. def enable_cpu_inplace_mode(self):
  506. """
  507. set cpu forward in inplace mode with which cpu forward only create one
  508. thread
  509. Note:
  510. this must be set before the network loaded
  511. """
  512. self._api.LITE_set_cpu_inplace_mode(self._network)
  513. def use_tensorrt(self):
  514. """
  515. use TensorRT
  516. Note:
  517. this must be set before the network loaded
  518. """
  519. self._api.LITE_use_tensorrt(self._network)
  520. @property
  521. def device_id(self):
  522. """
  523. get the device id
  524. Returns:
  525. the device id of current network used
  526. """
  527. device_id = c_int()
  528. self._api.LITE_get_device_id(self._network, byref(device_id))
  529. return device_id.value
  530. @device_id.setter
  531. def device_id(self, device_id):
  532. """
  533. set the device id
  534. Note:
  535. this must be set before the network loaded
  536. """
  537. self._api.LITE_set_device_id(self._network, device_id)
  538. @property
  539. def stream_id(self):
  540. """
  541. get the stream id
  542. Returns:
  543. the value of stream id set for detwork
  544. """
  545. stream_id = c_int()
  546. self._api.LITE_get_stream_id(self._network, byref(stream_id))
  547. return stream_id.value
  548. @stream_id.setter
  549. def stream_id(self, stream_id):
  550. """
  551. set the stream id
  552. Note:
  553. this must be set before the network loaded
  554. """
  555. self._api.LITE_set_stream_id(self._network, stream_id)
  556. @property
  557. def threads_number(self):
  558. """
  559. get the thread number of the netwrok
  560. Returns:
  561. the number of thread set in the network
  562. """
  563. nr_thread = c_size_t()
  564. self._api.LITE_get_cpu_threads_number(self._network, byref(nr_thread))
  565. return nr_thread.value
  566. @threads_number.setter
  567. def threads_number(self, nr_threads):
  568. """
  569. set the network forward in multithread mode, and the thread number
  570. Note:
  571. this must be set before the network loaded
  572. """
  573. self._api.LITE_set_cpu_threads_number(self._network, nr_threads)
  574. def get_io_tensor(self, name, phase=LiteTensorPhase.LITE_IO):
  575. """
  576. get input or output tensor by its name
  577. Args:
  578. name: the name of io tensor
  579. phase: the type of LiteTensor, this is useful to separate input or output tensor with the same name
  580. Returns:
  581. the tensor with given name and type
  582. """
  583. if type(name) == str:
  584. c_name = c_char_p(name.encode("utf-8"))
  585. else:
  586. c_name = c_char_p(name)
  587. tensor = LiteTensor(physic_construct=False)
  588. self._api.LITE_get_io_tensor(
  589. self._network, c_name, phase, byref(tensor._tensor)
  590. )
  591. tensor.update()
  592. return tensor
  593. def get_discrete_tensor(self, name, n_idx, phase=LiteTensorPhase.LITE_INPUT):
  594. """
  595. get the n_idx'th tensor in the network input tensors whose
  596. input consists of discrete multiple tensors and tensor name is name
  597. Args:
  598. name: the name of input tensor
  599. n_idx: the tensor index
  600. phase: the type of LiteTensor, this is useful to separate input tensor with the same name
  601. Returns:
  602. the tensors with given name and type
  603. """
  604. if type(name) == str:
  605. c_name = c_char_p(name.encode("utf-8"))
  606. else:
  607. c_name = c_char_p(name)
  608. tensor = LiteTensor(physic_construct=False)
  609. self._api.LITE_get_discrete_tensor(
  610. self._network, c_name, n_idx, phase, byref(tensor._tensor)
  611. )
  612. tensor.update()
  613. return tensor
  614. def get_input_name(self, index):
  615. """
  616. get the input name by the index in the network
  617. Args:
  618. index: the index of the input name
  619. Returns:
  620. the name of input tesor with given index
  621. """
  622. c_name = c_char_p()
  623. self._api.LITE_get_input_name(self._network, index, byref(c_name))
  624. return c_name.value.decode("utf-8")
  625. def get_output_name(self, index):
  626. """
  627. get the output name by the index in the network
  628. Args:
  629. index: the index of the output name
  630. Returns:
  631. the name of output tesor with given index
  632. """
  633. c_name = c_char_p()
  634. self._api.LITE_get_output_name(self._network, index, byref(c_name))
  635. return c_name.value.decode("utf-8")
  636. def get_all_input_name(self):
  637. """
  638. get all the input tensor name in the network
  639. Returns:
  640. the names of all input tesor in the network
  641. """
  642. nr_input = c_size_t()
  643. self._api.LITE_get_all_input_name(self._network, byref(nr_input), None)
  644. if nr_input.value > 0:
  645. names = (c_char_p * nr_input.value)()
  646. self._api.LITE_get_all_input_name(self._network, None, names)
  647. ret_name = [names[i].decode("utf-8") for i in range(nr_input.value)]
  648. return ret_name
  649. def get_all_output_name(self):
  650. """
  651. get all the output tensor name in the network
  652. Returns:
  653. the names of all output tesor in the network
  654. """
  655. nr_output = c_size_t()
  656. self._api.LITE_get_all_output_name(self._network, byref(nr_output), None)
  657. if nr_output.value > 0:
  658. names = (c_char_p * nr_output.value)()
  659. self._api.LITE_get_all_output_name(self._network, None, names)
  660. ret_name = [names[i].decode("utf-8") for i in range(nr_output.value)]
  661. return ret_name
  662. def extra_configure(self, extra_config):
  663. """
  664. Extra Configuration to the network.
  665. """
  666. self._api.LITE_extra_configure(self._network, extra_config)
  667. def share_weights_with(self, src_network):
  668. """
  669. share weights with the loaded network
  670. Args:
  671. src_network: the network to share weights
  672. """
  673. assert isinstance(src_network, LiteNetwork)
  674. self._api.LITE_shared_weight_with_network(self._network, src_network._network)
  675. def share_runtime_memroy(self, src_network):
  676. """
  677. share runtime memory with the srouce network
  678. Args:
  679. src_network: the network to share runtime memory
  680. """
  681. assert isinstance(src_network, LiteNetwork)
  682. self._api.LITE_share_runtime_memroy(self._network, src_network._network)
  683. def async_with_callback(self, async_callback):
  684. """
  685. set the network forwarding in async mode and set the AsyncCallback callback
  686. function
  687. Args:
  688. async_callback: the callback to set for network
  689. """
  690. callback = wrap_async_callback(async_callback)
  691. self._api.LITE_set_async_callback(self._network, callback)
  692. def set_start_callback(self, start_callback):
  693. """
  694. when the network start forward, the callback will be called,
  695. the start_callback with param mapping from LiteIO to the corresponding
  696. LiteTensor
  697. Args:
  698. start_callback: the callback to set for network
  699. """
  700. callback = start_finish_callback(start_callback)
  701. self._api.LITE_set_start_callback(self._network, callback)
  702. def set_finish_callback(self, finish_callback):
  703. """
  704. when the network finish forward, the callback will be called,
  705. the finish_callback with param mapping from LiteIO to the corresponding
  706. LiteTensor
  707. Args:
  708. finish_callback: the callback to set for network
  709. """
  710. callback = start_finish_callback(finish_callback)
  711. self._api.LITE_set_finish_callback(self._network, callback)
  712. def enable_profile_performance(self, profile_file):
  713. """
  714. enable get the network performance profiled information and save into given file
  715. Args:
  716. profile_file: the file to save profile information
  717. """
  718. c_file = profile_file.encode("utf-8")
  719. self._api.LITE_enable_profile_performance(self._network, c_file)
  720. def set_network_algo_workspace_limit(self, size_limit):
  721. """
  722. set the opr workspace limitation in the target network, some opr
  723. maybe use large of workspace to get good performance, set workspace limitation
  724. can save memory but may influence the performance
  725. Args:
  726. size_limit: the byte size of workspace limitation
  727. """
  728. self._api.LITE_set_network_algo_workspace_limit(self._network, size_limit)
  729. def set_network_algo_policy(
  730. self, policy, shared_batch_size=0, binary_equal_between_batch=False
  731. ):
  732. """
  733. set the network algorithm search policy for fast-run
  734. Args:
  735. shared_batch_size: the batch size used by fastrun,
  736. Non-zero value means that fastrun use this batch size
  737. regardless of the batch size of the model. Zero means
  738. fastrun use batch size of the model
  739. binary_equal_between_batch: if the content of each input batch is
  740. binary equal,whether the content of each output batch is
  741. promised to be equal
  742. """
  743. self._api.LITE_set_network_algo_policy(self._network, policy)
  744. self._api.LITE_set_network_algo_fastrun_config(
  745. self._network, shared_batch_size, binary_equal_between_batch
  746. )
  747. def io_txt_dump(self, txt_file):
  748. """
  749. dump all input/output tensor of all operators to the output file, in txt
  750. format, user can use this function to debug compute error
  751. Args:
  752. txt_file: the txt file
  753. """
  754. c_file = txt_file.encode("utf-8")
  755. self._api.LITE_enable_io_txt_dump(self._network, c_file)
  756. def io_bin_dump(self, bin_dir):
  757. """
  758. dump all input/output tensor of all operators to the output file, in
  759. binary format, user can use this function to debug compute error
  760. Args:
  761. bin_dir: the binary file directory
  762. """
  763. c_dir = bin_dir.encode("utf-8")
  764. self._api.LITE_enable_io_bin_dump(self._network, c_dir)
  765. def get_static_memory_alloc_info(self, log_dir="logs/test"):
  766. """
  767. get static peak memory info showed by Graph visualization
  768. Args:
  769. log_dir: the directory to save information log
  770. """
  771. c_log_dir = log_dir.encode("utf-8")
  772. self._api.LITE_get_static_memory_alloc_info(self._network, c_log_dir)
  773. def enable_global_layout_transform(self):
  774. """
  775. set global layout transform optimization for network, global
  776. layout optimization can auto determine the layout of every operator in
  777. the network by profile, thus it can improve the performance of the
  778. network forwarding
  779. """
  780. self._api.LITE_enable_global_layout_transform(self._network)
  781. def dump_layout_transform_model(self, model_file):
  782. """
  783. dump network after global layout transform optimization to the
  784. specific path
  785. Args:
  786. model_file: the file path to dump model
  787. """
  788. c_file = model_file.encode("utf-8")
  789. self._api.LITE_dump_layout_transform_model(self._network, c_file)
  790. def get_model_io_info(model_path, config=None):
  791. """
  792. get the model io information before model loaded by model path.
  793. Args:
  794. model_path: the model path to get the model IO information
  795. config the model configuration
  796. Returns:
  797. the input and output information in the network configuration
  798. """
  799. api = _NetworkAPI()._lib
  800. c_path = c_char_p(model_path.encode("utf-8"))
  801. ios = _LiteNetworkIO()
  802. if config is not None:
  803. api.LITE_get_model_io_info_by_path(c_path, config, byref(ios))
  804. else:
  805. config = LiteConfig()
  806. api.LITE_get_model_io_info_by_path(c_path, config, byref(ios))
  807. ret_ios = LiteNetworkIO()
  808. for i in range(ios.input_size):
  809. ret_ios.add_input(ios.inputs[i])
  810. for i in range(ios.output_size):
  811. ret_ios.add_output(ios.outputs[i])
  812. return ret_ios