You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

base.py 2.6 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394
  1. # -*- coding: utf-8 -*-
  2. import ctypes
  3. import glob
  4. import logging
  5. import os
  6. import sys
  7. from ctypes import *
  8. from ._env_initlization import check_misc
  9. # check misc as soon as possible
  10. check_misc()
  11. class _LiteCLib:
  12. def __init__(self):
  13. cwd = os.getcwd()
  14. package_dir = os.path.dirname(os.path.realpath(__file__))
  15. debug_path = os.getenv("LITE_LIB_PATH")
  16. os.chdir(package_dir)
  17. lite_libs = glob.glob("libs/liblite*")
  18. os.chdir(cwd)
  19. if debug_path is None:
  20. assert len(lite_libs) == 1
  21. self._lib = CDLL(os.path.join(package_dir, lite_libs[0]))
  22. else:
  23. self._lib = CDLL(debug_path)
  24. self._register_api(
  25. "LITE_get_version", [POINTER(c_int), POINTER(c_int), POINTER(c_int)]
  26. )
  27. self.lib.LITE_get_version.restype = None
  28. self._register_api("LITE_set_log_level", [c_int])
  29. self._register_api("LITE_get_log_level", [])
  30. self._register_api("LITE_get_last_error", [], False)
  31. self.lib.LITE_get_last_error.restype = c_char_p
  32. def _errcheck(self, result, func, args):
  33. if result:
  34. error = self.lib.LITE_get_last_error()
  35. msg = error.decode("utf-8")
  36. logging.error("{}".format(msg))
  37. raise RuntimeError("{}".format(msg))
  38. return result
  39. def _register_api(self, api_name, arg_types, error_check=True):
  40. func = getattr(self.lib, api_name)
  41. func.argtypes = arg_types
  42. func.restype = c_int
  43. if error_check:
  44. func.errcheck = self._errcheck
  45. @property
  46. def lib(self):
  47. return self._lib
  48. @property
  49. def version(self):
  50. major = c_int()
  51. minor = c_int()
  52. patch = c_int()
  53. self.lib.LITE_get_version(byref(major), byref(minor), byref(patch))
  54. return "{}.{}.{}".format(major.value, minor.value, patch.value)
  55. def set_log_level(self, level):
  56. self.lib.LITE_set_log_level(level)
  57. def get_log_level(self):
  58. return self.lib.LITE_get_log_level()
  59. _lib = _LiteCLib()
  60. version = _lib.version
  61. set_log_level = _lib.set_log_level
  62. get_log_level = _lib.get_log_level
  63. _Cnetwork = c_void_p
  64. _Ctensor = c_void_p
  65. class _LiteCObjMetaClass(type):
  66. """metaclass for lite object"""
  67. def __new__(cls, name, bases, attrs):
  68. for api in attrs["_api_"]:
  69. _lib._register_api(*api)
  70. del attrs["_api_"]
  71. attrs["_lib"] = _lib.lib
  72. return super().__new__(cls, name, bases, attrs)
  73. class _LiteCObjBase(metaclass=_LiteCObjMetaClass):
  74. _api_ = []