You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

global_setting.py 5.4 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. from ctypes import *
  10. import numpy as np
  11. from .base import _Ctensor, _lib, _LiteCObjBase
  12. from .network import *
  13. from .struct import LiteDataType, LiteDeviceType, LiteIOType, Structure
  14. from .tensor import *
  15. LiteDecryptionFunc = CFUNCTYPE(
  16. c_size_t, c_void_p, c_size_t, POINTER(c_uint8), c_size_t, c_void_p
  17. )
  18. class _GlobalAPI(_LiteCObjBase):
  19. """
  20. get the api from the lib
  21. """
  22. _api_ = [
  23. ("LITE_get_device_count", [c_int, POINTER(c_size_t)]),
  24. ("LITE_try_coalesce_all_free_memory", []),
  25. (
  26. "LITE_register_decryption_and_key",
  27. [c_char_p, LiteDecryptionFunc, POINTER(c_uint8), c_size_t],
  28. ),
  29. (
  30. "LITE_update_decryption_or_key",
  31. [c_char_p, c_void_p, POINTER(c_uint8), c_size_t],
  32. ),
  33. ("LITE_set_loader_lib_path", [c_char_p]),
  34. ("LITE_set_persistent_cache", [c_char_p, c_int]),
  35. # ('LITE_set_tensor_rt_cache', [c_char_p]),
  36. ("LITE_dump_persistent_cache", [c_char_p]),
  37. ("LITE_dump_tensor_rt_cache", [c_char_p]),
  38. ("LITE_register_memory_pair", [c_void_p, c_void_p, c_size_t, c_int, c_int]),
  39. ("LITE_clear_memory_pair", [c_void_p, c_void_p, c_int, c_int]),
  40. ("LITE_lookup_physic_ptr", [c_void_p, POINTER(c_void_p), c_int, c_int]),
  41. ]
  42. def decryption_func(func):
  43. """the decryption function decorator
  44. :type func: a function accept three array, in_arr, key_arr and out_arr, if out_arr is None, just query the out array lenght in byte
  45. """
  46. @CFUNCTYPE(c_size_t, c_void_p, c_size_t, POINTER(c_uint8), c_size_t, c_void_p)
  47. def wrapper(c_in_data, in_length, c_key_data, key_length, c_out_data):
  48. in_arr = np.frombuffer(c_in_data, dtype=np.uint8, count=in_length)
  49. key_arr = np.frombuffer(c_key_data, dtype=np.uint8, count=key_length)
  50. if c_out_data:
  51. out_length = func(in_arr, None)
  52. out_arr = np.frombuffer(c_out_data, dtype=np.uint8, count=out_length)
  53. return func(in_arr, key_arr, out_arr)
  54. # just query the output length
  55. else:
  56. return func(in_arr, key_arr, None)
  57. return wrapper
  58. class LiteGlobal(object):
  59. """
  60. some global config in lite
  61. """
  62. _api = _GlobalAPI()._lib
  63. @staticmethod
  64. def register_decryption_and_key(decryption_name, decryption_func, key):
  65. c_name = c_char_p(decryption_name.encode("utf-8"))
  66. key_length = len(key)
  67. c_key = (c_uint8 * key_length)(*key)
  68. LiteGlobal._api.LITE_register_decryption_and_key(
  69. c_name, decryption_func, c_key, key_length
  70. )
  71. @staticmethod
  72. def update_decryption_key(decryption_name, key):
  73. c_name = c_char_p(decryption_name.encode("utf-8"))
  74. key_length = len(key)
  75. c_key = (c_uint8 * key_length)(*key)
  76. LiteGlobal._api.LITE_update_decryption_or_key(c_name, None, c_key, key_length)
  77. @staticmethod
  78. def set_loader_lib_path(path):
  79. c_path = c_char_p(path.encode("utf-8"))
  80. LiteGlobal._api.LITE_set_loader_lib_path(c_path)
  81. @staticmethod
  82. def set_persistent_cache(path, always_sync=False):
  83. c_path = c_char_p(path.encode("utf-8"))
  84. LiteGlobal._api.LITE_set_persistent_cache(c_path, always_sync)
  85. @staticmethod
  86. def set_tensorrt_cache(path):
  87. c_path = c_char_p(path.encode("utf-8"))
  88. LiteGlobal._api.LITE_set_tensorrt_cache(c_path)
  89. @staticmethod
  90. def dump_persistent_cache(path):
  91. c_path = c_char_p(path.encode("utf-8"))
  92. LiteGlobal._api.LITE_dump_persistent_cache(c_path)
  93. @staticmethod
  94. def dump_tensorrt_cache():
  95. LiteGlobal._api.LITE_dump_tensorrt_cache()
  96. @staticmethod
  97. def get_device_count(device_type):
  98. count = c_size_t()
  99. LiteGlobal._api.LITE_get_device_count(device_type, byref(count))
  100. return count.value
  101. @staticmethod
  102. def try_coalesce_all_free_memory():
  103. LiteGlobal._api.LITE_try_coalesce_all_free_memory()
  104. @staticmethod
  105. def register_memory_pair(
  106. vir_ptr, phy_ptr, length, device, backend=LiteBackend.LITE_DEFAULT
  107. ):
  108. assert isinstance(vir_ptr, c_void_p) and isinstance(
  109. phy_ptr, c_void_p
  110. ), "clear memory pair only accept c_void_p type."
  111. LiteGlobal._api.LITE_register_memory_pair(
  112. vir_ptr, phy_ptr, length, device, backend
  113. )
  114. @staticmethod
  115. def clear_memory_pair(vir_ptr, phy_ptr, device, backend=LiteBackend.LITE_DEFAULT):
  116. assert isinstance(vir_ptr, c_void_p) and isinstance(
  117. phy_ptr, c_void_p
  118. ), "clear memory pair only accept c_void_p type."
  119. LiteGlobal._api.LITE_clear_memory_pair(vir_ptr, phy_ptr, device, backend)
  120. @staticmethod
  121. def lookup_physic_ptr(vir_ptr, device, backend=LiteBackend.LITE_DEFAULT):
  122. assert isinstance(
  123. vir_ptr, c_void_p
  124. ), "lookup physic ptr only accept c_void_p type."
  125. mem = c_void_p()
  126. LiteGlobal._api.LITE_lookup_physic_ptr(vir_ptr, byref(mem), device, backend)
  127. return mem