You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

__init__.py 3.2 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. import os
  10. import sys
  11. import platform
  12. import ctypes
  13. if sys.platform == "win32":
  14. lib_path = os.path.join(os.path.dirname(__file__), "core/lib")
  15. dll_paths = list(filter(os.path.exists, [lib_path,]))
  16. assert len(dll_paths) > 0
  17. kernel32 = ctypes.WinDLL("kernel32.dll", use_last_error=True)
  18. has_load_library_attr = hasattr(kernel32, "AddDllDirectory")
  19. old_error_mode = kernel32.SetErrorMode(0x0001)
  20. kernel32.LoadLibraryW.restype = ctypes.c_void_p
  21. if has_load_library_attr:
  22. kernel32.AddDllDirectory.restype = ctypes.c_void_p
  23. kernel32.LoadLibraryExW.restype = ctypes.c_void_p
  24. for dll_path in dll_paths:
  25. if sys.version_info >= (3, 8):
  26. os.add_dll_directory(dll_path)
  27. elif has_load_library_attr:
  28. res = kernel32.AddDllDirectory(dll_path)
  29. if res is None:
  30. err = ctypes.WinError(ctypes.get_last_error())
  31. err.strerror += ' Error adding "{}" to the DLL search PATH.'.format(
  32. dll_path
  33. )
  34. raise err
  35. else:
  36. print("WARN: python or OS env have some issue, may load DLL failed!!!")
  37. import glob
  38. dlls = glob.glob(os.path.join(lib_path, "*.dll"))
  39. path_patched = False
  40. for dll in dlls:
  41. is_loaded = False
  42. if has_load_library_attr:
  43. res = kernel32.LoadLibraryExW(dll, None, 0x00001100)
  44. last_error = ctypes.get_last_error()
  45. if res is None and last_error != 126:
  46. err = ctypes.WinError(last_error)
  47. err.strerror += ' Error loading "{}" or one of its dependencies.'.format(
  48. dll
  49. )
  50. raise err
  51. elif res is not None:
  52. is_loaded = True
  53. if not is_loaded:
  54. if not path_patched:
  55. os.environ["PATH"] = ";".join(dll_paths + [os.environ["PATH"]])
  56. path_patched = True
  57. res = kernel32.LoadLibraryW(dll)
  58. if res is None:
  59. err = ctypes.WinError(ctypes.get_last_error())
  60. err.strerror += ' Error loading "{}" or one of its dependencies.'.format(
  61. dll
  62. )
  63. raise err
  64. kernel32.SetErrorMode(old_error_mode)
  65. from .core._imperative_rt.utils import _set_fork_exec_path_for_timed_func
  66. from .device import *
  67. from .logger import enable_debug_log, get_logger, set_log_file, set_log_level
  68. from .serialization import load, save
  69. from .tensor import Parameter, Tensor, tensor
  70. from .version import __version__
  71. from .utils import comp_graph_tools as cgtools
  72. _set_fork_exec_path_for_timed_func(
  73. sys.executable,
  74. os.path.join(os.path.dirname(__file__), "utils", "_timed_func_fork_exec_entry.py"),
  75. )
  76. del _set_fork_exec_path_for_timed_func

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台