You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

__init__.py 3.9 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. import atexit
  10. import ctypes
  11. import os
  12. import platform
  13. import sys
  14. if sys.platform == "win32":
  15. lib_path = os.path.join(os.path.dirname(__file__), "core/lib")
  16. dll_paths = list(filter(os.path.exists, [lib_path,]))
  17. assert len(dll_paths) > 0
  18. kernel32 = ctypes.WinDLL("kernel32.dll", use_last_error=True)
  19. has_load_library_attr = hasattr(kernel32, "AddDllDirectory")
  20. old_error_mode = kernel32.SetErrorMode(0x0001)
  21. kernel32.LoadLibraryW.restype = ctypes.c_void_p
  22. if has_load_library_attr:
  23. kernel32.AddDllDirectory.restype = ctypes.c_void_p
  24. kernel32.LoadLibraryExW.restype = ctypes.c_void_p
  25. for dll_path in dll_paths:
  26. if sys.version_info >= (3, 8):
  27. os.add_dll_directory(dll_path)
  28. elif has_load_library_attr:
  29. res = kernel32.AddDllDirectory(dll_path)
  30. if res is None:
  31. err = ctypes.WinError(ctypes.get_last_error())
  32. err.strerror += ' Error adding "{}" to the DLL search PATH.'.format(
  33. dll_path
  34. )
  35. raise err
  36. else:
  37. print("WARN: python or OS env have some issue, may load DLL failed!!!")
  38. import glob
  39. dlls = glob.glob(os.path.join(lib_path, "*.dll"))
  40. path_patched = False
  41. for dll in dlls:
  42. is_loaded = False
  43. if has_load_library_attr:
  44. res = kernel32.LoadLibraryExW(dll, None, 0x00001100)
  45. last_error = ctypes.get_last_error()
  46. if res is None and last_error != 126:
  47. err = ctypes.WinError(last_error)
  48. err.strerror += ' Error loading "{}" or one of its dependencies.'.format(
  49. dll
  50. )
  51. raise err
  52. elif res is not None:
  53. is_loaded = True
  54. if not is_loaded:
  55. if not path_patched:
  56. os.environ["PATH"] = ";".join(dll_paths + [os.environ["PATH"]])
  57. path_patched = True
  58. res = kernel32.LoadLibraryW(dll)
  59. if res is None:
  60. err = ctypes.WinError(ctypes.get_last_error())
  61. err.strerror += ' Error loading "{}" or one of its dependencies.'.format(
  62. dll
  63. )
  64. raise err
  65. kernel32.SetErrorMode(old_error_mode)
  66. from .core._imperative_rt.core2 import release_trace_apply_func, sync, full_sync
  67. from .core._imperative_rt.utils import _set_fork_exec_path_for_timed_func
  68. from .device import *
  69. from .logger import enable_debug_log, get_logger, set_log_file, set_log_level
  70. from .serialization import load, save
  71. from .tensor import Parameter, Tensor, tensor
  72. from .utils import comp_graph_tools as cgtools
  73. from .utils import persistent_cache
  74. from .version import __version__
  75. _set_fork_exec_path_for_timed_func(
  76. sys.executable,
  77. os.path.join(os.path.dirname(__file__), "utils", "_timed_func_fork_exec_entry.py"),
  78. )
  79. _persistent_cache_impl_ins = persistent_cache.PersistentCacheOnServer()
  80. _persistent_cache_impl_ins.reg()
  81. atexit.register(sync)
  82. atexit.register(release_trace_apply_func)
  83. del sync
  84. del release_trace_apply_func
  85. del _set_fork_exec_path_for_timed_func
  86. del _persistent_cache_impl_ins
  87. # subpackages
  88. import megengine.autodiff
  89. import megengine.data
  90. import megengine.distributed
  91. import megengine.functional
  92. import megengine.hub
  93. import megengine.jit
  94. import megengine.module
  95. import megengine.optimizer
  96. import megengine.quantization
  97. import megengine.random
  98. import megengine.utils

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台