# -*- coding: utf-8 -*- import ctypes import glob import logging import os import sys from ctypes import * from ._env_initlization import check_misc # check misc as soon as possible check_misc() class _LiteCLib: def __init__(self): cwd = os.getcwd() package_dir = os.path.dirname(os.path.realpath(__file__)) debug_path = os.getenv("LITE_LIB_PATH") os.chdir(package_dir) lite_libs = glob.glob("libs/liblite*") os.chdir(cwd) if debug_path is None: assert len(lite_libs) == 1 self._lib = CDLL(os.path.join(package_dir, lite_libs[0])) else: self._lib = CDLL(debug_path) self._register_api( "LITE_get_version", [POINTER(c_int), POINTER(c_int), POINTER(c_int)] ) self.lib.LITE_get_version.restype = None self._register_api("LITE_set_log_level", [c_int]) self._register_api("LITE_get_log_level", []) self._register_api("LITE_get_last_error", [], False) self.lib.LITE_get_last_error.restype = c_char_p def _errcheck(self, result, func, args): if result: error = self.lib.LITE_get_last_error() msg = error.decode("utf-8") logging.error("{}".format(msg)) raise RuntimeError("{}".format(msg)) return result def _register_api(self, api_name, arg_types, error_check=True): func = getattr(self.lib, api_name) func.argtypes = arg_types func.restype = c_int if error_check: func.errcheck = self._errcheck @property def lib(self): return self._lib @property def version(self): major = c_int() minor = c_int() patch = c_int() self.lib.LITE_get_version(byref(major), byref(minor), byref(patch)) return "{}.{}.{}".format(major.value, minor.value, patch.value) def set_log_level(self, level): self.lib.LITE_set_log_level(level) def get_log_level(self): return self.lib.LITE_get_log_level() _lib = _LiteCLib() version = _lib.version set_log_level = _lib.set_log_level get_log_level = _lib.get_log_level _Cnetwork = c_void_p _Ctensor = c_void_p class _LiteCObjMetaClass(type): """metaclass for lite object""" def __new__(cls, name, bases, attrs): for api in attrs["_api_"]: _lib._register_api(*api) del attrs["_api_"] attrs["_lib"] = _lib.lib return super().__new__(cls, name, bases, attrs) class _LiteCObjBase(metaclass=_LiteCObjMetaClass): _api_ = []