You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

custom_op_tools.py 31 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909
  1. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  2. #
  3. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  4. #
  5. # Unless required by applicable law or agreed to in writing,
  6. # software distributed under the License is distributed on an
  7. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  8. import collections
  9. import ctypes
  10. import glob
  11. import os
  12. import re
  13. import subprocess
  14. import sys
  15. import time
  16. from typing import List, Optional, Union
  17. from ..core.ops.custom import load
  18. from ..logger import get_logger
  19. def _get_win_folder_with_ctypes(csidl_name):
  20. csidl_const = {
  21. "CSIDL_APPDATA": 26,
  22. "CSIDL_COMMON_APPDATA": 35,
  23. "CSIDL_LOCAL_APPDATA": 28,
  24. }[csidl_name]
  25. buf = ctypes.create_unicode_buffer(1024)
  26. ctypes.windll.shell32.SHGetFolderPathW(None, csidl_const, None, 0, buf)
  27. # Downgrade to short path name if have highbit chars. See
  28. # <http://bugs.activestate.com/show_bug.cgi?id=85099>.
  29. has_high_char = False
  30. for c in buf:
  31. if ord(c) > 255:
  32. has_high_char = True
  33. break
  34. if has_high_char:
  35. buf2 = ctypes.create_unicode_buffer(1024)
  36. if ctypes.windll.kernel32.GetShortPathNameW(buf.value, buf2, 1024):
  37. buf = buf2
  38. return buf.value
  39. system = sys.platform
  40. if system == "win32":
  41. _get_win_folder = _get_win_folder_with_ctypes
  42. PLAT_TO_VCVARS = {
  43. "win-amd64": "x86_amd64",
  44. }
  45. logger = get_logger()
  46. # environment varible
  47. ev_custom_op_root_dir = "MGE_CUSTOM_OP_DIR"
  48. ev_cuda_root_dir = "CUDA_ROOT_DIR"
  49. ev_cudnn_root_dir = "CUDNN_ROOT_DIR"
  50. # operating system
  51. IS_WINDOWS = system == "win32"
  52. IS_LINUX = system == "linux"
  53. IS_MACOS = system == "darwin"
  54. MGE_PATH = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
  55. MGE_INC_PATH = os.path.join(MGE_PATH, "core", "include")
  56. MGE_LIB_PATH = os.path.join(MGE_PATH, "core", "lib")
  57. MGE_ABI_VER = 0
  58. # compile version
  59. MINIMUM_GCC_VERSION = (5, 0, 0)
  60. MINIMUM_CLANG_CL_VERSION = (12, 0, 1)
  61. # compile flags
  62. COMMON_MSVC_FLAGS = [
  63. "/MD",
  64. "/wd4002",
  65. "/wd4819",
  66. "/EHsc",
  67. ]
  68. MSVC_IGNORE_CUDAFE_WARNINGS = [
  69. "field_without_dll_interface",
  70. ]
  71. COMMON_NVCC_FLAGS = []
  72. # Finds the CUDA install path
  73. def _find_cuda_root_dir() -> Optional[str]:
  74. cuda_root_dir = os.environ.get(ev_cuda_root_dir)
  75. if cuda_root_dir is None:
  76. try:
  77. which = "where" if IS_WINDOWS else "which"
  78. with open(os.devnull, "w") as devnull:
  79. nvcc = (
  80. subprocess.check_output([which, "nvcc"], stderr=devnull)
  81. .decode()
  82. .rstrip("\r\n")
  83. )
  84. cuda_root_dir = os.path.dirname(os.path.dirname(nvcc))
  85. except Exception:
  86. if IS_WINDOWS:
  87. cuda_root_dir = os.environ.get("CUDA_PATH", None)
  88. if cuda_root_dir == None:
  89. cuda_root_dirs = glob.glob(
  90. "C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v*.*"
  91. )
  92. if len(cuda_root_dirs) == 0:
  93. cuda_root_dir = ""
  94. else:
  95. cuda_root_dir = cuda_root_dirs[0]
  96. else:
  97. cuda_root_dir = "/usr/local/cuda"
  98. if not os.path.exists(cuda_root_dir):
  99. cuda_root_dir = None
  100. return cuda_root_dir
  101. def _find_cudnn_root_dir() -> Optional[str]:
  102. cudnn_root_dir = os.environ.get(ev_cudnn_root_dir)
  103. return cudnn_root_dir
  104. CUDA_ROOT_DIR = _find_cuda_root_dir()
  105. CUDNN_ROOT_DIR = _find_cudnn_root_dir()
  106. #####################################################################
  107. # Phase 1
  108. #####################################################################
  109. def _is_cuda_file(path: str) -> bool:
  110. valid_ext = [".cu", ".cuh"]
  111. return os.path.splitext(path)[1] in valid_ext
  112. # Return full path to the user-specific cache dir for this application.
  113. # Typical user cache directories are:
  114. # Mac OS X: ~/Library/Caches/<AppName>
  115. # Unix: ~/.cache/<AppName> (XDG default)
  116. # Windows: C:\Users\<username>\AppData\Local\<AppAuthor>\<AppName>\Cache
  117. def _get_user_cache_dir(appname=None, appauthor=None, version=None, opinion=True):
  118. if system == "win32":
  119. appauthor = appname if appauthor is None else appauthor
  120. path = os.path.normpath(_get_win_folder("CSIDL_LOCAL_APPDATA"))
  121. if appname:
  122. if appauthor is not False:
  123. path = os.path.join(path, appauthor)
  124. else:
  125. path = os.path.join(path, appname)
  126. if opinion:
  127. path = os.path.join(path, "Cache")
  128. elif system == "darwin":
  129. path = os.path.expanduser("~/Library/Caches")
  130. if appname:
  131. path = os.path.join(path, appname)
  132. else:
  133. path = os.getenv("XDG_CACHE_HOME", os.path.expanduser("~/.cache"))
  134. if appname:
  135. path = os.path.join(path, appname)
  136. if appname and version:
  137. path = os.path.join(path, version)
  138. return path
  139. # Returns the path to the root folder under which custom op will built.
  140. def _get_default_build_root() -> str:
  141. return os.path.realpath(_get_user_cache_dir(appname="mge_custom_op"))
  142. def _get_build_dir(name: str) -> str:
  143. custom_op_root_dir = os.environ.get(ev_custom_op_root_dir)
  144. if custom_op_root_dir is None:
  145. custom_op_root_dir = _get_default_build_root()
  146. build_dir = os.path.join(custom_op_root_dir, name)
  147. return build_dir
  148. #####################################################################
  149. # Phase 2
  150. #####################################################################
  151. def update_hash(seed, value):
  152. # using boost::hash_combine
  153. # https://www.boost.org/doc/libs/1_35_0/doc/html/boost/hash_combine_id241013.html
  154. return seed ^ (hash(value) + 0x9E3779B9 + (seed << 6) + (seed >> 2))
  155. def hash_source_files(hash_value, source_files):
  156. for filename in source_files:
  157. with open(filename) as file:
  158. hash_value = update_hash(hash_value, file.read())
  159. return hash_value
  160. def hash_build_args(hash_value, build_args):
  161. for group in build_args:
  162. for arg in group:
  163. hash_value = update_hash(hash_value, arg)
  164. return hash_value
  165. Entry = collections.namedtuple("Entry", "version, hash")
  166. class Versioner(object):
  167. def __init__(self):
  168. self.entries = {}
  169. def get_version(self, name):
  170. entry = self.entries.get(name)
  171. return None if entry is None else entry.version
  172. def bump_version_if_changed(
  173. self, name, sources, build_args, build_dir, with_cuda, with_cudnn, abi_tag
  174. ):
  175. hash_value = 0
  176. hash_value = hash_source_files(hash_value, sources)
  177. hash_value = hash_build_args(hash_value, build_args)
  178. hash_value = update_hash(hash_value, build_dir)
  179. hash_value = update_hash(hash_value, with_cuda)
  180. hash_value = update_hash(hash_value, with_cudnn)
  181. hash_value = update_hash(hash_value, abi_tag)
  182. entry = self.entries.get(name)
  183. if entry is None:
  184. self.entries[name] = entry = Entry(0, hash_value)
  185. elif hash_value != entry.hash:
  186. self.entries[name] = entry = Entry(entry.version + 1, hash_value)
  187. return entry.version
  188. custom_op_versioner = Versioner()
  189. def version_check(
  190. name, sources, build_args, build_dir, with_cuda, with_cudnn, abi_tag,
  191. ):
  192. old_version = custom_op_versioner.get_version(name)
  193. version = custom_op_versioner.bump_version_if_changed(
  194. name, sources, build_args, build_dir, with_cuda, with_cudnn, abi_tag,
  195. )
  196. return version, old_version
  197. #####################################################################
  198. # Phase 3
  199. #####################################################################
  200. def _check_ninja_availability():
  201. try:
  202. subprocess.check_output("ninja --version".split())
  203. except Exception:
  204. raise RuntimeError(
  205. "Ninja is required to build custom op, please install ninja and update your PATH"
  206. )
  207. def _mge_is_built_from_src():
  208. file_path = os.path.abspath(__file__)
  209. if "site-packages" in file_path:
  210. return False
  211. else:
  212. return True
  213. def _accepted_compilers_for_platform():
  214. if IS_WINDOWS:
  215. return ["clang-cl"]
  216. if IS_MACOS:
  217. return ["clang++", "clang"]
  218. if IS_LINUX:
  219. return ["g++", "gcc", "gnu-c++", "gnu-cc"]
  220. # Verifies that the compiler is the expected one for the current platform.
  221. def _check_compiler_existed_for_platform(compiler: str) -> bool:
  222. # there is no suitable cmd like `which` on windows, so we assume the compiler is always true on windows
  223. if IS_WINDOWS:
  224. try:
  225. version_string = subprocess.check_output(
  226. ["clang-cl", "--version"], stderr=subprocess.STDOUT
  227. ).decode()
  228. return True
  229. except Exception:
  230. return False
  231. # use os.path.realpath to resolve any symlinks, in particular from "c++" to e.g. "g++".
  232. which = subprocess.check_output(["which", compiler], stderr=subprocess.STDOUT)
  233. compiler_path = os.path.realpath(which.decode().strip())
  234. if any(name in compiler_path for name in _accepted_compilers_for_platform()):
  235. return True
  236. version_string = subprocess.check_output(
  237. [compiler, "-v"], stderr=subprocess.STDOUT
  238. ).decode()
  239. if sys.platform.startswith("linux"):
  240. pattern = re.compile("^COLLECT_GCC=(.*)$", re.MULTILINE)
  241. results = re.findall(pattern, version_string)
  242. if len(results) != 1:
  243. return False
  244. compiler_path = os.path.realpath(results[0].strip())
  245. return any(name in compiler_path for name in _accepted_compilers_for_platform())
  246. if sys.platform.startswith("darwin"):
  247. return version_string.startswith("Apple clang")
  248. return False
  249. # Verifies that the given compiler is ABI-compatible with MegEngine.
  250. def _check_compiler_abi_compatibility(compiler: str):
  251. # we think if the megengine is built from source, the user will use the same compiler to compile the custom op
  252. if _mge_is_built_from_src() or os.environ.get("MGE_CHECK_ABI", "1") == "0":
  253. return True
  254. # [TODO] There is no particular minimum version we need for clang, so we"re good here.
  255. if sys.platform.startswith("darwin"):
  256. return True
  257. try:
  258. if sys.platform.startswith("linux"):
  259. minimum_required_version = MINIMUM_GCC_VERSION
  260. versionstr = subprocess.check_output(
  261. [compiler, "-dumpfullversion", "-dumpversion"]
  262. )
  263. version = versionstr.decode().strip().split(".")
  264. else:
  265. minimum_required_version = MINIMUM_CLANG_CL_VERSION
  266. compiler_info = subprocess.check_output(
  267. [compiler, "--version"], stderr=subprocess.STDOUT
  268. )
  269. match = re.search(r"(\d+)\.(\d+)\.(\d+)", compiler_info.decode().strip())
  270. version = (0, 0, 0) if match is None else match.groups()
  271. except Exception:
  272. _, error, _ = sys.exc_info()
  273. logger.warning(
  274. "Error checking compiler version for {}: {}".format(compiler, error)
  275. )
  276. return False
  277. if tuple(map(int, version)) >= minimum_required_version:
  278. return True
  279. return False
  280. def _check_compiler_comatibility():
  281. # we use clang-cl on windows, refer: https://clang.llvm.org/docs/UsersManual.html#clang-cl
  282. compiler = (
  283. os.environ.get("CXX", "clang-cl")
  284. if IS_WINDOWS
  285. else os.environ.get("CXX", "c++")
  286. )
  287. existed = _check_compiler_existed_for_platform(compiler)
  288. if existed == False:
  289. log_str = (
  290. "Cannot find compiler which is compatible with the compiler "
  291. "MegEngine was built with for this platform, which is {mge_compiler} on "
  292. "{platform}. Please use {mge_compiler} to to compile your extension. "
  293. "Alternatively, you may compile MegEngine from source using "
  294. "{user_compiler}, and then you can also use {user_compiler} to compile "
  295. "your extension."
  296. ).format(
  297. user_compiler=compiler,
  298. mge_compiler=_accepted_compilers_for_platform()[0],
  299. platform=sys.platform,
  300. )
  301. logger.warning(log_str)
  302. return False
  303. compatible = _check_compiler_abi_compatibility(compiler)
  304. if compatible == False:
  305. log_str = (
  306. "Your compiler version may be ABI-incompatible with MegEngine! "
  307. "Please use a compiler that is ABI-compatible with GCC 5.0 on Linux "
  308. "and LLVM/Clang 12.0 on Windows ."
  309. )
  310. logger.warning(log_str)
  311. return True
  312. #####################################################################
  313. # Phase 4
  314. #####################################################################
  315. # Quote command-line arguments for DOS/Windows conventions.
  316. def _nt_quote_args(args: Optional[List[str]]) -> List[str]:
  317. # Cover None-type
  318. if not args:
  319. return []
  320. return ['"{}"'.format(arg) if " " in arg else arg for arg in args]
  321. # Now we need user to specify the arch of GPU
  322. def _get_cuda_arch_flags(cflags: Optional[List[str]] = None) -> List[str]:
  323. return []
  324. def _setup_sys_includes(with_cuda: bool, with_cudnn: bool):
  325. includes = [os.path.join(MGE_INC_PATH)]
  326. if with_cuda:
  327. includes.append(os.path.join(CUDA_ROOT_DIR, "include"))
  328. if with_cudnn:
  329. includes.append(os.path.join(CUDNN_ROOT_DIR, "include"))
  330. return includes
  331. def _setup_includes(extra_include_paths: List[str], with_cuda: bool, with_cudnn: bool):
  332. user_includes = [os.path.abspath(path) for path in extra_include_paths]
  333. system_includes = _setup_sys_includes(with_cuda, with_cudnn)
  334. if IS_WINDOWS:
  335. user_includes += system_includes
  336. system_includes.clear()
  337. return user_includes, system_includes
  338. def _setup_common_cflags(user_includes: List[str], system_includes: List[str]):
  339. common_cflags = []
  340. common_cflags += ["-I{}".format(include) for include in user_includes]
  341. common_cflags += ["-isystem {}".format(include) for include in system_includes]
  342. if not IS_WINDOWS:
  343. common_cflags += ["-D_GLIBCXX_USE_CXX11_ABI={}".format(MGE_ABI_VER)]
  344. return common_cflags
  345. def _setup_cuda_cflags(cflags: List[str], extra_cuda_cflags: List[str]):
  346. cuda_flags = cflags + COMMON_NVCC_FLAGS + _get_cuda_arch_flags()
  347. if IS_WINDOWS:
  348. for flag in COMMON_MSVC_FLAGS:
  349. cuda_flags = ["-Xcompiler", flag] + cuda_flags
  350. for ignore_warning in MSVC_IGNORE_CUDAFE_WARNINGS:
  351. cuda_flags = ["-Xcudafe", "--diag_suppress=" + ignore_warning] + cuda_flags
  352. cuda_flags = _nt_quote_args(cuda_flags)
  353. cuda_flags += _nt_quote_args(extra_cuda_cflags)
  354. else:
  355. cuda_flags += ["--compiler-options", '"-fPIC"']
  356. cuda_flags += extra_cuda_cflags
  357. if not any(flag.startswith("-std=") for flag in cuda_flags):
  358. cuda_flags.append("-std=c++14")
  359. if os.getenv("CC") is not None:
  360. cuda_flags = ["-ccbin", os.getenv("CC")] + cuda_flags
  361. return cuda_flags
  362. def _setup_ldflags(
  363. extra_ldflags: List[str], with_cuda: bool, with_cudnn: bool
  364. ) -> List[str]:
  365. ldflags = extra_ldflags
  366. if IS_WINDOWS:
  367. ldflags.append(os.path.join(MGE_LIB_PATH, "megengine_shared.lib"))
  368. if with_cuda:
  369. ldflags.append(os.path.join(CUDA_ROOT_DIR, "lib", "x64", "cudart.lib"))
  370. if with_cudnn:
  371. ldflags.append(os.path.join(CUDNN_ROOT_DIR, "lib", "x64", "cudnn.lib"))
  372. else:
  373. ldflags.append("-lmegengine_shared -L{}".format(MGE_LIB_PATH))
  374. ldflags.append("-Wl,-rpath,{}".format(MGE_LIB_PATH))
  375. if with_cuda:
  376. ldflags.append("-lcudart")
  377. ldflags.append("-L{}".format(os.path.join(CUDA_ROOT_DIR, "lib64")))
  378. ldflags.append("-Wl,-rpath,{}".format(os.path.join(CUDA_ROOT_DIR, "lib64")))
  379. if with_cudnn:
  380. ldflags.append("-L{}".format(os.path.join(CUDNN_ROOT_DIR, "lib64")))
  381. ldflags.append(
  382. "-Wl,-rpath,{}".format(os.path.join(CUDNN_ROOT_DIR, "lib64"))
  383. )
  384. return ldflags
  385. def _add_shared_flag(ldflags: List[str]):
  386. ldflags += ["/LD" if IS_WINDOWS else "-shared"]
  387. return ldflags
  388. #####################################################################
  389. # Phase 5
  390. #####################################################################
  391. def _obj_file_path(src_file_path: str):
  392. file_name = os.path.splitext(os.path.basename(src_file_path))[0]
  393. if _is_cuda_file(src_file_path):
  394. target = "{}.cuda.o".format(file_name)
  395. else:
  396. target = "{}.o".format(file_name)
  397. return target
  398. def _dump_ninja_file(
  399. path,
  400. cflags,
  401. post_cflags,
  402. cuda_cflags,
  403. cuda_post_cflags,
  404. sources,
  405. objects,
  406. ldflags,
  407. library_target,
  408. with_cuda,
  409. ):
  410. def sanitize_flags(flags):
  411. return [] if flags is None else [flag.strip() for flag in flags]
  412. cflags = sanitize_flags(cflags)
  413. post_cflags = sanitize_flags(post_cflags)
  414. cuda_cflags = sanitize_flags(cuda_cflags)
  415. cuda_post_cflags = sanitize_flags(cuda_post_cflags)
  416. ldflags = sanitize_flags(ldflags)
  417. assert len(sources) == len(objects)
  418. assert len(sources) > 0
  419. if IS_WINDOWS:
  420. compiler = os.environ.get("CXX", "clang-cl")
  421. else:
  422. compiler = os.environ.get("CXX", "c++")
  423. # Version 1.3 is required for the `deps` directive.
  424. config = ["ninja_required_version = 1.3"]
  425. config.append("cxx = {}".format(compiler))
  426. if with_cuda:
  427. nvcc = os.path.join(CUDA_ROOT_DIR, "bin", "nvcc")
  428. config.append("nvcc = {}".format(nvcc))
  429. flags = ["cflags = {}".format(" ".join(cflags))]
  430. flags.append("post_cflags = {}".format(" ".join(post_cflags)))
  431. if with_cuda:
  432. flags.append("cuda_cflags = {}".format(" ".join(cuda_cflags)))
  433. flags.append("cuda_post_cflags = {}".format(" ".join(cuda_post_cflags)))
  434. flags.append("ldflags = {}".format(" ".join(ldflags)))
  435. # Turn into absolute paths so we can emit them into the ninja build
  436. # file wherever it is.
  437. sources = [os.path.abspath(file) for file in sources]
  438. # See https://ninja-build.org/build.ninja.html for reference.
  439. compile_rule = ["rule compile"]
  440. if IS_WINDOWS:
  441. compile_rule.append(
  442. " command = clang-cl /showIncludes $cflags -c $in /Fo$out $post_cflags"
  443. )
  444. compile_rule.append(" deps = msvc")
  445. else:
  446. compile_rule.append(
  447. " command = $cxx -MMD -MF $out.d $cflags -c $in -o $out $post_cflags"
  448. )
  449. compile_rule.append(" depfile = $out.d")
  450. compile_rule.append(" deps = gcc")
  451. if with_cuda:
  452. cuda_compile_rule = ["rule cuda_compile"]
  453. nvcc_gendeps = ""
  454. cuda_compile_rule.append(
  455. " command = $nvcc {} $cuda_cflags -c $in -o $out $cuda_post_cflags".format(
  456. nvcc_gendeps
  457. )
  458. )
  459. # Emit one build rule per source to enable incremental build.
  460. build = []
  461. for source_file, object_file in zip(sources, objects):
  462. is_cuda_source = _is_cuda_file(source_file) and with_cuda
  463. rule = "cuda_compile" if is_cuda_source else "compile"
  464. if IS_WINDOWS:
  465. source_file = source_file.replace(":", "$:")
  466. object_file = object_file.replace(":", "$:")
  467. source_file = source_file.replace(" ", "$ ")
  468. object_file = object_file.replace(" ", "$ ")
  469. build.append("build {}: {} {}".format(object_file, rule, source_file))
  470. if library_target is not None:
  471. link_rule = ["rule link"]
  472. if IS_WINDOWS:
  473. link_rule.append(" command = clang-cl $in /nologo $ldflags /out:$out")
  474. else:
  475. link_rule.append(" command = $cxx $in $ldflags -o $out")
  476. link = ["build {}: link {}".format(library_target, " ".join(objects))]
  477. default = ["default {}".format(library_target)]
  478. else:
  479. link_rule, link, default = [], [], []
  480. # 'Blocks' should be separated by newlines, for visual benefit.
  481. blocks = [config, flags, compile_rule]
  482. if with_cuda:
  483. blocks.append(cuda_compile_rule)
  484. blocks += [link_rule, build, link, default]
  485. with open(path, "w") as build_file:
  486. for block in blocks:
  487. lines = "\n".join(block)
  488. build_file.write("{}\n\n".format(lines))
  489. class FileBaton:
  490. def __init__(self, lock_file_path, wait_seconds=0.1):
  491. self.lock_file_path = lock_file_path
  492. self.wait_seconds = wait_seconds
  493. self.fd = None
  494. def try_acquire(self):
  495. try:
  496. self.fd = os.open(self.lock_file_path, os.O_CREAT | os.O_EXCL)
  497. return True
  498. except FileExistsError:
  499. return False
  500. def wait(self):
  501. while os.path.exists(self.lock_file_path):
  502. time.sleep(self.wait_seconds)
  503. def release(self):
  504. if self.fd is not None:
  505. os.close(self.fd)
  506. os.remove(self.lock_file_path)
  507. #####################################################################
  508. # Phase 6
  509. #####################################################################
  510. def _build_with_ninja(build_dir: str, verbose: bool, error_prefix: str):
  511. command = ["ninja", "-v"]
  512. env = os.environ.copy()
  513. try:
  514. sys.stdout.flush()
  515. sys.stderr.flush()
  516. stdout_fileno = 1
  517. subprocess.run(
  518. command,
  519. stdout=stdout_fileno if verbose else subprocess.PIPE,
  520. stderr=subprocess.STDOUT,
  521. cwd=build_dir,
  522. check=True,
  523. env=env,
  524. )
  525. except subprocess.CalledProcessError as e:
  526. with open(os.path.join(build_dir, "build.ninja")) as f:
  527. lines = f.readlines()
  528. print(lines)
  529. _, error, _ = sys.exc_info()
  530. message = error_prefix
  531. if hasattr(error, "output") and error.output:
  532. message += ": {}".format(error.output.decode())
  533. raise RuntimeError(message) from e
  534. def build(
  535. name: str,
  536. sources: Union[str, List[str]],
  537. extra_cflags: Union[str, List[str]] = [],
  538. extra_cuda_cflags: Union[str, List[str]] = [],
  539. extra_ldflags: Union[str, List[str]] = [],
  540. extra_include_paths: Union[str, List[str]] = [],
  541. with_cuda: Optional[bool] = None,
  542. build_dir: Optional[bool] = None,
  543. verbose: bool = False,
  544. abi_tag: Optional[int] = None,
  545. ) -> str:
  546. r"""Build a Custom Op with ninja in the way of just-in-time (JIT).
  547. To build the custom op, a Ninja build file is emitted, which is used to
  548. compile the given sources into a dynamic library.
  549. By default, the directory to which the build file is emitted and the
  550. resulting library compiled to is ``<tmp>/mge_custom_op/<name>``, where
  551. ``<tmp>`` is the temporary folder on the current platform and ``<name>``
  552. the name of the custom op. This location can be overridden in two ways.
  553. First, if the ``MGE_CUSTOM_OP_DIR`` environment variable is set, it
  554. replaces ``<tmp>/mge_custom_op`` and all custom op will be compiled
  555. into subfolders of this directory. Second, if the ``build_dir``
  556. argument to this function is supplied, it overrides the entire path, i.e.
  557. the library will be compiled into that folder directly.
  558. To compile the sources, the default system compiler (``c++``) is used,
  559. which can be overridden by setting the ``CXX`` environment variable. To pass
  560. additional arguments to the compilation process, ``extra_cflags`` or
  561. ``extra_ldflags`` can be provided. For example, to compile your custom op
  562. with optimizations, pass ``extra_cflags=['-O3']``. You can also use
  563. ``extra_cflags`` to pass further include directories.
  564. CUDA support with mixed compilation is provided. Simply pass CUDA source
  565. files (``.cu`` or ``.cuh``) along with other sources. Such files will be
  566. detected and compiled with nvcc rather than the C++ compiler. This includes
  567. passing the CUDA lib64 directory as a library directory, and linking
  568. ``cudart``. You can pass additional flags to nvcc via
  569. ``extra_cuda_cflags``, just like with ``extra_cflags`` for C++. Various
  570. heuristics for finding the CUDA install directory are used, which usually
  571. work fine. If not, setting the ``CUDA_ROOT_DIR`` environment variable is the
  572. safest option. If you use CUDNN, please also setting the ``CUDNN_ROOT_DIR``
  573. environment variable.
  574. Args:
  575. name: The name of the custom op to build.
  576. sources: A list of relative or absolute paths to C++ source files.
  577. extra_cflags: optional list of compiler flags to forward to the build.
  578. extra_cuda_cflags: optional list of compiler flags to forward to nvcc
  579. when building CUDA sources.
  580. extra_ldflags: optional list of linker flags to forward to the build.
  581. extra_include_paths: optional list of include directories to forward
  582. to the build.
  583. with_cuda: Determines whether CUDA headers and libraries are added to
  584. the build. If set to ``None`` (default), this value is
  585. automatically determined based on the existence of ``.cu`` or
  586. ``.cuh`` in ``sources``. Set it to `True`` to force CUDA headers
  587. and libraries to be included.
  588. build_dir: optional path to use as build workspace.
  589. verbose: If ``True``, turns on verbose logging of load steps.
  590. abi_tag: Determines the value of MACRO ``_GLIBCXX_USE_CXX11_ABI``
  591. in gcc compiler, should be ``0`` or ``1``.
  592. Returns:
  593. the compiled dynamic library path
  594. """
  595. # phase 1: prepare config
  596. if abi_tag != None:
  597. global MGE_ABI_VER
  598. MGE_ABI_VER = abi_tag
  599. def strlist(args, name):
  600. assert isinstance(args, str) or isinstance(
  601. args, list
  602. ), "{} must be str or list[str]".format(name)
  603. if isinstance(args, str):
  604. return [args]
  605. for arg in args:
  606. assert isinstance(arg, str)
  607. args = [arg.strip() for arg in args]
  608. return args
  609. sources = strlist(sources, "sources")
  610. extra_cflags = strlist(extra_cflags, "extra_cflags")
  611. extra_cuda_cflags = strlist(extra_cuda_cflags, "extra_cuda_cflags")
  612. extra_ldflags = strlist(extra_ldflags, "extra_ldflags")
  613. extra_include_paths = strlist(extra_include_paths, "extra_include_paths")
  614. with_cuda = any(map(_is_cuda_file, sources)) if with_cuda is None else with_cuda
  615. with_cudnn = any(["cudnn" in f for f in extra_ldflags])
  616. if CUDA_ROOT_DIR == None and with_cuda:
  617. print(
  618. "No CUDA runtime is found, using {}=/path/to/your/cuda_root_dir".format(
  619. ev_cuda_root_dir
  620. )
  621. )
  622. if CUDNN_ROOT_DIR == None and with_cudnn:
  623. print(
  624. "Cannot find the root directory of cudnn, using {}=/path/to/your/cudnn_root_dir".format(
  625. ev_cudnn_root_dir
  626. )
  627. )
  628. build_dir = os.path.abspath(
  629. _get_build_dir(name) if build_dir is None else build_dir
  630. )
  631. if not os.path.exists(build_dir):
  632. os.makedirs(build_dir, exist_ok=True)
  633. if verbose:
  634. print("Using {} to build megengine custom op".format(build_dir))
  635. # phase 2: version check
  636. version, old_version = version_check(
  637. name,
  638. sources,
  639. [extra_cflags, extra_cuda_cflags, extra_ldflags, extra_include_paths],
  640. build_dir,
  641. with_cuda,
  642. with_cudnn,
  643. abi_tag,
  644. )
  645. if verbose:
  646. if version != old_version and old_version != None:
  647. print(
  648. "Input conditions of custom op {} have changed, bumping to version {}".format(
  649. name, version
  650. )
  651. )
  652. print("Building custom op {} with version {}".format(name, version))
  653. if version == old_version:
  654. if verbose:
  655. print(
  656. "No modifications detected for {}, skipping build step...".format(name)
  657. )
  658. return
  659. name = "{}_v{}".format(name, version)
  660. # phase 3: compiler and ninja check
  661. _check_ninja_availability()
  662. _check_compiler_comatibility()
  663. # phase 4: setup the compile flags
  664. user_includes, system_includes = _setup_includes(
  665. extra_include_paths, with_cuda, with_cudnn
  666. )
  667. common_cflags = _setup_common_cflags(user_includes, system_includes)
  668. cuda_cflags = (
  669. _setup_cuda_cflags(common_cflags, extra_cuda_cflags) if with_cuda else None
  670. )
  671. ldflags = _setup_ldflags(extra_ldflags, with_cuda, with_cudnn)
  672. if IS_WINDOWS:
  673. cflags = common_cflags + COMMON_MSVC_FLAGS + extra_cflags
  674. cflags = _nt_quote_args(cflags)
  675. else:
  676. cflags = common_cflags + ["-fPIC", "-std=c++14"] + extra_cflags
  677. ldflags = _add_shared_flag(ldflags)
  678. if sys.platform.startswith("darwin"):
  679. ldflags.append("-undefined dynamic_lookup")
  680. elif IS_WINDOWS:
  681. ldflags += ["/link"]
  682. ldflags = _nt_quote_args(ldflags)
  683. baton = FileBaton(os.path.join(build_dir, "lock"))
  684. if baton.try_acquire():
  685. try:
  686. # phase 5: generate ninja build file
  687. objs = [_obj_file_path(src) for src in sources]
  688. name += ".dll" if IS_WINDOWS else ".so"
  689. build_file_path = os.path.join(build_dir, "build.ninja")
  690. if verbose:
  691. print("Emitting ninja build file {}".format(build_file_path))
  692. _dump_ninja_file(
  693. path=build_file_path,
  694. cflags=cflags,
  695. post_cflags=None,
  696. cuda_cflags=cuda_cflags,
  697. cuda_post_cflags=None,
  698. sources=sources,
  699. objects=objs,
  700. ldflags=ldflags,
  701. library_target=name,
  702. with_cuda=with_cuda,
  703. )
  704. # phase 6: build with ninja
  705. if verbose:
  706. print(
  707. "Compiling and linking your custom op {}".format(
  708. os.path.join(build_dir, name)
  709. )
  710. )
  711. _build_with_ninja(build_dir, verbose, "compiling error")
  712. finally:
  713. baton.release()
  714. else:
  715. baton.wait()
  716. return os.path.join(build_dir, name)
  717. def build_and_load(
  718. name: str,
  719. sources: Union[str, List[str]],
  720. extra_cflags: Union[str, List[str]] = [],
  721. extra_cuda_cflags: Union[str, List[str]] = [],
  722. extra_ldflags: Union[str, List[str]] = [],
  723. extra_include_paths: Union[str, List[str]] = [],
  724. with_cuda: Optional[bool] = None,
  725. build_dir: Optional[bool] = None,
  726. verbose: bool = False,
  727. abi_tag: Optional[int] = None,
  728. ) -> str:
  729. r"""Build and Load a Custom Op with ninja in the way of just-in-time (JIT).
  730. Same as the function ``build()`` but load the built dynamic library.
  731. Args:
  732. same as ``build()``
  733. Returns:
  734. the compiled dynamic library path
  735. """
  736. lib_path = build(
  737. name,
  738. sources,
  739. extra_cflags,
  740. extra_cuda_cflags,
  741. extra_ldflags,
  742. extra_include_paths,
  743. with_cuda,
  744. build_dir,
  745. verbose,
  746. abi_tag,
  747. )
  748. if verbose:
  749. print("Load the compiled custom op {}".format(lib_path))
  750. load(lib_path)
  751. return lib_path