You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

__init__.py 4.5 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. import atexit
  10. import ctypes
  11. import os
  12. import platform
  13. import sys
  14. if sys.platform == "win32":
  15. lib_path = os.path.join(os.path.dirname(__file__), "core/lib")
  16. dll_paths = list(filter(os.path.exists, [lib_path,]))
  17. assert len(dll_paths) > 0
  18. kernel32 = ctypes.WinDLL("kernel32.dll", use_last_error=True)
  19. has_load_library_attr = hasattr(kernel32, "AddDllDirectory")
  20. old_error_mode = kernel32.SetErrorMode(0x0001)
  21. kernel32.LoadLibraryW.restype = ctypes.c_void_p
  22. if has_load_library_attr:
  23. kernel32.AddDllDirectory.restype = ctypes.c_void_p
  24. kernel32.LoadLibraryExW.restype = ctypes.c_void_p
  25. for dll_path in dll_paths:
  26. if sys.version_info >= (3, 8):
  27. os.add_dll_directory(dll_path)
  28. elif has_load_library_attr:
  29. res = kernel32.AddDllDirectory(dll_path)
  30. if res is None:
  31. err = ctypes.WinError(ctypes.get_last_error())
  32. err.strerror += ' Error adding "{}" to the DLL search PATH.'.format(
  33. dll_path
  34. )
  35. raise err
  36. else:
  37. print("WARN: python or OS env have some issue, may load DLL failed!!!")
  38. import glob
  39. dlls = glob.glob(os.path.join(lib_path, "*.dll"))
  40. path_patched = False
  41. for dll in dlls:
  42. is_loaded = False
  43. if has_load_library_attr:
  44. res = kernel32.LoadLibraryExW(dll, None, 0x00001100)
  45. last_error = ctypes.get_last_error()
  46. if res is None and last_error != 126:
  47. err = ctypes.WinError(last_error)
  48. err.strerror += ' Error loading "{}" or one of its dependencies.'.format(
  49. dll
  50. )
  51. err.strerror += " \nplease install VC runtime from: "
  52. err.strerror += " \nhttps://docs.microsoft.com/en-us/cpp/windows/latest-supported-vc-redist?view=msvc-160"
  53. raise err
  54. elif res is not None:
  55. is_loaded = True
  56. if not is_loaded:
  57. if not path_patched:
  58. os.environ["PATH"] = ";".join(dll_paths + [os.environ["PATH"]])
  59. path_patched = True
  60. res = kernel32.LoadLibraryW(dll)
  61. if res is None:
  62. err = ctypes.WinError(ctypes.get_last_error())
  63. err.strerror += ' Error loading "{}" or one of its dependencies.'.format(
  64. dll
  65. )
  66. err.strerror += " \nplease install VC runtime from: "
  67. err.strerror += " \nhttps://docs.microsoft.com/en-us/cpp/windows/latest-supported-vc-redist?view=msvc-160"
  68. raise err
  69. kernel32.SetErrorMode(old_error_mode)
  70. from .core._imperative_rt.core2 import close as _close
  71. from .core._imperative_rt.core2 import full_sync as _full_sync
  72. from .core._imperative_rt.core2 import sync as _sync
  73. from .core._imperative_rt.utils import _set_fork_exec_path_for_timed_func
  74. from .device import *
  75. from .logger import enable_debug_log, get_logger, set_log_file, set_log_level
  76. from .serialization import load, save
  77. from .tensor import Parameter, Tensor, tensor
  78. from .utils import comp_graph_tools as cgtools
  79. from .utils import persistent_cache
  80. from .version import __version__
  81. _set_fork_exec_path_for_timed_func(
  82. sys.executable,
  83. os.path.join(os.path.dirname(__file__), "utils", "_timed_func_fork_exec_entry.py"),
  84. )
  85. atexit.register(_close)
  86. del _set_fork_exec_path_for_timed_func
  87. _exit_handlers = []
  88. def _run_exit_handlers():
  89. for handler in _exit_handlers:
  90. handler()
  91. _exit_handlers.clear()
  92. atexit.register(_run_exit_handlers)
  93. def _exit(code):
  94. _run_exit_handlers()
  95. sys.exit(code)
  96. def _atexit(handler):
  97. _exit_handlers.append(handler)
  98. # subpackages
  99. import megengine.amp
  100. import megengine.autodiff
  101. import megengine.data
  102. import megengine.distributed
  103. import megengine.dtr
  104. import megengine.functional
  105. import megengine.hub
  106. import megengine.jit
  107. import megengine.module
  108. import megengine.optimizer
  109. import megengine.quantization
  110. import megengine.random
  111. import megengine.utils
  112. import megengine.traced_module
  113. persistent_cache.get_manager()