You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

base.py 5.5 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. import ctypes
  10. import glob
  11. import logging
  12. import os
  13. import sys
  14. from ctypes import *
  15. if sys.platform == "win32":
  16. lib_path = os.path.join(os.path.dirname(__file__), "../megengine/core/lib")
  17. dll_paths = list(filter(os.path.exists, [lib_path,]))
  18. assert len(dll_paths) > 0
  19. kernel32 = ctypes.WinDLL("kernel32.dll", use_last_error=True)
  20. has_load_library_attr = hasattr(kernel32, "AddDllDirectory")
  21. old_error_mode = kernel32.SetErrorMode(0x0001)
  22. kernel32.LoadLibraryW.restype = ctypes.c_void_p
  23. if has_load_library_attr:
  24. kernel32.AddDllDirectory.restype = ctypes.c_void_p
  25. kernel32.LoadLibraryExW.restype = ctypes.c_void_p
  26. for dll_path in dll_paths:
  27. if sys.version_info >= (3, 8):
  28. os.add_dll_directory(dll_path)
  29. elif has_load_library_attr:
  30. res = kernel32.AddDllDirectory(dll_path)
  31. if res is None:
  32. err = ctypes.WinError(ctypes.get_last_error())
  33. err.strerror += ' Error adding "{}" to the DLL search PATH.'.format(
  34. dll_path
  35. )
  36. raise err
  37. else:
  38. print("WARN: python or OS env have some issue, may load DLL failed!!!")
  39. import glob
  40. dlls = glob.glob(os.path.join(lib_path, "*.dll"))
  41. path_patched = False
  42. for dll in dlls:
  43. is_loaded = False
  44. if has_load_library_attr:
  45. res = kernel32.LoadLibraryExW(dll, None, 0x00001100)
  46. last_error = ctypes.get_last_error()
  47. if res is None and last_error != 126:
  48. err = ctypes.WinError(last_error)
  49. err.strerror += ' Error loading "{}" or one of its dependencies.'.format(
  50. dll
  51. )
  52. err.strerror += " \nplease install VC runtime from: "
  53. err.strerror += " \nhttps://docs.microsoft.com/en-us/cpp/windows/latest-supported-vc-redist?view=msvc-160"
  54. raise err
  55. elif res is not None:
  56. is_loaded = True
  57. if not is_loaded:
  58. if not path_patched:
  59. os.environ["PATH"] = ";".join(dll_paths + [os.environ["PATH"]])
  60. path_patched = True
  61. res = kernel32.LoadLibraryW(dll)
  62. if res is None:
  63. err = ctypes.WinError(ctypes.get_last_error())
  64. err.strerror += ' Error loading "{}" or one of its dependencies.'.format(
  65. dll
  66. )
  67. err.strerror += " \nplease install VC runtime from: "
  68. err.strerror += " \nhttps://docs.microsoft.com/en-us/cpp/windows/latest-supported-vc-redist?view=msvc-160"
  69. raise err
  70. kernel32.SetErrorMode(old_error_mode)
  71. class _LiteCLib:
  72. def __init__(self):
  73. cwd = os.getcwd()
  74. package_dir = os.path.dirname(os.path.realpath(__file__))
  75. debug_path = os.getenv("LITE_LIB_PATH")
  76. os.chdir(package_dir)
  77. lite_libs = glob.glob("libs/liblite*")
  78. os.chdir(cwd)
  79. if debug_path is None:
  80. assert len(lite_libs) == 1
  81. self._lib = CDLL(os.path.join(package_dir, lite_libs[0]))
  82. else:
  83. self._lib = CDLL(debug_path)
  84. self._register_api(
  85. "LITE_get_version", [POINTER(c_int), POINTER(c_int), POINTER(c_int)]
  86. )
  87. self.lib.LITE_get_version.restype = None
  88. self._register_api("LITE_set_log_level", [c_int])
  89. self._register_api("LITE_get_log_level", [])
  90. self._register_api("LITE_get_last_error", [], False)
  91. self.lib.LITE_get_last_error.restype = c_char_p
  92. def _errcheck(self, result, func, args):
  93. if result:
  94. error = self.lib.LITE_get_last_error()
  95. msg = error.decode("utf-8")
  96. logging.error("{}".format(msg))
  97. raise RuntimeError("{}".format(msg))
  98. return result
  99. def _register_api(self, api_name, arg_types, error_check=True):
  100. func = getattr(self.lib, api_name)
  101. func.argtypes = arg_types
  102. func.restype = c_int
  103. if error_check:
  104. func.errcheck = self._errcheck
  105. @property
  106. def lib(self):
  107. return self._lib
  108. @property
  109. def version(self):
  110. major = c_int()
  111. minor = c_int()
  112. patch = c_int()
  113. self.lib.LITE_get_version(byref(major), byref(minor), byref(patch))
  114. return "{}.{}.{}".format(major.value, minor.value, patch.value)
  115. def set_log_level(self, level):
  116. self.lib.LITE_set_log_level(level)
  117. def get_log_level(self):
  118. return self.lib.LITE_get_log_level()
  119. _lib = _LiteCLib()
  120. version = _lib.version
  121. set_log_level = _lib.set_log_level
  122. get_log_level = _lib.get_log_level
  123. _Cnetwork = c_void_p
  124. _Ctensor = c_void_p
  125. class _LiteCObjMetaClass(type):
  126. """metaclass for lite object"""
  127. def __new__(cls, name, bases, attrs):
  128. for api in attrs["_api_"]:
  129. _lib._register_api(*api)
  130. del attrs["_api_"]
  131. attrs["_lib"] = _lib.lib
  132. return super().__new__(cls, name, bases, attrs)
  133. class _LiteCObjBase(metaclass=_LiteCObjMetaClass):
  134. _api_ = []

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台