You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

custom_op_tools.py 31 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902
  1. import collections
  2. import ctypes
  3. import glob
  4. import os
  5. import re
  6. import subprocess
  7. import sys
  8. import time
  9. from typing import List, Optional, Union
  10. from ..core.ops.custom import get_custom_op_abi_tag, load
  11. from ..logger import get_logger
  12. def _get_win_folder_with_ctypes(csidl_name):
  13. csidl_const = {
  14. "CSIDL_APPDATA": 26,
  15. "CSIDL_COMMON_APPDATA": 35,
  16. "CSIDL_LOCAL_APPDATA": 28,
  17. }[csidl_name]
  18. buf = ctypes.create_unicode_buffer(1024)
  19. ctypes.windll.shell32.SHGetFolderPathW(None, csidl_const, None, 0, buf)
  20. # Downgrade to short path name if have highbit chars. See
  21. # <http://bugs.activestate.com/show_bug.cgi?id=85099>.
  22. has_high_char = False
  23. for c in buf:
  24. if ord(c) > 255:
  25. has_high_char = True
  26. break
  27. if has_high_char:
  28. buf2 = ctypes.create_unicode_buffer(1024)
  29. if ctypes.windll.kernel32.GetShortPathNameW(buf.value, buf2, 1024):
  30. buf = buf2
  31. return buf.value
  32. system = sys.platform
  33. if system == "win32":
  34. _get_win_folder = _get_win_folder_with_ctypes
  35. PLAT_TO_VCVARS = {
  36. "win-amd64": "x86_amd64",
  37. }
  38. logger = get_logger()
  39. # environment varible
  40. ev_custom_op_root_dir = "MGE_CUSTOM_OP_DIR"
  41. ev_cuda_root_dir = "CUDA_ROOT_DIR"
  42. ev_cudnn_root_dir = "CUDNN_ROOT_DIR"
  43. # operating system
  44. IS_WINDOWS = system == "win32"
  45. IS_LINUX = system == "linux"
  46. IS_MACOS = system == "darwin"
  47. MGE_PATH = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
  48. MGE_INC_PATH = os.path.join(MGE_PATH, "core", "include")
  49. MGE_LIB_PATH = os.path.join(MGE_PATH, "core", "lib")
  50. MGE_ABI_VER = 0
  51. # compile version
  52. MINIMUM_GCC_VERSION = (5, 0, 0)
  53. MINIMUM_CLANG_CL_VERSION = (12, 0, 1)
  54. # compile flags
  55. COMMON_MSVC_FLAGS = [
  56. "/MD",
  57. "/wd4002",
  58. "/wd4819",
  59. "/EHsc",
  60. ]
  61. MSVC_IGNORE_CUDAFE_WARNINGS = [
  62. "field_without_dll_interface",
  63. ]
  64. COMMON_NVCC_FLAGS = []
  65. # Finds the CUDA install path
  66. def _find_cuda_root_dir() -> Optional[str]:
  67. cuda_root_dir = os.environ.get(ev_cuda_root_dir)
  68. if cuda_root_dir is None:
  69. try:
  70. which = "where" if IS_WINDOWS else "which"
  71. with open(os.devnull, "w") as devnull:
  72. nvcc = (
  73. subprocess.check_output([which, "nvcc"], stderr=devnull)
  74. .decode()
  75. .rstrip("\r\n")
  76. )
  77. cuda_root_dir = os.path.dirname(os.path.dirname(nvcc))
  78. except Exception:
  79. if IS_WINDOWS:
  80. cuda_root_dir = os.environ.get("CUDA_PATH", None)
  81. if cuda_root_dir == None:
  82. cuda_root_dirs = glob.glob(
  83. "C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v*.*"
  84. )
  85. if len(cuda_root_dirs) == 0:
  86. cuda_root_dir = ""
  87. else:
  88. cuda_root_dir = cuda_root_dirs[0]
  89. else:
  90. cuda_root_dir = "/usr/local/cuda"
  91. if not os.path.exists(cuda_root_dir):
  92. cuda_root_dir = None
  93. return cuda_root_dir
  94. def _find_cudnn_root_dir() -> Optional[str]:
  95. cudnn_root_dir = os.environ.get(ev_cudnn_root_dir)
  96. return cudnn_root_dir
  97. CUDA_ROOT_DIR = _find_cuda_root_dir()
  98. CUDNN_ROOT_DIR = _find_cudnn_root_dir()
  99. #####################################################################
  100. # Phase 1
  101. #####################################################################
  102. def _is_cuda_file(path: str) -> bool:
  103. valid_ext = [".cu", ".cuh"]
  104. return os.path.splitext(path)[1] in valid_ext
  105. # Return full path to the user-specific cache dir for this application.
  106. # Typical user cache directories are:
  107. # Mac OS X: ~/Library/Caches/<AppName>
  108. # Unix: ~/.cache/<AppName> (XDG default)
  109. # Windows: C:\Users\<username>\AppData\Local\<AppAuthor>\<AppName>\Cache
  110. def _get_user_cache_dir(appname=None, appauthor=None, version=None, opinion=True):
  111. if system == "win32":
  112. appauthor = appname if appauthor is None else appauthor
  113. path = os.path.normpath(_get_win_folder("CSIDL_LOCAL_APPDATA"))
  114. if appname:
  115. if appauthor is not False:
  116. path = os.path.join(path, appauthor)
  117. else:
  118. path = os.path.join(path, appname)
  119. if opinion:
  120. path = os.path.join(path, "Cache")
  121. elif system == "darwin":
  122. path = os.path.expanduser("~/Library/Caches")
  123. if appname:
  124. path = os.path.join(path, appname)
  125. else:
  126. path = os.getenv("XDG_CACHE_HOME", os.path.expanduser("~/.cache"))
  127. if appname:
  128. path = os.path.join(path, appname)
  129. if appname and version:
  130. path = os.path.join(path, version)
  131. return path
  132. # Returns the path to the root folder under which custom op will built.
  133. def _get_default_build_root() -> str:
  134. return os.path.realpath(_get_user_cache_dir(appname="mge_custom_op"))
  135. def _get_build_dir(name: str) -> str:
  136. custom_op_root_dir = os.environ.get(ev_custom_op_root_dir)
  137. if custom_op_root_dir is None:
  138. custom_op_root_dir = _get_default_build_root()
  139. build_dir = os.path.join(custom_op_root_dir, name)
  140. return build_dir
  141. #####################################################################
  142. # Phase 2
  143. #####################################################################
  144. def update_hash(seed, value):
  145. # using boost::hash_combine
  146. # https://www.boost.org/doc/libs/1_35_0/doc/html/boost/hash_combine_id241013.html
  147. return seed ^ (hash(value) + 0x9E3779B9 + (seed << 6) + (seed >> 2))
  148. def hash_source_files(hash_value, source_files):
  149. for filename in source_files:
  150. with open(filename) as file:
  151. hash_value = update_hash(hash_value, file.read())
  152. return hash_value
  153. def hash_build_args(hash_value, build_args):
  154. for group in build_args:
  155. for arg in group:
  156. hash_value = update_hash(hash_value, arg)
  157. return hash_value
  158. Entry = collections.namedtuple("Entry", "version, hash")
  159. class Versioner(object):
  160. def __init__(self):
  161. self.entries = {}
  162. def get_version(self, name):
  163. entry = self.entries.get(name)
  164. return None if entry is None else entry.version
  165. def bump_version_if_changed(
  166. self, name, sources, build_args, build_dir, with_cuda, with_cudnn, abi_tag
  167. ):
  168. hash_value = 0
  169. hash_value = hash_source_files(hash_value, sources)
  170. hash_value = hash_build_args(hash_value, build_args)
  171. hash_value = update_hash(hash_value, build_dir)
  172. hash_value = update_hash(hash_value, with_cuda)
  173. hash_value = update_hash(hash_value, with_cudnn)
  174. hash_value = update_hash(hash_value, abi_tag)
  175. entry = self.entries.get(name)
  176. if entry is None:
  177. self.entries[name] = entry = Entry(0, hash_value)
  178. elif hash_value != entry.hash:
  179. self.entries[name] = entry = Entry(entry.version + 1, hash_value)
  180. return entry.version
  181. custom_op_versioner = Versioner()
  182. def version_check(
  183. name, sources, build_args, build_dir, with_cuda, with_cudnn, abi_tag,
  184. ):
  185. old_version = custom_op_versioner.get_version(name)
  186. version = custom_op_versioner.bump_version_if_changed(
  187. name, sources, build_args, build_dir, with_cuda, with_cudnn, abi_tag,
  188. )
  189. return version, old_version
  190. #####################################################################
  191. # Phase 3
  192. #####################################################################
  193. def _check_ninja_availability():
  194. try:
  195. subprocess.check_output("ninja --version".split())
  196. except Exception:
  197. raise RuntimeError(
  198. "Ninja is required to build custom op, please install ninja and update your PATH"
  199. )
  200. def _mge_is_built_from_src():
  201. file_path = os.path.abspath(__file__)
  202. if "site-packages" in file_path:
  203. return False
  204. else:
  205. return True
  206. def _accepted_compilers_for_platform():
  207. if IS_WINDOWS:
  208. return ["clang-cl"]
  209. if IS_MACOS:
  210. return ["clang++", "clang"]
  211. if IS_LINUX:
  212. return ["g++", "gcc", "gnu-c++", "gnu-cc"]
  213. # Verifies that the compiler is the expected one for the current platform.
  214. def _check_compiler_existed_for_platform(compiler: str) -> bool:
  215. # there is no suitable cmd like `which` on windows, so we assume the compiler is always true on windows
  216. if IS_WINDOWS:
  217. try:
  218. version_string = subprocess.check_output(
  219. ["clang-cl", "--version"], stderr=subprocess.STDOUT
  220. ).decode()
  221. return True
  222. except Exception:
  223. return False
  224. # use os.path.realpath to resolve any symlinks, in particular from "c++" to e.g. "g++".
  225. which = subprocess.check_output(["which", compiler], stderr=subprocess.STDOUT)
  226. compiler_path = os.path.realpath(which.decode().strip())
  227. if any(name in compiler_path for name in _accepted_compilers_for_platform()):
  228. return True
  229. version_string = subprocess.check_output(
  230. [compiler, "-v"], stderr=subprocess.STDOUT
  231. ).decode()
  232. if sys.platform.startswith("linux"):
  233. pattern = re.compile("^COLLECT_GCC=(.*)$", re.MULTILINE)
  234. results = re.findall(pattern, version_string)
  235. if len(results) != 1:
  236. return False
  237. compiler_path = os.path.realpath(results[0].strip())
  238. return any(name in compiler_path for name in _accepted_compilers_for_platform())
  239. if sys.platform.startswith("darwin"):
  240. return version_string.startswith("Apple clang")
  241. return False
  242. # Verifies that the given compiler is ABI-compatible with MegEngine.
  243. def _check_compiler_abi_compatibility(compiler: str):
  244. # we think if the megengine is built from source, the user will use the same compiler to compile the custom op
  245. if _mge_is_built_from_src() or os.environ.get("MGE_CHECK_ABI", "1") == "0":
  246. return True
  247. # [TODO] There is no particular minimum version we need for clang, so we"re good here.
  248. if sys.platform.startswith("darwin"):
  249. return True
  250. try:
  251. if sys.platform.startswith("linux"):
  252. minimum_required_version = MINIMUM_GCC_VERSION
  253. versionstr = subprocess.check_output(
  254. [compiler, "-dumpfullversion", "-dumpversion"]
  255. )
  256. version = versionstr.decode().strip().split(".")
  257. else:
  258. minimum_required_version = MINIMUM_CLANG_CL_VERSION
  259. compiler_info = subprocess.check_output(
  260. [compiler, "--version"], stderr=subprocess.STDOUT
  261. )
  262. match = re.search(r"(\d+)\.(\d+)\.(\d+)", compiler_info.decode().strip())
  263. version = (0, 0, 0) if match is None else match.groups()
  264. except Exception:
  265. _, error, _ = sys.exc_info()
  266. logger.warning(
  267. "Error checking compiler version for {}: {}".format(compiler, error)
  268. )
  269. return False
  270. if tuple(map(int, version)) >= minimum_required_version:
  271. return True
  272. return False
  273. def _check_compiler_comatibility():
  274. # we use clang-cl on windows, refer: https://clang.llvm.org/docs/UsersManual.html#clang-cl
  275. compiler = (
  276. os.environ.get("CXX", "clang-cl")
  277. if IS_WINDOWS
  278. else os.environ.get("CXX", "c++")
  279. )
  280. existed = _check_compiler_existed_for_platform(compiler)
  281. if existed == False:
  282. log_str = (
  283. "Cannot find compiler which is compatible with the compiler "
  284. "MegEngine was built with for this platform, which is {mge_compiler} on "
  285. "{platform}. Please use {mge_compiler} to to compile your extension. "
  286. "Alternatively, you may compile MegEngine from source using "
  287. "{user_compiler}, and then you can also use {user_compiler} to compile "
  288. "your extension."
  289. ).format(
  290. user_compiler=compiler,
  291. mge_compiler=_accepted_compilers_for_platform()[0],
  292. platform=sys.platform,
  293. )
  294. logger.warning(log_str)
  295. return False
  296. compatible = _check_compiler_abi_compatibility(compiler)
  297. if compatible == False:
  298. log_str = (
  299. "Your compiler version may be ABI-incompatible with MegEngine! "
  300. "Please use a compiler that is ABI-compatible with GCC 5.0 on Linux "
  301. "and LLVM/Clang 12.0 on Windows ."
  302. )
  303. logger.warning(log_str)
  304. return True
  305. #####################################################################
  306. # Phase 4
  307. #####################################################################
  308. # Quote command-line arguments for DOS/Windows conventions.
  309. def _nt_quote_args(args: Optional[List[str]]) -> List[str]:
  310. # Cover None-type
  311. if not args:
  312. return []
  313. return ['"{}"'.format(arg) if " " in arg else arg for arg in args]
  314. # Now we need user to specify the arch of GPU
  315. def _get_cuda_arch_flags(cflags: Optional[List[str]] = None) -> List[str]:
  316. return []
  317. def _setup_sys_includes(with_cuda: bool, with_cudnn: bool):
  318. includes = [os.path.join(MGE_INC_PATH)]
  319. if with_cuda:
  320. includes.append(os.path.join(CUDA_ROOT_DIR, "include"))
  321. if with_cudnn:
  322. includes.append(os.path.join(CUDNN_ROOT_DIR, "include"))
  323. return includes
  324. def _setup_includes(extra_include_paths: List[str], with_cuda: bool, with_cudnn: bool):
  325. user_includes = [os.path.abspath(path) for path in extra_include_paths]
  326. system_includes = _setup_sys_includes(with_cuda, with_cudnn)
  327. if IS_WINDOWS:
  328. user_includes += system_includes
  329. system_includes.clear()
  330. return user_includes, system_includes
  331. def _setup_common_cflags(user_includes: List[str], system_includes: List[str]):
  332. common_cflags = []
  333. common_cflags += ["-I{}".format(include) for include in user_includes]
  334. common_cflags += ["-isystem {}".format(include) for include in system_includes]
  335. if not IS_WINDOWS:
  336. common_cflags += ["-D_GLIBCXX_USE_CXX11_ABI={}".format(MGE_ABI_VER)]
  337. return common_cflags
  338. def _setup_cuda_cflags(cflags: List[str], extra_cuda_cflags: List[str]):
  339. cuda_flags = cflags + COMMON_NVCC_FLAGS + _get_cuda_arch_flags()
  340. if IS_WINDOWS:
  341. for flag in COMMON_MSVC_FLAGS:
  342. cuda_flags = ["-Xcompiler", flag] + cuda_flags
  343. for ignore_warning in MSVC_IGNORE_CUDAFE_WARNINGS:
  344. cuda_flags = ["-Xcudafe", "--diag_suppress=" + ignore_warning] + cuda_flags
  345. cuda_flags = _nt_quote_args(cuda_flags)
  346. cuda_flags += _nt_quote_args(extra_cuda_cflags)
  347. else:
  348. cuda_flags += ["--compiler-options", '"-fPIC"']
  349. cuda_flags += extra_cuda_cflags
  350. if not any(flag.startswith("-std=") for flag in cuda_flags):
  351. cuda_flags.append("-std=c++14")
  352. if os.getenv("CC") is not None:
  353. cuda_flags = ["-ccbin", os.getenv("CC")] + cuda_flags
  354. return cuda_flags
  355. def _setup_ldflags(
  356. extra_ldflags: List[str], with_cuda: bool, with_cudnn: bool
  357. ) -> List[str]:
  358. ldflags = extra_ldflags
  359. if IS_WINDOWS:
  360. ldflags.append(os.path.join(MGE_LIB_PATH, "megengine_shared.lib"))
  361. if with_cuda:
  362. ldflags.append(os.path.join(CUDA_ROOT_DIR, "lib", "x64", "cudart.lib"))
  363. if with_cudnn:
  364. ldflags.append(os.path.join(CUDNN_ROOT_DIR, "lib", "x64", "cudnn.lib"))
  365. else:
  366. ldflags.append("-lmegengine_shared -L{}".format(MGE_LIB_PATH))
  367. ldflags.append("-Wl,-rpath,{}".format(MGE_LIB_PATH))
  368. if with_cuda:
  369. ldflags.append("-lcudart")
  370. ldflags.append("-L{}".format(os.path.join(CUDA_ROOT_DIR, "lib64")))
  371. ldflags.append("-Wl,-rpath,{}".format(os.path.join(CUDA_ROOT_DIR, "lib64")))
  372. if with_cudnn:
  373. ldflags.append("-L{}".format(os.path.join(CUDNN_ROOT_DIR, "lib64")))
  374. ldflags.append(
  375. "-Wl,-rpath,{}".format(os.path.join(CUDNN_ROOT_DIR, "lib64"))
  376. )
  377. return ldflags
  378. def _add_shared_flag(ldflags: List[str]):
  379. ldflags += ["/LD" if IS_WINDOWS else "-shared"]
  380. return ldflags
  381. #####################################################################
  382. # Phase 5
  383. #####################################################################
  384. def _obj_file_path(src_file_path: str):
  385. file_name = os.path.splitext(os.path.basename(src_file_path))[0]
  386. if _is_cuda_file(src_file_path):
  387. target = "{}.cuda.o".format(file_name)
  388. else:
  389. target = "{}.o".format(file_name)
  390. return target
  391. def _dump_ninja_file(
  392. path,
  393. cflags,
  394. post_cflags,
  395. cuda_cflags,
  396. cuda_post_cflags,
  397. sources,
  398. objects,
  399. ldflags,
  400. library_target,
  401. with_cuda,
  402. ):
  403. def sanitize_flags(flags):
  404. return [] if flags is None else [flag.strip() for flag in flags]
  405. cflags = sanitize_flags(cflags)
  406. post_cflags = sanitize_flags(post_cflags)
  407. cuda_cflags = sanitize_flags(cuda_cflags)
  408. cuda_post_cflags = sanitize_flags(cuda_post_cflags)
  409. ldflags = sanitize_flags(ldflags)
  410. assert len(sources) == len(objects)
  411. assert len(sources) > 0
  412. if IS_WINDOWS:
  413. compiler = os.environ.get("CXX", "clang-cl")
  414. else:
  415. compiler = os.environ.get("CXX", "c++")
  416. # Version 1.3 is required for the `deps` directive.
  417. config = ["ninja_required_version = 1.3"]
  418. config.append("cxx = {}".format(compiler))
  419. if with_cuda:
  420. nvcc = os.path.join(CUDA_ROOT_DIR, "bin", "nvcc")
  421. config.append("nvcc = {}".format(nvcc))
  422. flags = ["cflags = {}".format(" ".join(cflags))]
  423. flags.append("post_cflags = {}".format(" ".join(post_cflags)))
  424. if with_cuda:
  425. flags.append("cuda_cflags = {}".format(" ".join(cuda_cflags)))
  426. flags.append("cuda_post_cflags = {}".format(" ".join(cuda_post_cflags)))
  427. flags.append("ldflags = {}".format(" ".join(ldflags)))
  428. # Turn into absolute paths so we can emit them into the ninja build
  429. # file wherever it is.
  430. sources = [os.path.abspath(file) for file in sources]
  431. # See https://ninja-build.org/build.ninja.html for reference.
  432. compile_rule = ["rule compile"]
  433. if IS_WINDOWS:
  434. compile_rule.append(
  435. " command = clang-cl /showIncludes $cflags -c $in /Fo$out $post_cflags"
  436. )
  437. compile_rule.append(" deps = msvc")
  438. else:
  439. compile_rule.append(
  440. " command = $cxx -MMD -MF $out.d $cflags -c $in -o $out $post_cflags"
  441. )
  442. compile_rule.append(" depfile = $out.d")
  443. compile_rule.append(" deps = gcc")
  444. if with_cuda:
  445. cuda_compile_rule = ["rule cuda_compile"]
  446. nvcc_gendeps = ""
  447. cuda_compile_rule.append(
  448. " command = $nvcc {} $cuda_cflags -c $in -o $out $cuda_post_cflags".format(
  449. nvcc_gendeps
  450. )
  451. )
  452. # Emit one build rule per source to enable incremental build.
  453. build = []
  454. for source_file, object_file in zip(sources, objects):
  455. is_cuda_source = _is_cuda_file(source_file) and with_cuda
  456. rule = "cuda_compile" if is_cuda_source else "compile"
  457. if IS_WINDOWS:
  458. source_file = source_file.replace(":", "$:")
  459. object_file = object_file.replace(":", "$:")
  460. source_file = source_file.replace(" ", "$ ")
  461. object_file = object_file.replace(" ", "$ ")
  462. build.append("build {}: {} {}".format(object_file, rule, source_file))
  463. if library_target is not None:
  464. link_rule = ["rule link"]
  465. if IS_WINDOWS:
  466. link_rule.append(" command = clang-cl $in /nologo $ldflags /out:$out")
  467. else:
  468. link_rule.append(" command = $cxx $in $ldflags -o $out")
  469. link = ["build {}: link {}".format(library_target, " ".join(objects))]
  470. default = ["default {}".format(library_target)]
  471. else:
  472. link_rule, link, default = [], [], []
  473. # 'Blocks' should be separated by newlines, for visual benefit.
  474. blocks = [config, flags, compile_rule]
  475. if with_cuda:
  476. blocks.append(cuda_compile_rule)
  477. blocks += [link_rule, build, link, default]
  478. with open(path, "w") as build_file:
  479. for block in blocks:
  480. lines = "\n".join(block)
  481. build_file.write("{}\n\n".format(lines))
  482. class FileBaton:
  483. def __init__(self, lock_file_path, wait_seconds=0.1):
  484. self.lock_file_path = lock_file_path
  485. self.wait_seconds = wait_seconds
  486. self.fd = None
  487. def try_acquire(self):
  488. try:
  489. self.fd = os.open(self.lock_file_path, os.O_CREAT | os.O_EXCL)
  490. return True
  491. except FileExistsError:
  492. return False
  493. def wait(self):
  494. while os.path.exists(self.lock_file_path):
  495. time.sleep(self.wait_seconds)
  496. def release(self):
  497. if self.fd is not None:
  498. os.close(self.fd)
  499. os.remove(self.lock_file_path)
  500. #####################################################################
  501. # Phase 6
  502. #####################################################################
  503. def _build_with_ninja(build_dir: str, verbose: bool, error_prefix: str):
  504. command = ["ninja", "-v"]
  505. env = os.environ.copy()
  506. try:
  507. sys.stdout.flush()
  508. sys.stderr.flush()
  509. stdout_fileno = 1
  510. subprocess.run(
  511. command,
  512. stdout=stdout_fileno if verbose else subprocess.PIPE,
  513. stderr=subprocess.STDOUT,
  514. cwd=build_dir,
  515. check=True,
  516. env=env,
  517. )
  518. except subprocess.CalledProcessError as e:
  519. with open(os.path.join(build_dir, "build.ninja")) as f:
  520. lines = f.readlines()
  521. print(lines)
  522. _, error, _ = sys.exc_info()
  523. message = error_prefix
  524. if hasattr(error, "output") and error.output:
  525. message += ": {}".format(error.output.decode())
  526. raise RuntimeError(message) from e
  527. def build(
  528. name: str,
  529. sources: Union[str, List[str]],
  530. extra_cflags: Union[str, List[str]] = [],
  531. extra_cuda_cflags: Union[str, List[str]] = [],
  532. extra_ldflags: Union[str, List[str]] = [],
  533. extra_include_paths: Union[str, List[str]] = [],
  534. with_cuda: Optional[bool] = None,
  535. build_dir: Optional[bool] = None,
  536. verbose: bool = False,
  537. abi_tag: Optional[int] = None,
  538. ) -> str:
  539. r"""Build a Custom Op with ninja in the way of just-in-time (JIT).
  540. To build the custom op, a Ninja build file is emitted, which is used to
  541. compile the given sources into a dynamic library.
  542. By default, the directory to which the build file is emitted and the
  543. resulting library compiled to is ``<tmp>/mge_custom_op/<name>``, where
  544. ``<tmp>`` is the temporary folder on the current platform and ``<name>``
  545. the name of the custom op. This location can be overridden in two ways.
  546. First, if the ``MGE_CUSTOM_OP_DIR`` environment variable is set, it
  547. replaces ``<tmp>/mge_custom_op`` and all custom op will be compiled
  548. into subfolders of this directory. Second, if the ``build_dir``
  549. argument to this function is supplied, it overrides the entire path, i.e.
  550. the library will be compiled into that folder directly.
  551. To compile the sources, the default system compiler (``c++``) is used,
  552. which can be overridden by setting the ``CXX`` environment variable. To pass
  553. additional arguments to the compilation process, ``extra_cflags`` or
  554. ``extra_ldflags`` can be provided. For example, to compile your custom op
  555. with optimizations, pass ``extra_cflags=['-O3']``. You can also use
  556. ``extra_cflags`` to pass further include directories.
  557. CUDA support with mixed compilation is provided. Simply pass CUDA source
  558. files (``.cu`` or ``.cuh``) along with other sources. Such files will be
  559. detected and compiled with nvcc rather than the C++ compiler. This includes
  560. passing the CUDA lib64 directory as a library directory, and linking
  561. ``cudart``. You can pass additional flags to nvcc via
  562. ``extra_cuda_cflags``, just like with ``extra_cflags`` for C++. Various
  563. heuristics for finding the CUDA install directory are used, which usually
  564. work fine. If not, setting the ``CUDA_ROOT_DIR`` environment variable is the
  565. safest option. If you use CUDNN, please also setting the ``CUDNN_ROOT_DIR``
  566. environment variable.
  567. Args:
  568. name: The name of the custom op to build.
  569. sources: A list of relative or absolute paths to C++ source files.
  570. extra_cflags: optional list of compiler flags to forward to the build.
  571. extra_cuda_cflags: optional list of compiler flags to forward to nvcc
  572. when building CUDA sources.
  573. extra_ldflags: optional list of linker flags to forward to the build.
  574. extra_include_paths: optional list of include directories to forward
  575. to the build.
  576. with_cuda: Determines whether CUDA headers and libraries are added to
  577. the build. If set to ``None`` (default), this value is
  578. automatically determined based on the existence of ``.cu`` or
  579. ``.cuh`` in ``sources``. Set it to `True`` to force CUDA headers
  580. and libraries to be included.
  581. build_dir: optional path to use as build workspace.
  582. verbose: If ``True``, turns on verbose logging of load steps.
  583. abi_tag: Determines the value of MACRO ``_GLIBCXX_USE_CXX11_ABI``
  584. in gcc compiler, should be ``0`` or ``1``.
  585. Returns:
  586. the compiled dynamic library path
  587. """
  588. # phase 1: prepare config
  589. if abi_tag == None:
  590. abi_tag = get_custom_op_abi_tag()
  591. global MGE_ABI_VER
  592. MGE_ABI_VER = abi_tag
  593. def strlist(args, name):
  594. assert isinstance(args, str) or isinstance(
  595. args, list
  596. ), "{} must be str or list[str]".format(name)
  597. if isinstance(args, str):
  598. return [args]
  599. for arg in args:
  600. assert isinstance(arg, str)
  601. args = [arg.strip() for arg in args]
  602. return args
  603. sources = strlist(sources, "sources")
  604. extra_cflags = strlist(extra_cflags, "extra_cflags")
  605. extra_cuda_cflags = strlist(extra_cuda_cflags, "extra_cuda_cflags")
  606. extra_ldflags = strlist(extra_ldflags, "extra_ldflags")
  607. extra_include_paths = strlist(extra_include_paths, "extra_include_paths")
  608. with_cuda = any(map(_is_cuda_file, sources)) if with_cuda is None else with_cuda
  609. with_cudnn = any(["cudnn" in f for f in extra_ldflags])
  610. if CUDA_ROOT_DIR == None and with_cuda:
  611. print(
  612. "No CUDA runtime is found, using {}=/path/to/your/cuda_root_dir".format(
  613. ev_cuda_root_dir
  614. )
  615. )
  616. if CUDNN_ROOT_DIR == None and with_cudnn:
  617. print(
  618. "Cannot find the root directory of cudnn, using {}=/path/to/your/cudnn_root_dir".format(
  619. ev_cudnn_root_dir
  620. )
  621. )
  622. build_dir = os.path.abspath(
  623. _get_build_dir(name) if build_dir is None else build_dir
  624. )
  625. if not os.path.exists(build_dir):
  626. os.makedirs(build_dir, exist_ok=True)
  627. if verbose:
  628. print("Using {} to build megengine custom op".format(build_dir))
  629. # phase 2: version check
  630. version, old_version = version_check(
  631. name,
  632. sources,
  633. [extra_cflags, extra_cuda_cflags, extra_ldflags, extra_include_paths],
  634. build_dir,
  635. with_cuda,
  636. with_cudnn,
  637. abi_tag,
  638. )
  639. if verbose:
  640. if version != old_version and old_version != None:
  641. print(
  642. "Input conditions of custom op {} have changed, bumping to version {}".format(
  643. name, version
  644. )
  645. )
  646. print("Building custom op {} with version {}".format(name, version))
  647. if version == old_version:
  648. if verbose:
  649. print(
  650. "No modifications detected for {}, skipping build step...".format(name)
  651. )
  652. return
  653. name = "{}_v{}".format(name, version)
  654. # phase 3: compiler and ninja check
  655. _check_ninja_availability()
  656. _check_compiler_comatibility()
  657. # phase 4: setup the compile flags
  658. user_includes, system_includes = _setup_includes(
  659. extra_include_paths, with_cuda, with_cudnn
  660. )
  661. common_cflags = _setup_common_cflags(user_includes, system_includes)
  662. cuda_cflags = (
  663. _setup_cuda_cflags(common_cflags, extra_cuda_cflags) if with_cuda else None
  664. )
  665. ldflags = _setup_ldflags(extra_ldflags, with_cuda, with_cudnn)
  666. if IS_WINDOWS:
  667. cflags = common_cflags + COMMON_MSVC_FLAGS + extra_cflags
  668. cflags = _nt_quote_args(cflags)
  669. else:
  670. cflags = common_cflags + ["-fPIC", "-std=c++14"] + extra_cflags
  671. ldflags = _add_shared_flag(ldflags)
  672. if sys.platform.startswith("darwin"):
  673. ldflags.append("-undefined dynamic_lookup")
  674. elif IS_WINDOWS:
  675. ldflags += ["/link"]
  676. ldflags = _nt_quote_args(ldflags)
  677. baton = FileBaton(os.path.join(build_dir, "lock"))
  678. if baton.try_acquire():
  679. try:
  680. # phase 5: generate ninja build file
  681. objs = [_obj_file_path(src) for src in sources]
  682. name += ".dll" if IS_WINDOWS else ".so"
  683. build_file_path = os.path.join(build_dir, "build.ninja")
  684. if verbose:
  685. print("Emitting ninja build file {}".format(build_file_path))
  686. _dump_ninja_file(
  687. path=build_file_path,
  688. cflags=cflags,
  689. post_cflags=None,
  690. cuda_cflags=cuda_cflags,
  691. cuda_post_cflags=None,
  692. sources=sources,
  693. objects=objs,
  694. ldflags=ldflags,
  695. library_target=name,
  696. with_cuda=with_cuda,
  697. )
  698. # phase 6: build with ninja
  699. if verbose:
  700. print(
  701. "Compiling and linking your custom op {}".format(
  702. os.path.join(build_dir, name)
  703. )
  704. )
  705. _build_with_ninja(build_dir, verbose, "compiling error")
  706. finally:
  707. baton.release()
  708. else:
  709. baton.wait()
  710. return os.path.join(build_dir, name)
  711. def build_and_load(
  712. name: str,
  713. sources: Union[str, List[str]],
  714. extra_cflags: Union[str, List[str]] = [],
  715. extra_cuda_cflags: Union[str, List[str]] = [],
  716. extra_ldflags: Union[str, List[str]] = [],
  717. extra_include_paths: Union[str, List[str]] = [],
  718. with_cuda: Optional[bool] = None,
  719. build_dir: Optional[bool] = None,
  720. verbose: bool = False,
  721. abi_tag: Optional[int] = None,
  722. ) -> str:
  723. r"""Build and Load a Custom Op with ninja in the way of just-in-time (JIT).
  724. Same as the function ``build()`` but load the built dynamic library.
  725. Args:
  726. same as ``build()``
  727. Returns:
  728. the compiled dynamic library path
  729. """
  730. lib_path = build(
  731. name,
  732. sources,
  733. extra_cflags,
  734. extra_cuda_cflags,
  735. extra_ldflags,
  736. extra_include_paths,
  737. with_cuda,
  738. build_dir,
  739. verbose,
  740. abi_tag,
  741. )
  742. if verbose:
  743. print("Load the compiled custom op {}".format(lib_path))
  744. load(lib_path)
  745. return lib_path