You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

global_setting.py 3.9 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120
  1. # -*- coding: utf-8 -*-
  2. # This file is part of MegEngine, a deep learning framework developed by
  3. # Megvii.
  4. #
  5. # Copyright (c) Copyright (c) 2020-2021 Megvii Inc. All rights reserved.
  6. from ctypes import *
  7. import numpy as np
  8. from .base import _Ctensor, _lib, _LiteCObjBase
  9. from .network import *
  10. from .struct import LiteDataType, LiteDeviceType, LiteIOType, Structure
  11. from .tensor import *
  12. LiteDecryptionFunc = CFUNCTYPE(
  13. c_size_t, c_void_p, c_size_t, POINTER(c_uint8), c_size_t, c_void_p
  14. )
  15. class _GlobalAPI(_LiteCObjBase):
  16. """
  17. get the api from the lib
  18. """
  19. _api_ = [
  20. ("LITE_get_device_count", [c_int, POINTER(c_size_t)]),
  21. ("LITE_try_coalesce_all_free_memory", []),
  22. (
  23. "LITE_register_decryption_and_key",
  24. [c_char_p, LiteDecryptionFunc, POINTER(c_uint8), c_size_t],
  25. ),
  26. (
  27. "LITE_update_decryption_or_key",
  28. [c_char_p, c_void_p, POINTER(c_uint8), c_size_t],
  29. ),
  30. ("LITE_set_loader_lib_path", [c_char_p]),
  31. ("LITE_set_persistent_cache", [c_char_p, c_int]),
  32. # ('LITE_set_tensor_rt_cache', [c_char_p]),
  33. ("LITE_dump_persistent_cache", [c_char_p]),
  34. ("LITE_dump_tensor_rt_cache", [c_char_p]),
  35. ]
  36. def decryption_func(func):
  37. """the decryption function decorator
  38. :type func: a function accept three array, in_arr, key_arr and out_arr, if out_arr is None, just query the out array lenght in byte
  39. """
  40. @CFUNCTYPE(c_size_t, c_void_p, c_size_t, POINTER(c_uint8), c_size_t, c_void_p)
  41. def wrapper(c_in_data, in_length, c_key_data, key_length, c_out_data):
  42. in_arr = np.frombuffer(c_in_data, dtype=np.uint8, count=in_length)
  43. key_arr = np.frombuffer(c_key_data, dtype=np.uint8, count=key_length)
  44. if c_out_data:
  45. out_length = func(in_arr, None)
  46. out_arr = np.frombuffer(c_out_data, dtype=np.uint8, count=out_length)
  47. return func(in_arr, key_arr, out_arr)
  48. # just query the output length
  49. else:
  50. return func(in_arr, key_arr, None)
  51. return wrapper
  52. class LiteGlobal(object):
  53. """
  54. some global config in lite
  55. """
  56. _api = _GlobalAPI()._lib
  57. @staticmethod
  58. def register_decryption_and_key(decryption_name, decryption_func, key):
  59. c_name = c_char_p(decryption_name.encode("utf-8"))
  60. key_length = len(key)
  61. c_key = (c_uint8 * key_length)(*key)
  62. LiteGlobal._api.LITE_register_decryption_and_key(
  63. c_name, decryption_func, c_key, key_length
  64. )
  65. @staticmethod
  66. def update_decryption_key(decryption_name, key):
  67. c_name = c_char_p(decryption_name.encode("utf-8"))
  68. key_length = len(key)
  69. c_key = (c_uint8 * key_length)(*key)
  70. LiteGlobal._api.LITE_update_decryption_or_key(c_name, None, c_key, key_length)
  71. @staticmethod
  72. def set_loader_lib_path(path):
  73. c_path = c_char_p(path.encode("utf-8"))
  74. LiteGlobal._api.LITE_set_loader_lib_path(c_path)
  75. @staticmethod
  76. def set_persistent_cache(path, always_sync=False):
  77. c_path = c_char_p(path.encode("utf-8"))
  78. LiteGlobal._api.LITE_set_persistent_cache(c_path, always_sync)
  79. @staticmethod
  80. def set_tensorrt_cache(path):
  81. c_path = c_char_p(path.encode("utf-8"))
  82. LiteGlobal._api.LITE_set_tensorrt_cache(c_path)
  83. @staticmethod
  84. def dump_persistent_cache(path):
  85. c_path = c_char_p(path.encode("utf-8"))
  86. LiteGlobal._api.LITE_dump_persistent_cache(c_path)
  87. @staticmethod
  88. def dump_tensorrt_cache():
  89. LiteGlobal._api.LITE_dump_tensorrt_cache()
  90. @staticmethod
  91. def get_device_count(device_type):
  92. count = c_size_t()
  93. LiteGlobal._api.LITE_get_device_count(device_type, byref(count))
  94. return count.value
  95. @staticmethod
  96. def try_coalesce_all_free_memory():
  97. LiteGlobal._api.LITE_try_coalesce_all_free_memory()

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台