You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

__init__.py 3.4 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. import os
  10. import sys
  11. import platform
  12. import ctypes
  13. import atexit
  14. if sys.platform == "win32":
  15. lib_path = os.path.join(os.path.dirname(__file__), "core/lib")
  16. dll_paths = list(filter(os.path.exists, [lib_path,]))
  17. assert len(dll_paths) > 0
  18. kernel32 = ctypes.WinDLL("kernel32.dll", use_last_error=True)
  19. has_load_library_attr = hasattr(kernel32, "AddDllDirectory")
  20. old_error_mode = kernel32.SetErrorMode(0x0001)
  21. kernel32.LoadLibraryW.restype = ctypes.c_void_p
  22. if has_load_library_attr:
  23. kernel32.AddDllDirectory.restype = ctypes.c_void_p
  24. kernel32.LoadLibraryExW.restype = ctypes.c_void_p
  25. for dll_path in dll_paths:
  26. if sys.version_info >= (3, 8):
  27. os.add_dll_directory(dll_path)
  28. elif has_load_library_attr:
  29. res = kernel32.AddDllDirectory(dll_path)
  30. if res is None:
  31. err = ctypes.WinError(ctypes.get_last_error())
  32. err.strerror += ' Error adding "{}" to the DLL search PATH.'.format(
  33. dll_path
  34. )
  35. raise err
  36. else:
  37. print("WARN: python or OS env have some issue, may load DLL failed!!!")
  38. import glob
  39. dlls = glob.glob(os.path.join(lib_path, "*.dll"))
  40. path_patched = False
  41. for dll in dlls:
  42. is_loaded = False
  43. if has_load_library_attr:
  44. res = kernel32.LoadLibraryExW(dll, None, 0x00001100)
  45. last_error = ctypes.get_last_error()
  46. if res is None and last_error != 126:
  47. err = ctypes.WinError(last_error)
  48. err.strerror += ' Error loading "{}" or one of its dependencies.'.format(
  49. dll
  50. )
  51. raise err
  52. elif res is not None:
  53. is_loaded = True
  54. if not is_loaded:
  55. if not path_patched:
  56. os.environ["PATH"] = ";".join(dll_paths + [os.environ["PATH"]])
  57. path_patched = True
  58. res = kernel32.LoadLibraryW(dll)
  59. if res is None:
  60. err = ctypes.WinError(ctypes.get_last_error())
  61. err.strerror += ' Error loading "{}" or one of its dependencies.'.format(
  62. dll
  63. )
  64. raise err
  65. kernel32.SetErrorMode(old_error_mode)
  66. from .core._imperative_rt.utils import _set_fork_exec_path_for_timed_func
  67. from .core._imperative_rt.core2 import sync
  68. from .device import *
  69. from .logger import enable_debug_log, get_logger, set_log_file, set_log_level
  70. from .serialization import load, save
  71. from .tensor import Parameter, Tensor, tensor
  72. from .version import __version__
  73. from .utils import persistent_cache, comp_graph_tools as cgtools
  74. _set_fork_exec_path_for_timed_func(
  75. sys.executable,
  76. os.path.join(os.path.dirname(__file__), "utils", "_timed_func_fork_exec_entry.py"),
  77. )
  78. _persistent_cache_impl_ins = persistent_cache.PersistentCacheOnServer()
  79. _persistent_cache_impl_ins.reg()
  80. atexit.register(sync)
  81. del sync
  82. del _set_fork_exec_path_for_timed_func
  83. del _persistent_cache_impl_ins

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台