You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

network.py 19 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. from ctypes import *
  10. import numpy as np
  11. from .base import _Cnetwork, _Ctensor, _lib, _LiteCObjBase
  12. from .struct import *
  13. from .tensor import *
  14. class LiteOptions(Structure):
  15. """
  16. the inference options will be used to config a network
  17. """
  18. _fields_ = [
  19. ("weight_preprocess", c_int),
  20. ("fuse_preprocess", c_int),
  21. ("fake_next_exec", c_int),
  22. ("var_sanity_check_first_run", c_int),
  23. ("const_shape", c_int),
  24. ("force_dynamic_alloc", c_int),
  25. ("force_output_dynamic_alloc", c_int),
  26. ("force_output_use_user_specified_memory", c_int),
  27. ("no_profiling_on_shape_change", c_int),
  28. ("jit_level", c_int),
  29. ("comp_node_seq_record_level", c_int),
  30. ("graph_opt_level", c_int),
  31. ("async_exec_level", c_int),
  32. # layout transform options
  33. ("enable_nchw44", c_int),
  34. ("enable_nchw44_dot", c_int),
  35. ("enable_nchw88", c_int),
  36. ("enable_nhwcd4", c_int),
  37. ("enable_nchw4", c_int),
  38. ("enable_nchw32", c_int),
  39. ("enable_nchw64", c_int),
  40. ]
  41. def __init__(self):
  42. self.weight_preprocess = False
  43. self.fuse_preprocess = False
  44. self.fake_next_exec = False
  45. self.var_sanity_check_first_run = True
  46. self.const_shape = False
  47. self.force_dynamic_alloc = False
  48. self.force_output_dynamic_alloc = False
  49. self.force_output_use_user_specified_memory = False
  50. self.no_profiling_on_shape_change = False
  51. self.jit_level = 0
  52. self.comp_node_seq_record_level = 0
  53. self.graph_opt_level = 2
  54. self.async_exec_level = 1
  55. def __repr__(self):
  56. data = {
  57. "weight_preprocess": bool(self.weight_preprocess),
  58. "fuse_preprocess": bool(self.fuse_preprocess),
  59. "fake_next_exec": bool(self.fake_next_exec),
  60. "var_sanity_check_first_run": bool(self.var_sanity_check_first_run),
  61. "const_shape": bool(self.const_shape),
  62. "force_dynamic_alloc": bool(self.force_dynamic_alloc),
  63. "force_output_dynamic_alloc": bool(self.force_output_dynamic_alloc),
  64. "force_output_nocopy": bool(self.force_output_nocopy),
  65. "no_profiling_on_shape_change": bool(self.no_profiling_on_shape_change),
  66. "jit_level": self.jit_level,
  67. "comp_node_seq_record_level": self.comp_node_seq_record_level,
  68. "graph_opt_level": self.graph_opt_level,
  69. "async_exec_level": self.async_exec_level,
  70. }
  71. return data.__repr__()
  72. class LiteConfig(Structure):
  73. """
  74. Configuration when load and compile the graph
  75. bare_model_cryption_name: is the bare model cryption method name, bare
  76. model is not pack model info inside
  77. use_loader_dynamic_param: when model forward with device loader of npu,
  78. use_loader_dynamic_param used to flag whether the loader use device input or
  79. output, if use device input or output it will set Non-zero , else set zero
  80. has_compression: flag whether the model is compressed, the compress
  81. method will used to read the model
  82. """
  83. _fields_ = [
  84. ("has_compression", c_int),
  85. ("device_id", c_int),
  86. ("device_type", c_int),
  87. ("backend", c_int),
  88. ("bare_model_cryption_name", c_char_p),
  89. ("options", LiteOptions),
  90. ]
  91. def __init__(self, device_type=LiteDeviceType.LITE_CPU, option=None):
  92. self.device_type = device_type
  93. if option:
  94. self.options = option
  95. else:
  96. self.options = LiteOptions()
  97. self.bare_model_cryption_name = c_char_p(b"")
  98. self.use_loader_dynamic_param = 0
  99. self.has_compression = 0
  100. self.backend = LiteBackend.LITE_DEFAULT
  101. def __repr__(self):
  102. data = {
  103. "has_compression": bool(self.has_compression),
  104. "device_id": LiteDeviceType(self.device_id),
  105. "device_type": LiteDeviceType(self.device_type),
  106. "backend": LiteBackend(self.backend),
  107. "bare_model_cryption_name": self.bare_model_cryption_name.decode("utf-8"),
  108. "options": self.options,
  109. }
  110. return data.__repr__()
  111. class LiteIO(Structure):
  112. """
  113. config the network input and output item
  114. name: the tensor name in the graph corresponding to the IO
  115. is_host: Used to mark where the input tensor comes from and the output where copy
  116. to, if is_host is true, the input is from host and output copy to host,
  117. otherwise device. Sometimes The input is from device and output no need
  118. copy to host, default is true.
  119. io_type: The IO type, it can be SHAPE or VALUE, when SHAPE is set, the input or
  120. output tensor value is invaid, only shape will be set, default is VALUE
  121. config_layout: The layout of the config from user, if other layout is set before
  122. forward or get after forward, this layout will by pass. if no other
  123. layout is set before forward, this layout will work. if this layout is
  124. no set, the model will forward with its origin layout. if in output, it
  125. will used to check.
  126. """
  127. _fields_ = [
  128. ("name", c_char_p),
  129. ("is_host", c_int),
  130. ("io_type", c_int),
  131. ("config_layout", LiteLayout),
  132. ]
  133. def __init__(
  134. self, name, is_host=True, io_type=LiteIOType.LITE_IO_VALUE, layout=None
  135. ):
  136. if type(name) == str:
  137. self.name = c_char_p(name.encode("utf-8"))
  138. else:
  139. self.name = c_char_p(name)
  140. if layout:
  141. self.config_layout = layout
  142. else:
  143. self.config_layout = LiteLayout()
  144. self.is_host = is_host
  145. self.io_type = io_type
  146. def __repr__(self):
  147. data = {
  148. "name": self.name,
  149. "is_host": bool(self.is_host),
  150. "io_type": LiteIOType(self.io_type),
  151. "config_layout": self.config_layout,
  152. }
  153. return data.__repr__()
  154. def __hash__(self):
  155. return hash(self.name)
  156. class _LiteNetworkIO(Structure):
  157. """
  158. the input and output information when load the network
  159. """
  160. _fields_ = [
  161. ("inputs", POINTER(LiteIO)),
  162. ("outputs", POINTER(LiteIO)),
  163. ("input_size", c_size_t),
  164. ("output_size", c_size_t),
  165. ]
  166. def __init__(self):
  167. self.inputs = POINTER(LiteIO)()
  168. self.outputs = POINTER(LiteIO)()
  169. self.input_size = 0
  170. self.output_size = 0
  171. class LiteNetworkIO(object):
  172. """
  173. the input and output information for user to construct _LiteNetWorkIO
  174. """
  175. def __init__(self):
  176. self.inputs = []
  177. self.outputs = []
  178. def add_input(self, input_io):
  179. assert isinstance(input_io, LiteIO)
  180. self.inputs.append(input_io)
  181. def add_output(self, output_io):
  182. assert isinstance(output_io, LiteIO)
  183. self.outputs.append(output_io)
  184. def _create_network_io(self):
  185. network_io = _LiteNetworkIO()
  186. length = 1 if len(self.inputs) == 0 else len(self.inputs)
  187. self.c_inputs = (LiteIO * length)(*self.inputs)
  188. length = 1 if len(self.outputs) == 0 else len(self.outputs)
  189. self.c_outputs = (LiteIO * length)(*self.outputs)
  190. network_io.inputs = pointer(self.c_inputs[0])
  191. network_io.outputs = pointer(self.c_outputs[0])
  192. network_io.input_size = len(self.inputs)
  193. network_io.output_size = len(self.outputs)
  194. return network_io
  195. def __repr__(self):
  196. data = {"inputs": list(self.inputs), "outputs": list(self.outputs)}
  197. return data.__repr__()
  198. LiteAsyncCallback = CFUNCTYPE(c_int)
  199. LiteStartCallback = CFUNCTYPE(c_int, POINTER(LiteIO), POINTER(_Ctensor), c_size_t)
  200. LiteFinishCallback = CFUNCTYPE(c_int, POINTER(LiteIO), POINTER(_Ctensor), c_size_t)
  201. def wrap_async_callback(func):
  202. global wrapper
  203. @CFUNCTYPE(c_int)
  204. def wrapper():
  205. return func()
  206. return wrapper
  207. def start_finish_callback(func):
  208. global wrapper
  209. @CFUNCTYPE(c_int, POINTER(LiteIO), POINTER(_Ctensor), c_size_t)
  210. def wrapper(c_ios, c_tensors, size):
  211. ios = {}
  212. for i in range(size):
  213. tensor = LiteTensor()
  214. tensor._tensor = c_void_p(c_tensors[i])
  215. tensor.update()
  216. io = c_ios[i]
  217. ios[io] = tensor
  218. return func(ios)
  219. return wrapper
  220. class _NetworkAPI(_LiteCObjBase):
  221. """
  222. get the network api from the lib
  223. """
  224. _api_ = [
  225. ("LITE_make_default_network", [POINTER(_Cnetwork)]),
  226. ("LITE_make_network", [POINTER(_Cnetwork), LiteConfig, _LiteNetworkIO]),
  227. ("LITE_load_model_from_mem", [_Cnetwork, c_void_p, c_size_t]),
  228. ("LITE_load_model_from_path", [_Cnetwork, c_char_p]),
  229. ("LITE_shared_weight_with_network", [_Cnetwork, _Ctensor]),
  230. ("LITE_destroy_network", [_Cnetwork]),
  231. ("LITE_forward", [_Cnetwork]),
  232. ("LITE_wait", [_Cnetwork]),
  233. ("LITE_get_io_tensor", [_Cnetwork, c_char_p, c_int, POINTER(_Ctensor)]),
  234. ("LITE_get_input_name", [_Cnetwork, c_size_t, POINTER(c_char_p)]),
  235. ("LITE_get_output_name", [_Cnetwork, c_size_t, POINTER(c_char_p)]),
  236. ("LITE_get_all_input_name", [_Cnetwork, POINTER(c_size_t), POINTER(c_char_p)]),
  237. ("LITE_get_all_output_name", [_Cnetwork, POINTER(c_size_t), POINTER(c_char_p)]),
  238. ("LITE_is_cpu_inplace_mode", [_Cnetwork, POINTER(c_int)]),
  239. ("LITE_get_cpu_threads_number", [_Cnetwork, POINTER(c_size_t)]),
  240. ("LITE_get_device_id", [_Cnetwork, POINTER(c_int)]),
  241. ("LITE_set_device_id", [_Cnetwork, c_int]),
  242. ("LITE_set_cpu_inplace_mode", [_Cnetwork]),
  243. ("LITE_use_tensorrt", [_Cnetwork]),
  244. ("LITE_set_cpu_threads_number", [_Cnetwork, c_size_t]),
  245. ("LITE_set_stream_id", [_Cnetwork, c_int]),
  246. ("LITE_get_stream_id", [_Cnetwork, POINTER(c_int)]),
  247. ("LITE_set_network_algo_policy", [_Cnetwork, c_int]),
  248. ("LITE_set_network_algo_fastrun_config", [_Cnetwork, c_int, c_int]),
  249. ("LITE_set_network_algo_workspace_limit", [_Cnetwork, c_size_t]),
  250. ("LITE_share_runtime_memroy", [_Cnetwork, _Cnetwork]),
  251. ("LITE_enable_profile_performance", [_Cnetwork, c_char_p]),
  252. ("LITE_enable_io_txt_dump", [_Cnetwork, c_char_p]),
  253. ("LITE_enable_io_bin_dump", [_Cnetwork, c_char_p]),
  254. ("LITE_set_async_callback", [_Cnetwork, LiteAsyncCallback]),
  255. ("LITE_set_start_callback", [_Cnetwork, LiteStartCallback]),
  256. ("LITE_set_finish_callback", [_Cnetwork, LiteFinishCallback]),
  257. ("LITE_get_static_memory_alloc_info", [_Cnetwork, c_char_p]),
  258. ]
  259. class LiteNetwork(object):
  260. """
  261. the network to load a model and forward
  262. """
  263. _api = _NetworkAPI()._lib
  264. def __init__(self, config=None, io=None):
  265. """
  266. create a network with config and networkio
  267. """
  268. self._network = _Cnetwork()
  269. if config:
  270. self.config = config
  271. else:
  272. self.config = LiteConfig()
  273. if io:
  274. self.network_io = io
  275. else:
  276. self.network_io = LiteNetworkIO()
  277. c_network_io = self.network_io._create_network_io()
  278. self._api.LITE_make_network(byref(self._network), self.config, c_network_io)
  279. def __repr__(self):
  280. data = {"config": self.config, "IOs": self.network_io}
  281. return data.__repr__()
  282. def __del__(self):
  283. self._api.LITE_destroy_network(self._network)
  284. def load(self, path):
  285. c_path = c_char_p(path.encode("utf-8"))
  286. self._api.LITE_load_model_from_path(self._network, c_path)
  287. def forward(self):
  288. self._api.LITE_forward(self._network)
  289. def wait(self):
  290. self._api.LITE_wait(self._network)
  291. def is_cpu_inplace_mode(self):
  292. """
  293. whether the network run in cpu inpalce mode
  294. """
  295. inplace = c_int()
  296. self._api.LITE_is_cpu_inplace_mode(self._network, byref(inplace))
  297. return bool(inplace.value)
  298. def enable_cpu_inplace_mode(self):
  299. """
  300. set cpu forward in inplace mode with which cpu forward only create one
  301. thread
  302. Note: this must be set before the network loaded
  303. """
  304. self._api.LITE_set_cpu_inplace_mode(self._network)
  305. def use_tensorrt(self):
  306. """
  307. Note: this must be set before the network loaded
  308. """
  309. self._api.LITE_use_tensorrt(self._network)
  310. @property
  311. def device_id(self):
  312. """
  313. get the device id
  314. """
  315. device_id = c_int()
  316. self._api.LITE_get_device_id(self._network, byref(device_id))
  317. return device_id.value
  318. @device_id.setter
  319. def device_id(self, device_id):
  320. """
  321. set the device id
  322. Note: this must be set before the network loaded
  323. """
  324. self._api.LITE_set_device_id(self._network, device_id)
  325. @property
  326. def stream_id(self):
  327. """
  328. get the stream id
  329. """
  330. stream_id = c_int()
  331. self._api.LITE_get_stream_id(self._network, byref(stream_id))
  332. return stream_id.value
  333. @stream_id.setter
  334. def stream_id(self, stream_id):
  335. """
  336. set the stream id
  337. Note: this must be set before the network loaded
  338. """
  339. self._api.LITE_set_stream_id(self._network, stream_id)
  340. @property
  341. def threads_number(self):
  342. """
  343. get the thread number of the netwrok
  344. """
  345. nr_thread = c_size_t()
  346. self._api.LITE_get_cpu_threads_number(self._network, byref(nr_thread))
  347. return nr_thread.value
  348. @threads_number.setter
  349. def threads_number(self, nr_threads):
  350. """
  351. set the network forward in multithread mode, and the thread number
  352. Note: this must be set before the network loaded
  353. """
  354. self._api.LITE_set_cpu_threads_number(self._network, nr_threads)
  355. def get_io_tensor(self, name, phase=LiteTensorPhase.LITE_IO):
  356. """
  357. get input or output tensor by its name
  358. """
  359. if type(name) == str:
  360. c_name = c_char_p(name.encode("utf-8"))
  361. else:
  362. c_name = c_char_p(name)
  363. tensor = LiteTensor()
  364. self._api.LITE_get_io_tensor(
  365. self._network, c_name, phase, byref(tensor._tensor)
  366. )
  367. tensor.update()
  368. return tensor
  369. def get_input_name(self, index):
  370. """
  371. get the input name by the index in the network
  372. """
  373. c_name = c_char_p()
  374. self._api.LITE_get_input_name(self._network, index, byref(c_name))
  375. return c_name.value.decode("utf-8")
  376. def get_output_name(self, index):
  377. """
  378. get the output name by the index in the network
  379. """
  380. c_name = c_char_p()
  381. self._api.LITE_get_output_name(self._network, index, byref(c_name))
  382. return c_name.value.decode("utf-8")
  383. def get_all_input_name(self):
  384. """
  385. get all the input tensor name in the network
  386. """
  387. nr_input = c_size_t()
  388. self._api.LITE_get_all_input_name(self._network, byref(nr_input), None)
  389. if nr_input.value > 0:
  390. names = (c_char_p * nr_input.value)()
  391. self._api.LITE_get_all_input_name(self._network, None, names)
  392. ret_name = [names[i].decode("utf-8") for i in range(nr_input.value)]
  393. return ret_name
  394. def get_all_output_name(self):
  395. """
  396. get all the output tensor name in the network
  397. """
  398. nr_output = c_size_t()
  399. self._api.LITE_get_all_output_name(self._network, byref(nr_output), None)
  400. if nr_output.value > 0:
  401. names = (c_char_p * nr_output.value)()
  402. self._api.LITE_get_all_output_name(self._network, None, names)
  403. ret_name = [names[i].decode("utf-8") for i in range(nr_output.value)]
  404. return ret_name
  405. def share_weights_with(self, src_network):
  406. """
  407. share weights with the loaded network
  408. """
  409. assert isinstance(src_network, LiteNetwork)
  410. self._api.LITE_shared_weight_with_network(self._network, src_network._network)
  411. def share_runtime_memroy(self, src_network):
  412. """
  413. share runtime memory with the srouce network
  414. """
  415. assert isinstance(src_network, LiteNetwork)
  416. self._api.LITE_share_runtime_memroy(self._network, src_network._network)
  417. def async_with_callback(self, async_callback):
  418. callback = wrap_async_callback(async_callback)
  419. self._api.LITE_set_async_callback(self._network, callback)
  420. def set_start_callback(self, start_callback):
  421. """
  422. when the network start forward, the callback will be called,
  423. the start_callback with param mapping from LiteIO to the corresponding
  424. LiteTensor
  425. """
  426. callback = start_finish_callback(start_callback)
  427. self._api.LITE_set_start_callback(self._network, callback)
  428. def set_finish_callback(self, finish_callback):
  429. """
  430. when the network finish forward, the callback will be called,
  431. the finish_callback with param mapping from LiteIO to the corresponding
  432. LiteTensor
  433. """
  434. callback = start_finish_callback(finish_callback)
  435. self._api.LITE_set_finish_callback(self._network, callback)
  436. def enable_profile_performance(self, profile_file):
  437. c_file = profile_file.encode("utf-8")
  438. self._api.LITE_enable_profile_performance(self._network, c_file)
  439. def set_network_algo_workspace_limit(self, size_limit):
  440. self._api.LITE_set_network_algo_workspace_limit(self._network, size_limit)
  441. def set_network_algo_policy(
  442. self, policy, shared_batch_size=0, binary_equal_between_batch=False
  443. ):
  444. """
  445. shared_batch_size: the batch size used by fastrun,
  446. Non-zero value means that fastrun use this batch size
  447. regardless of the batch size of the model. Zero means
  448. fastrun use batch size of the model
  449. binary_equal_between_batch: if the content of each input batch is
  450. binary equal,whether the content of each output batch is
  451. promised to be equal
  452. """
  453. self._api.LITE_set_network_algo_policy(self._network, policy)
  454. self._api.LITE_set_network_algo_fastrun_config(
  455. self._network, shared_batch_size, binary_equal_between_batch
  456. )
  457. def io_txt_dump(self, txt_file):
  458. c_file = txt_file.encode("utf-8")
  459. self._api.LITE_enable_io_txt_dump(self._network, c_file)
  460. def io_bin_dump(self, bin_dir):
  461. c_dir = bin_dir.encode("utf-8")
  462. self._api.LITE_enable_io_bin_dump(self._network, c_dir)
  463. def get_static_memory_alloc_info(self, log_dir="logs/test"):
  464. c_log_dir = log_dir.encode("utf-8")
  465. self._api.LITE_get_static_memory_alloc_info(self._network, c_log_dir)

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台