You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

global_setting.py 8.6 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239
  1. # -*- coding: utf-8 -*-
  2. from ctypes import *
  3. import numpy as np
  4. from .base import _Ctensor, _lib, _LiteCObjBase
  5. from .network import *
  6. from .struct import LiteDataType, LiteDeviceType, LiteIOType, Structure
  7. from .tensor import *
  8. LiteDecryptionFunc = CFUNCTYPE(
  9. c_size_t, c_void_p, c_size_t, POINTER(c_uint8), c_size_t, c_void_p
  10. )
  11. class _GlobalAPI(_LiteCObjBase):
  12. """
  13. Get APIs from the lib
  14. """
  15. _api_ = [
  16. ("LITE_get_device_count", [c_int, POINTER(c_size_t)]),
  17. ("LITE_try_coalesce_all_free_memory", []),
  18. (
  19. "LITE_register_decryption_and_key",
  20. [c_char_p, LiteDecryptionFunc, POINTER(c_uint8), c_size_t],
  21. ),
  22. (
  23. "LITE_update_decryption_or_key",
  24. [c_char_p, c_void_p, POINTER(c_uint8), c_size_t],
  25. ),
  26. ("LITE_set_loader_lib_path", [c_char_p]),
  27. ("LITE_set_persistent_cache", [c_char_p, c_int]),
  28. # ('LITE_set_tensor_rt_cache', [c_char_p]),
  29. ("LITE_dump_persistent_cache", [c_char_p]),
  30. ("LITE_dump_tensor_rt_cache", [c_char_p]),
  31. ("LITE_register_memory_pair", [c_void_p, c_void_p, c_size_t, c_int, c_int]),
  32. ("LITE_clear_memory_pair", [c_void_p, c_void_p, c_int, c_int]),
  33. ("LITE_lookup_physic_ptr", [c_void_p, POINTER(c_void_p), c_int, c_int]),
  34. ]
  35. def decryption_func(func):
  36. """the decryption function decorator.
  37. .. note::
  38. The function accept three array: ``in_arr``, ``key_arr`` and ``out_arr``.
  39. If ``out_arr`` is None, just query the out array length in byte.
  40. """
  41. @CFUNCTYPE(c_size_t, c_void_p, c_size_t, POINTER(c_uint8), c_size_t, c_void_p)
  42. def wrapper(c_in_data, in_length, c_key_data, key_length, c_out_data):
  43. in_arr = np.frombuffer(c_in_data, dtype=np.uint8, count=in_length)
  44. key_arr = np.frombuffer(c_key_data, dtype=np.uint8, count=key_length)
  45. if c_out_data:
  46. out_length = func(in_arr, None)
  47. out_arr = np.frombuffer(c_out_data, dtype=np.uint8, count=out_length)
  48. return func(in_arr, key_arr, out_arr)
  49. # just query the output length
  50. else:
  51. return func(in_arr, key_arr, None)
  52. return wrapper
  53. class LiteGlobal(object):
  54. """
  55. Some global config in lite
  56. """
  57. _api = _GlobalAPI()._lib
  58. @staticmethod
  59. def register_decryption_and_key(decryption_name, decryption_func, key):
  60. """Register a custom decryption method and key to lite
  61. Args:
  62. decryption_name: the name of the decryption, which will act as the hash
  63. key to find the decryption method.
  64. decryption_func: the decryption function, which will decrypt the model with
  65. the registered key, then return the decrypted model.
  66. See :py:func:`~.decryption_func` for more details.
  67. key: the decryption key of the method.
  68. """
  69. c_name = c_char_p(decryption_name.encode("utf-8"))
  70. key_length = len(key)
  71. c_key = (c_uint8 * key_length)(*key)
  72. LiteGlobal._api.LITE_register_decryption_and_key(
  73. c_name, decryption_func, c_key, key_length
  74. )
  75. @staticmethod
  76. def update_decryption_key(decryption_name, key):
  77. """Update decryption key of a custom decryption method.
  78. Args:
  79. decrypt_name: the name of the decryption,
  80. which will act as the hash key to find the decryption method.
  81. key: the decryption key of the method,
  82. if the length of key is zero, the key will not be updated.
  83. """
  84. c_name = c_char_p(decryption_name.encode("utf-8"))
  85. key_length = len(key)
  86. c_key = (c_uint8 * key_length)(*key)
  87. LiteGlobal._api.LITE_update_decryption_or_key(c_name, None, c_key, key_length)
  88. @staticmethod
  89. def set_loader_lib_path(path):
  90. """Set the loader path to be used in lite.
  91. Args:
  92. path: the file path which store the loader library.
  93. """
  94. c_path = c_char_p(path.encode("utf-8"))
  95. LiteGlobal._api.LITE_set_loader_lib_path(c_path)
  96. @staticmethod
  97. def set_persistent_cache(path, always_sync=False):
  98. """Set the algo policy cache file for CPU/CUDA,
  99. the algo policy cache is produced by MegEngine fast-run.
  100. Args:
  101. path: the file path which store the cache.
  102. always_sync: always update the cache file when model runs.
  103. """
  104. c_path = c_char_p(path.encode("utf-8"))
  105. LiteGlobal._api.LITE_set_persistent_cache(c_path, always_sync)
  106. @staticmethod
  107. def set_tensorrt_cache(path):
  108. """Set the TensorRT engine cache path for serialized prebuilt ICudaEngine.
  109. Args:
  110. path: the cache file path to set
  111. """
  112. c_path = c_char_p(path.encode("utf-8"))
  113. LiteGlobal._api.LITE_set_tensorrt_cache(c_path)
  114. @staticmethod
  115. def dump_persistent_cache(path):
  116. """Dump the PersistentCache policy cache to the specific file.
  117. If the network is set to profile when forward,
  118. though this the algo policy will dump to file.
  119. Args:
  120. path: the cache file path to be dumped.
  121. """
  122. c_path = c_char_p(path.encode("utf-8"))
  123. LiteGlobal._api.LITE_dump_persistent_cache(c_path)
  124. @staticmethod
  125. def dump_tensorrt_cache():
  126. """Dump the TensorRT cache to the file set in :py:func:`~.set_tensorrt_cache`."""
  127. LiteGlobal._api.LITE_dump_tensorrt_cache()
  128. @staticmethod
  129. def get_device_count(device_type):
  130. """Get the number of device of the given device type in current context.
  131. Args:
  132. device_type: the device type to be counted.
  133. Returns:
  134. the number of device.
  135. """
  136. count = c_size_t()
  137. LiteGlobal._api.LITE_get_device_count(device_type, byref(count))
  138. return count.value
  139. @staticmethod
  140. def try_coalesce_all_free_memory():
  141. """Try to coalesce all free memory in MegEngine.
  142. When call it MegEnine Lite will try to free all the unused memory
  143. thus decrease the runtime memory usage.
  144. """
  145. LiteGlobal._api.LITE_try_coalesce_all_free_memory()
  146. @staticmethod
  147. def register_memory_pair(
  148. vir_ptr, phy_ptr, length, device, backend=LiteBackend.LITE_DEFAULT
  149. ):
  150. """Register the physical and virtual address pair to the MegEngine,
  151. some device need the map from physical to virtual.
  152. Args:
  153. vir_ptr: the virtual ptr to set to MegEngine.
  154. phy_ptr: the physical ptr to set to MegEngine.
  155. length: the length of bytes to set pair memory.
  156. device: the the device to set the pair memory.
  157. backend: the backend to set the pair memory
  158. Return:
  159. Whether the register operation is successful.
  160. """
  161. assert isinstance(vir_ptr, c_void_p) and isinstance(
  162. phy_ptr, c_void_p
  163. ), "clear memory pair only accept c_void_p type."
  164. LiteGlobal._api.LITE_register_memory_pair(
  165. vir_ptr, phy_ptr, length, device, backend
  166. )
  167. @staticmethod
  168. def clear_memory_pair(vir_ptr, phy_ptr, device, backend=LiteBackend.LITE_DEFAULT):
  169. """Clear the physical and virtual address pair in MegEngine.
  170. Args:
  171. vir_ptr: the virtual ptr to set to MegEngine.
  172. phy_ptr: the physical ptr to set to MegEngine.
  173. device: the the device to set the pair memory.
  174. backend: the backend to set the pair memory.
  175. Return:
  176. Whether the clear is operation successful.
  177. """
  178. assert isinstance(vir_ptr, c_void_p) and isinstance(
  179. phy_ptr, c_void_p
  180. ), "clear memory pair only accept c_void_p type."
  181. LiteGlobal._api.LITE_clear_memory_pair(vir_ptr, phy_ptr, device, backend)
  182. @staticmethod
  183. def lookup_physic_ptr(vir_ptr, device, backend=LiteBackend.LITE_DEFAULT):
  184. """Get the physic address by the virtual address in MegEngine.
  185. Args:
  186. vir_ptr: the virtual ptr to set to MegEngine.
  187. device: the the device to set the pair memory.
  188. backend: the backend to set the pair memory.
  189. Return:
  190. The physic address to lookup.
  191. """
  192. assert isinstance(
  193. vir_ptr, c_void_p
  194. ), "lookup physic ptr only accept c_void_p type."
  195. mem = c_void_p()
  196. LiteGlobal._api.LITE_lookup_physic_ptr(vir_ptr, byref(mem), device, backend)
  197. return mem