You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

__init__.py 6.1 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. import atexit
  10. import ctypes
  11. import re
  12. import os
  13. import platform
  14. import sys
  15. if os.getenv("TERMUX_VERSION"):
  16. try:
  17. import cv2
  18. except Exception as exc:
  19. print("Run MegEngine python interface at Android/Termux env")
  20. print("!!!You need build opencv-python manually!!!, by run sh:")
  21. print(
  22. "https://github.com/MegEngine/MegEngine/blob/master/scripts/whl/android/android_opencv_python.sh"
  23. )
  24. raise exc
  25. if sys.platform == "win32":
  26. lib_path = os.path.join(os.path.dirname(__file__), "core/lib")
  27. dll_paths = list(filter(os.path.exists, [lib_path,]))
  28. assert len(dll_paths) > 0
  29. kernel32 = ctypes.WinDLL("kernel32.dll", use_last_error=True)
  30. has_load_library_attr = hasattr(kernel32, "AddDllDirectory")
  31. old_error_mode = kernel32.SetErrorMode(0x0001)
  32. kernel32.LoadLibraryW.restype = ctypes.c_void_p
  33. if has_load_library_attr:
  34. kernel32.AddDllDirectory.restype = ctypes.c_void_p
  35. kernel32.LoadLibraryExW.restype = ctypes.c_void_p
  36. for dll_path in dll_paths:
  37. if sys.version_info >= (3, 8):
  38. os.add_dll_directory(dll_path)
  39. elif has_load_library_attr:
  40. res = kernel32.AddDllDirectory(dll_path)
  41. if res is None:
  42. err = ctypes.WinError(ctypes.get_last_error())
  43. err.strerror += ' Error adding "{}" to the DLL search PATH.'.format(
  44. dll_path
  45. )
  46. raise err
  47. else:
  48. print("WARN: python or OS env have some issue, may load DLL failed!!!")
  49. import glob
  50. dlls = glob.glob(os.path.join(lib_path, "*.dll"))
  51. path_patched = False
  52. for dll in dlls:
  53. is_loaded = False
  54. if has_load_library_attr:
  55. res = kernel32.LoadLibraryExW(dll, None, 0x00001100)
  56. last_error = ctypes.get_last_error()
  57. if res is None and last_error != 126:
  58. err = ctypes.WinError(last_error)
  59. err.strerror += ' Error loading "{}" or one of its dependencies.'.format(
  60. dll
  61. )
  62. err.strerror += " \nplease install VC runtime from: "
  63. err.strerror += " \nhttps://docs.microsoft.com/en-us/cpp/windows/latest-supported-vc-redist?view=msvc-160"
  64. raise err
  65. elif res is not None:
  66. is_loaded = True
  67. if not is_loaded:
  68. if not path_patched:
  69. os.environ["PATH"] = ";".join(dll_paths + [os.environ["PATH"]])
  70. path_patched = True
  71. res = kernel32.LoadLibraryW(dll)
  72. if res is None:
  73. err = ctypes.WinError(ctypes.get_last_error())
  74. err.strerror += ' Error loading "{}" or one of its dependencies.'.format(
  75. dll
  76. )
  77. err.strerror += " \nplease install VC runtime from: "
  78. err.strerror += " \nhttps://docs.microsoft.com/en-us/cpp/windows/latest-supported-vc-redist?view=msvc-160"
  79. raise err
  80. kernel32.SetErrorMode(old_error_mode)
  81. from .core._imperative_rt.core2 import close as _close
  82. from .core._imperative_rt.core2 import full_sync as _full_sync
  83. from .core._imperative_rt.core2 import sync as _sync
  84. from .core._imperative_rt.common import (
  85. get_supported_sm_versions as _get_supported_sm_versions,
  86. )
  87. from .core._imperative_rt.utils import _set_fork_exec_path_for_timed_func
  88. from .config import *
  89. from .device import *
  90. from .logger import enable_debug_log, get_logger, set_log_file, set_log_level
  91. from .serialization import load, save
  92. from .tensor import Parameter, Tensor, tensor
  93. from .utils import comp_graph_tools as cgtools
  94. from .utils.persistent_cache import PersistentCacheOnServer as _PersistentCacheOnServer
  95. from .version import __version__
  96. logger = get_logger(__name__)
  97. ngpus = get_device_count("gpu")
  98. supported_sm_versions = re.findall(r"sm_(\d+)", _get_supported_sm_versions())
  99. for idx in range(ngpus):
  100. prop = get_cuda_device_property(idx)
  101. cur_sm = str(prop.major * 10 + prop.minor)
  102. if not cur_sm in supported_sm_versions:
  103. logger.warning(
  104. "{} with CUDA capability sm_{} is not compatible with the current MegEngine installation. The current MegEngine install supports CUDA {} {}. If you want to use the {} with MegEngine, please check the instructions at https://github.com/MegEngine/MegEngine/blob/master/scripts/cmake-build/BUILD_README.md".format(
  105. prop.name,
  106. cur_sm,
  107. "capabilities" if len(supported_sm_versions) > 1 else "capability",
  108. " ".join(["sm_" + v for v in supported_sm_versions]),
  109. prop.name,
  110. )
  111. )
  112. _set_fork_exec_path_for_timed_func(
  113. sys.executable,
  114. os.path.join(os.path.dirname(__file__), "utils", "_timed_func_fork_exec_entry.py"),
  115. )
  116. del _set_fork_exec_path_for_timed_func
  117. _exit_handlers = []
  118. def _run_exit_handlers():
  119. for handler in reversed(_exit_handlers):
  120. handler()
  121. _exit_handlers.clear()
  122. atexit.register(_run_exit_handlers)
  123. def _exit(code):
  124. _run_exit_handlers()
  125. sys.exit(code)
  126. def _atexit(handler):
  127. _exit_handlers.append(handler)
  128. _atexit(_close)
  129. _persistent_cache = _PersistentCacheOnServer()
  130. _persistent_cache.reg()
  131. _atexit(_persistent_cache.flush)
  132. # subpackages
  133. import megengine.amp
  134. import megengine.autodiff
  135. import megengine.data
  136. import megengine.distributed
  137. import megengine.dtr
  138. import megengine.functional
  139. import megengine.hub
  140. import megengine.jit
  141. import megengine.module
  142. import megengine.optimizer
  143. import megengine.quantization
  144. import megengine.random
  145. import megengine.utils
  146. import megengine.traced_module