|
- # -*- coding: utf-8 -*-
- # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
- #
- # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
- #
- # Unless required by applicable law or agreed to in writing,
- # software distributed under the License is distributed on an
- # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-
- import ctypes
- import glob
- import logging
- import os
- import sys
- from ctypes import *
-
- if sys.platform == "win32":
- lib_path = os.path.join(os.path.dirname(__file__), "libs")
- dll_paths = list(filter(os.path.exists, [lib_path,]))
- assert len(dll_paths) > 0
-
- kernel32 = ctypes.WinDLL("kernel32.dll", use_last_error=True)
- has_load_library_attr = hasattr(kernel32, "AddDllDirectory")
- old_error_mode = kernel32.SetErrorMode(0x0001)
-
- kernel32.LoadLibraryW.restype = ctypes.c_void_p
- if has_load_library_attr:
- kernel32.AddDllDirectory.restype = ctypes.c_void_p
- kernel32.LoadLibraryExW.restype = ctypes.c_void_p
-
- for dll_path in dll_paths:
- if sys.version_info >= (3, 8):
- os.add_dll_directory(dll_path)
- elif has_load_library_attr:
- res = kernel32.AddDllDirectory(dll_path)
- if res is None:
- err = ctypes.WinError(ctypes.get_last_error())
- err.strerror += ' Error adding "{}" to the DLL search PATH.'.format(
- dll_path
- )
- raise err
- else:
- print("WARN: python or OS env have some issue, may load DLL failed!!!")
-
- import glob
-
- dlls = glob.glob(os.path.join(lib_path, "*.dll"))
- path_patched = False
- for dll in dlls:
- is_loaded = False
- if has_load_library_attr:
- res = kernel32.LoadLibraryExW(dll, None, 0x00001100)
- last_error = ctypes.get_last_error()
- if res is None and last_error != 126:
- err = ctypes.WinError(last_error)
- err.strerror += ' Error loading "{}" or one of its dependencies.'.format(
- dll
- )
- raise err
- elif res is not None:
- is_loaded = True
- if not is_loaded:
- if not path_patched:
- os.environ["PATH"] = ";".join(dll_paths + [os.environ["PATH"]])
- path_patched = True
- res = kernel32.LoadLibraryW(dll)
- if res is None:
- err = ctypes.WinError(ctypes.get_last_error())
- err.strerror += ' Error loading "{}" or one of its dependencies.'.format(
- dll
- )
- raise err
-
- kernel32.SetErrorMode(old_error_mode)
-
-
- class _LiteCLib:
- def __init__(self):
- cwd = os.getcwd()
- package_dir = os.path.dirname(os.path.realpath(__file__))
- debug_path = os.getenv("LITE_LIB_PATH")
- os.chdir(package_dir)
- lite_libs = glob.glob("libs/liblite*")
- os.chdir(cwd)
-
- if debug_path is None:
- assert len(lite_libs) == 1
- self._lib = CDLL(os.path.join(package_dir, lite_libs[0]))
- else:
- self._lib = CDLL(debug_path)
- self._register_api(
- "LITE_get_version", [POINTER(c_int), POINTER(c_int), POINTER(c_int)]
- )
- self.lib.LITE_get_version.restype = None
- self._register_api("LITE_set_log_level", [c_int])
- self._register_api("LITE_get_log_level", [])
- self._register_api("LITE_get_last_error", [], False)
- self.lib.LITE_get_last_error.restype = c_char_p
-
- def _errcheck(self, result, func, args):
- if result:
- error = self.lib.LITE_get_last_error()
- msg = error.decode("utf-8")
- logging.error("{}".format(msg))
- raise RuntimeError("{}".format(msg))
- return result
-
- def _register_api(self, api_name, arg_types, error_check=True):
- func = getattr(self.lib, api_name)
- func.argtypes = arg_types
- func.restype = c_int
- if error_check:
- func.errcheck = self._errcheck
-
- @property
- def lib(self):
- return self._lib
-
- @property
- def version(self):
- major = c_int()
- minor = c_int()
- patch = c_int()
- self.lib.LITE_get_version(byref(major), byref(minor), byref(patch))
- return "{}.{}.{}".format(major.value, minor.value, patch.value)
-
- def set_log_level(self, level):
- self.lib.LITE_set_log_level(level)
-
- def get_log_level(self):
- return self.lib.LITE_get_log_level()
-
-
- _lib = _LiteCLib()
- version = _lib.version
- set_log_level = _lib.set_log_level
- get_log_level = _lib.get_log_level
-
- _Cnetwork = c_void_p
- _Ctensor = c_void_p
-
-
- class _LiteCObjMetaClass(type):
- """metaclass for lite object"""
-
- def __new__(cls, name, bases, attrs):
- for api in attrs["_api_"]:
- _lib._register_api(*api)
- del attrs["_api_"]
- attrs["_lib"] = _lib.lib
- return super().__new__(cls, name, bases, attrs)
-
-
- class _LiteCObjBase(metaclass=_LiteCObjMetaClass):
- _api_ = []
|