diff --git a/CMakeLists.txt b/CMakeLists.txt index 2dbd6d28..b7efcbf3 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1145,6 +1145,14 @@ if(TARGET _imperative_rt) COMMAND ${CMAKE_COMMAND} -E create_symlink ${CMAKE_CURRENT_BINARY_DIR}/imperative/python/${PACKAGE_NAME}/version.py ${CMAKE_CURRENT_SOURCE_DIR}/imperative/python/${PACKAGE_NAME}/version.py + COMMAND ${CMAKE_COMMAND} -E create_symlink + ${CMAKE_CURRENT_SOURCE_DIR}/src/custom/include + ${CMAKE_CURRENT_SOURCE_DIR}/imperative/python/${PACKAGE_NAME}/core/include + COMMAND ${CMAKE_COMMAND} -E make_directory + ${CMAKE_CURRENT_SOURCE_DIR}/imperative/python/${PACKAGE_NAME}/core/lib + COMMAND ${CMAKE_COMMAND} -E create_symlink + ${CMAKE_CURRENT_BINARY_DIR}/src/$ + ${CMAKE_CURRENT_SOURCE_DIR}/imperative/python/${PACKAGE_NAME}/core/lib/$ DEPENDS _imperative_rt VERBATIM ) diff --git a/imperative/CMakeLists.txt b/imperative/CMakeLists.txt index cf1a8c35..bd6359c4 100644 --- a/imperative/CMakeLists.txt +++ b/imperative/CMakeLists.txt @@ -67,8 +67,11 @@ add_custom_command( COMMAND ${CMAKE_COMMAND} -E copy ${PROJECT_SOURCE_DIR}/LICENSE ${PROJECT_SOURCE_DIR}/ACKNOWLEDGMENTS ${PROJECT_BINARY_DIR} COMMAND ${CMAKE_COMMAND} -E remove -f ${CMAKE_CURRENT_SOURCE_DIR}/python/megengine/core/$ # clean develop COMMAND ${CMAKE_COMMAND} -E remove -f ${CMAKE_CURRENT_SOURCE_DIR}/python/megengine/version.py # clean develop + COMMAND ${CMAKE_COMMAND} -E remove -f ${CMAKE_CURRENT_SOURCE_DIR}/python/megengine/core/include # clean develop + COMMAND ${CMAKE_COMMAND} -E remove -f ${CMAKE_CURRENT_SOURCE_DIR}/python/megengine/core/lib # clean develop COMMAND ${CMAKE_COMMAND} -E copy_directory ${CMAKE_CURRENT_SOURCE_DIR}/python/megengine ${CMAKE_CURRENT_BINARY_DIR}/python/megengine COMMAND ${CMAKE_COMMAND} -E copy_directory ${CMAKE_CURRENT_SOURCE_DIR}/python/test ${CMAKE_CURRENT_BINARY_DIR}/python/test + COMMAND ${CMAKE_COMMAND} -E copy_directory ${PROJECT_SOURCE_DIR}/src/custom/include ${CMAKE_CURRENT_BINARY_DIR}/python/megengine/core/include COMMAND ${CMAKE_COMMAND} -E copy ${CMAKE_CURRENT_SOURCE_DIR}/python/setup.py ${CMAKE_CURRENT_BINARY_DIR}/python/setup.py COMMAND ${CMAKE_COMMAND} -E copy ${CMAKE_CURRENT_SOURCE_DIR}/python/requires.txt ${CMAKE_CURRENT_BINARY_DIR}/python/requires.txt COMMAND ${CMAKE_COMMAND} -E copy ${CMAKE_CURRENT_SOURCE_DIR}/python/requires-style.txt ${CMAKE_CURRENT_BINARY_DIR}/python/requires-style.txt diff --git a/imperative/python/megengine/core/ops/custom.py b/imperative/python/megengine/core/ops/custom.py index b1a055fd..b60527c3 100644 --- a/imperative/python/megengine/core/ops/custom.py +++ b/imperative/python/megengine/core/ops/custom.py @@ -7,11 +7,14 @@ # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +import os + from .._imperative_rt.ops._custom import ( _get_custom_op_list, _install, _make_custom_op, _uninstall, + get_custom_op_abi_tag, ) __all__ = ["load"] @@ -25,8 +28,16 @@ def _gen_custom_op_maker(custom_op_name): def load(lib_path): - op_in_this_lib = _install(lib_path[0:-3], lib_path) + lib_path = os.path.abspath(lib_path) + lib_name = os.path.splitext(lib_path)[0] + op_in_this_lib = _install(lib_name, lib_path) for op in op_in_this_lib: op_maker = _gen_custom_op_maker(op) globals()[op] = op_maker __all__.append(op) + + +def unload(lib_path): + lib_path = os.path.abspath(lib_path) + lib_name = os.path.splitext(lib_path)[0] + _uninstall(lib_name) diff --git a/imperative/python/megengine/utils/custom_op_tools.py b/imperative/python/megengine/utils/custom_op_tools.py new file mode 100644 index 00000000..d9150fef --- /dev/null +++ b/imperative/python/megengine/utils/custom_op_tools.py @@ -0,0 +1,909 @@ +# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") +# +# Copyright (c) 2014-2021 Megvii Inc. All rights reserved. +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + +import collections +import ctypes +import glob +import os +import re +import subprocess +import sys +import time +from typing import List, Optional, Union + +from ..core.ops.custom import load +from ..logger import get_logger + + +def _get_win_folder_with_ctypes(csidl_name): + csidl_const = { + "CSIDL_APPDATA": 26, + "CSIDL_COMMON_APPDATA": 35, + "CSIDL_LOCAL_APPDATA": 28, + }[csidl_name] + + buf = ctypes.create_unicode_buffer(1024) + ctypes.windll.shell32.SHGetFolderPathW(None, csidl_const, None, 0, buf) + + # Downgrade to short path name if have highbit chars. See + # . + has_high_char = False + for c in buf: + if ord(c) > 255: + has_high_char = True + break + if has_high_char: + buf2 = ctypes.create_unicode_buffer(1024) + if ctypes.windll.kernel32.GetShortPathNameW(buf.value, buf2, 1024): + buf = buf2 + + return buf.value + + +system = sys.platform +if system == "win32": + _get_win_folder = _get_win_folder_with_ctypes + +PLAT_TO_VCVARS = { + "win-amd64": "x86_amd64", +} + +logger = get_logger() + +# environment varible +ev_custom_op_root_dir = "MGE_CUSTOM_OP_DIR" +ev_cuda_root_dir = "CUDA_ROOT_DIR" +ev_cudnn_root_dir = "CUDNN_ROOT_DIR" + +# operating system +IS_WINDOWS = system == "win32" +IS_LINUX = system == "linux" +IS_MACOS = system == "darwin" + +MGE_PATH = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) +MGE_INC_PATH = os.path.join(MGE_PATH, "core", "include") +MGE_LIB_PATH = os.path.join(MGE_PATH, "core", "lib") +MGE_ABI_VER = 0 + + +# compile version +MINIMUM_GCC_VERSION = (5, 0, 0) +MINIMUM_CLANG_CL_VERSION = (12, 0, 1) + +# compile flags +COMMON_MSVC_FLAGS = [ + "/MD", + "/wd4002", + "/wd4819", + "/EHsc", +] + +MSVC_IGNORE_CUDAFE_WARNINGS = [ + "field_without_dll_interface", +] + +COMMON_NVCC_FLAGS = [] + +# Finds the CUDA install path +def _find_cuda_root_dir() -> Optional[str]: + cuda_root_dir = os.environ.get(ev_cuda_root_dir) + if cuda_root_dir is None: + try: + which = "where" if IS_WINDOWS else "which" + with open(os.devnull, "w") as devnull: + nvcc = ( + subprocess.check_output([which, "nvcc"], stderr=devnull) + .decode() + .rstrip("\r\n") + ) + cuda_root_dir = os.path.dirname(os.path.dirname(nvcc)) + except Exception: + if IS_WINDOWS: + cuda_root_dir = os.environ.get("CUDA_PATH", None) + if cuda_root_dir == None: + cuda_root_dirs = glob.glob( + "C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v*.*" + ) + if len(cuda_root_dirs) == 0: + cuda_root_dir = "" + else: + cuda_root_dir = cuda_root_dirs[0] + else: + cuda_root_dir = "/usr/local/cuda" + if not os.path.exists(cuda_root_dir): + cuda_root_dir = None + return cuda_root_dir + + +def _find_cudnn_root_dir() -> Optional[str]: + cudnn_root_dir = os.environ.get(ev_cudnn_root_dir) + return cudnn_root_dir + + +CUDA_ROOT_DIR = _find_cuda_root_dir() +CUDNN_ROOT_DIR = _find_cudnn_root_dir() + +##################################################################### +# Phase 1 +##################################################################### + + +def _is_cuda_file(path: str) -> bool: + valid_ext = [".cu", ".cuh"] + return os.path.splitext(path)[1] in valid_ext + + +# Return full path to the user-specific cache dir for this application. +# Typical user cache directories are: +# Mac OS X: ~/Library/Caches/ +# Unix: ~/.cache/ (XDG default) +# Windows: C:\Users\\AppData\Local\\\Cache +def _get_user_cache_dir(appname=None, appauthor=None, version=None, opinion=True): + if system == "win32": + appauthor = appname if appauthor is None else appauthor + path = os.path.normpath(_get_win_folder("CSIDL_LOCAL_APPDATA")) + if appname: + if appauthor is not False: + path = os.path.join(path, appauthor) + else: + path = os.path.join(path, appname) + if opinion: + path = os.path.join(path, "Cache") + elif system == "darwin": + path = os.path.expanduser("~/Library/Caches") + if appname: + path = os.path.join(path, appname) + else: + path = os.getenv("XDG_CACHE_HOME", os.path.expanduser("~/.cache")) + if appname: + path = os.path.join(path, appname) + if appname and version: + path = os.path.join(path, version) + return path + + +# Returns the path to the root folder under which custom op will built. +def _get_default_build_root() -> str: + return os.path.realpath(_get_user_cache_dir(appname="mge_custom_op")) + + +def _get_build_dir(name: str) -> str: + custom_op_root_dir = os.environ.get(ev_custom_op_root_dir) + if custom_op_root_dir is None: + custom_op_root_dir = _get_default_build_root() + + build_dir = os.path.join(custom_op_root_dir, name) + return build_dir + + +##################################################################### +# Phase 2 +##################################################################### + + +def update_hash(seed, value): + # using boost::hash_combine + # https://www.boost.org/doc/libs/1_35_0/doc/html/boost/hash_combine_id241013.html + return seed ^ (hash(value) + 0x9E3779B9 + (seed << 6) + (seed >> 2)) + + +def hash_source_files(hash_value, source_files): + for filename in source_files: + with open(filename) as file: + hash_value = update_hash(hash_value, file.read()) + return hash_value + + +def hash_build_args(hash_value, build_args): + for group in build_args: + for arg in group: + hash_value = update_hash(hash_value, arg) + return hash_value + + +Entry = collections.namedtuple("Entry", "version, hash") + + +class Versioner(object): + def __init__(self): + self.entries = {} + + def get_version(self, name): + entry = self.entries.get(name) + return None if entry is None else entry.version + + def bump_version_if_changed( + self, name, sources, build_args, build_dir, with_cuda, with_cudnn, abi_tag + ): + hash_value = 0 + hash_value = hash_source_files(hash_value, sources) + hash_value = hash_build_args(hash_value, build_args) + hash_value = update_hash(hash_value, build_dir) + hash_value = update_hash(hash_value, with_cuda) + hash_value = update_hash(hash_value, with_cudnn) + hash_value = update_hash(hash_value, abi_tag) + + entry = self.entries.get(name) + if entry is None: + self.entries[name] = entry = Entry(0, hash_value) + elif hash_value != entry.hash: + self.entries[name] = entry = Entry(entry.version + 1, hash_value) + + return entry.version + + +custom_op_versioner = Versioner() + + +def version_check( + name, sources, build_args, build_dir, with_cuda, with_cudnn, abi_tag, +): + old_version = custom_op_versioner.get_version(name) + version = custom_op_versioner.bump_version_if_changed( + name, sources, build_args, build_dir, with_cuda, with_cudnn, abi_tag, + ) + return version, old_version + + +##################################################################### +# Phase 3 +##################################################################### + + +def _check_ninja_availability(): + try: + subprocess.check_output("ninja --version".split()) + except Exception: + raise RuntimeError( + "Ninja is required to build custom op, please install ninja and update your PATH" + ) + + +def _mge_is_built_from_src(): + file_path = os.path.abspath(__file__) + if "site-packages" in file_path: + return False + else: + return True + + +def _accepted_compilers_for_platform(): + if IS_WINDOWS: + return ["clang-cl"] + if IS_MACOS: + return ["clang++", "clang"] + if IS_LINUX: + return ["g++", "gcc", "gnu-c++", "gnu-cc"] + + +# Verifies that the compiler is the expected one for the current platform. +def _check_compiler_existed_for_platform(compiler: str) -> bool: + # there is no suitable cmd like `which` on windows, so we assume the compiler is always true on windows + if IS_WINDOWS: + try: + version_string = subprocess.check_output( + ["clang-cl", "--version"], stderr=subprocess.STDOUT + ).decode() + return True + except Exception: + return False + + # use os.path.realpath to resolve any symlinks, in particular from "c++" to e.g. "g++". + which = subprocess.check_output(["which", compiler], stderr=subprocess.STDOUT) + compiler_path = os.path.realpath(which.decode().strip()) + if any(name in compiler_path for name in _accepted_compilers_for_platform()): + return True + + version_string = subprocess.check_output( + [compiler, "-v"], stderr=subprocess.STDOUT + ).decode() + if sys.platform.startswith("linux"): + pattern = re.compile("^COLLECT_GCC=(.*)$", re.MULTILINE) + results = re.findall(pattern, version_string) + if len(results) != 1: + return False + compiler_path = os.path.realpath(results[0].strip()) + return any(name in compiler_path for name in _accepted_compilers_for_platform()) + + if sys.platform.startswith("darwin"): + return version_string.startswith("Apple clang") + + return False + + +# Verifies that the given compiler is ABI-compatible with MegEngine. +def _check_compiler_abi_compatibility(compiler: str): + # we think if the megengine is built from source, the user will use the same compiler to compile the custom op + if _mge_is_built_from_src() or os.environ.get("MGE_CHECK_ABI", "1") == "0": + return True + + # [TODO] There is no particular minimum version we need for clang, so we"re good here. + if sys.platform.startswith("darwin"): + return True + + try: + if sys.platform.startswith("linux"): + minimum_required_version = MINIMUM_GCC_VERSION + versionstr = subprocess.check_output( + [compiler, "-dumpfullversion", "-dumpversion"] + ) + version = versionstr.decode().strip().split(".") + else: + minimum_required_version = MINIMUM_CLANG_CL_VERSION + compiler_info = subprocess.check_output( + [compiler, "--version"], stderr=subprocess.STDOUT + ) + match = re.search(r"(\d+)\.(\d+)\.(\d+)", compiler_info.decode().strip()) + version = (0, 0, 0) if match is None else match.groups() + except Exception: + _, error, _ = sys.exc_info() + logger.warning( + "Error checking compiler version for {}: {}".format(compiler, error) + ) + return False + + if tuple(map(int, version)) >= minimum_required_version: + return True + + return False + + +def _check_compiler_comatibility(): + # we use clang-cl on windows, refer: https://clang.llvm.org/docs/UsersManual.html#clang-cl + compiler = ( + os.environ.get("CXX", "clang-cl") + if IS_WINDOWS + else os.environ.get("CXX", "c++") + ) + + existed = _check_compiler_existed_for_platform(compiler) + if existed == False: + log_str = ( + "Cannot find compiler which is compatible with the compiler " + "MegEngine was built with for this platform, which is {mge_compiler} on " + "{platform}. Please use {mge_compiler} to to compile your extension. " + "Alternatively, you may compile MegEngine from source using " + "{user_compiler}, and then you can also use {user_compiler} to compile " + "your extension." + ).format( + user_compiler=compiler, + mge_compiler=_accepted_compilers_for_platform()[0], + platform=sys.platform, + ) + + logger.warning(log_str) + return False + + compatible = _check_compiler_abi_compatibility(compiler) + if compatible == False: + log_str = ( + "Your compiler version may be ABI-incompatible with MegEngine! " + "Please use a compiler that is ABI-compatible with GCC 5.0 on Linux " + "and LLVM/Clang 12.0 on Windows ." + ) + logger.warning(log_str) + return True + + +##################################################################### +# Phase 4 +##################################################################### + + +# Quote command-line arguments for DOS/Windows conventions. +def _nt_quote_args(args: Optional[List[str]]) -> List[str]: + # Cover None-type + if not args: + return [] + return ['"{}"'.format(arg) if " " in arg else arg for arg in args] + + +# Now we need user to specify the arch of GPU +def _get_cuda_arch_flags(cflags: Optional[List[str]] = None) -> List[str]: + return [] + + +def _setup_sys_includes(with_cuda: bool, with_cudnn: bool): + includes = [os.path.join(MGE_INC_PATH)] + if with_cuda: + includes.append(os.path.join(CUDA_ROOT_DIR, "include")) + if with_cudnn: + includes.append(os.path.join(CUDNN_ROOT_DIR, "include")) + return includes + + +def _setup_includes(extra_include_paths: List[str], with_cuda: bool, with_cudnn: bool): + user_includes = [os.path.abspath(path) for path in extra_include_paths] + system_includes = _setup_sys_includes(with_cuda, with_cudnn) + if IS_WINDOWS: + user_includes += system_includes + system_includes.clear() + return user_includes, system_includes + + +def _setup_common_cflags(user_includes: List[str], system_includes: List[str]): + common_cflags = [] + common_cflags += ["-I{}".format(include) for include in user_includes] + common_cflags += ["-isystem {}".format(include) for include in system_includes] + if not IS_WINDOWS: + common_cflags += ["-D_GLIBCXX_USE_CXX11_ABI={}".format(MGE_ABI_VER)] + return common_cflags + + +def _setup_cuda_cflags(cflags: List[str], extra_cuda_cflags: List[str]): + cuda_flags = cflags + COMMON_NVCC_FLAGS + _get_cuda_arch_flags() + if IS_WINDOWS: + for flag in COMMON_MSVC_FLAGS: + cuda_flags = ["-Xcompiler", flag] + cuda_flags + for ignore_warning in MSVC_IGNORE_CUDAFE_WARNINGS: + cuda_flags = ["-Xcudafe", "--diag_suppress=" + ignore_warning] + cuda_flags + cuda_flags = _nt_quote_args(cuda_flags) + cuda_flags += _nt_quote_args(extra_cuda_cflags) + else: + cuda_flags += ["--compiler-options", '"-fPIC"'] + cuda_flags += extra_cuda_cflags + if not any(flag.startswith("-std=") for flag in cuda_flags): + cuda_flags.append("-std=c++14") + if os.getenv("CC") is not None: + cuda_flags = ["-ccbin", os.getenv("CC")] + cuda_flags + return cuda_flags + + +def _setup_ldflags( + extra_ldflags: List[str], with_cuda: bool, with_cudnn: bool +) -> List[str]: + ldflags = extra_ldflags + if IS_WINDOWS: + ldflags.append(os.path.join(MGE_LIB_PATH, "megengine_shared.lib")) + if with_cuda: + ldflags.append(os.path.join(CUDA_ROOT_DIR, "lib", "x64", "cudart.lib")) + if with_cudnn: + ldflags.append(os.path.join(CUDNN_ROOT_DIR, "lib", "x64", "cudnn.lib")) + + else: + ldflags.append("-lmegengine_shared -L{}".format(MGE_LIB_PATH)) + ldflags.append("-Wl,-rpath,{}".format(MGE_LIB_PATH)) + if with_cuda: + ldflags.append("-lcudart") + ldflags.append("-L{}".format(os.path.join(CUDA_ROOT_DIR, "lib64"))) + ldflags.append("-Wl,-rpath,{}".format(os.path.join(CUDA_ROOT_DIR, "lib64"))) + if with_cudnn: + ldflags.append("-L{}".format(os.path.join(CUDNN_ROOT_DIR, "lib64"))) + ldflags.append( + "-Wl,-rpath,{}".format(os.path.join(CUDNN_ROOT_DIR, "lib64")) + ) + + return ldflags + + +def _add_shared_flag(ldflags: List[str]): + ldflags += ["/LD" if IS_WINDOWS else "-shared"] + return ldflags + + +##################################################################### +# Phase 5 +##################################################################### + + +def _obj_file_path(src_file_path: str): + file_name = os.path.splitext(os.path.basename(src_file_path))[0] + if _is_cuda_file(src_file_path): + target = "{}.cuda.o".format(file_name) + else: + target = "{}.o".format(file_name) + return target + + +def _dump_ninja_file( + path, + cflags, + post_cflags, + cuda_cflags, + cuda_post_cflags, + sources, + objects, + ldflags, + library_target, + with_cuda, +): + def sanitize_flags(flags): + return [] if flags is None else [flag.strip() for flag in flags] + + cflags = sanitize_flags(cflags) + post_cflags = sanitize_flags(post_cflags) + cuda_cflags = sanitize_flags(cuda_cflags) + cuda_post_cflags = sanitize_flags(cuda_post_cflags) + ldflags = sanitize_flags(ldflags) + + assert len(sources) == len(objects) + assert len(sources) > 0 + + if IS_WINDOWS: + compiler = os.environ.get("CXX", "clang-cl") + else: + compiler = os.environ.get("CXX", "c++") + + # Version 1.3 is required for the `deps` directive. + config = ["ninja_required_version = 1.3"] + config.append("cxx = {}".format(compiler)) + if with_cuda: + nvcc = os.path.join(CUDA_ROOT_DIR, "bin", "nvcc") + config.append("nvcc = {}".format(nvcc)) + + flags = ["cflags = {}".format(" ".join(cflags))] + flags.append("post_cflags = {}".format(" ".join(post_cflags))) + if with_cuda: + flags.append("cuda_cflags = {}".format(" ".join(cuda_cflags))) + flags.append("cuda_post_cflags = {}".format(" ".join(cuda_post_cflags))) + flags.append("ldflags = {}".format(" ".join(ldflags))) + + # Turn into absolute paths so we can emit them into the ninja build + # file wherever it is. + sources = [os.path.abspath(file) for file in sources] + + # See https://ninja-build.org/build.ninja.html for reference. + compile_rule = ["rule compile"] + if IS_WINDOWS: + compile_rule.append( + " command = clang-cl /showIncludes $cflags -c $in /Fo$out $post_cflags" + ) + compile_rule.append(" deps = msvc") + else: + compile_rule.append( + " command = $cxx -MMD -MF $out.d $cflags -c $in -o $out $post_cflags" + ) + compile_rule.append(" depfile = $out.d") + compile_rule.append(" deps = gcc") + + if with_cuda: + cuda_compile_rule = ["rule cuda_compile"] + nvcc_gendeps = "" + cuda_compile_rule.append( + " command = $nvcc {} $cuda_cflags -c $in -o $out $cuda_post_cflags".format( + nvcc_gendeps + ) + ) + + # Emit one build rule per source to enable incremental build. + build = [] + for source_file, object_file in zip(sources, objects): + is_cuda_source = _is_cuda_file(source_file) and with_cuda + rule = "cuda_compile" if is_cuda_source else "compile" + if IS_WINDOWS: + source_file = source_file.replace(":", "$:") + object_file = object_file.replace(":", "$:") + source_file = source_file.replace(" ", "$ ") + object_file = object_file.replace(" ", "$ ") + build.append("build {}: {} {}".format(object_file, rule, source_file)) + + if library_target is not None: + link_rule = ["rule link"] + if IS_WINDOWS: + link_rule.append(" command = clang-cl $in /nologo $ldflags /out:$out") + else: + link_rule.append(" command = $cxx $in $ldflags -o $out") + + link = ["build {}: link {}".format(library_target, " ".join(objects))] + default = ["default {}".format(library_target)] + else: + link_rule, link, default = [], [], [] + + # 'Blocks' should be separated by newlines, for visual benefit. + blocks = [config, flags, compile_rule] + if with_cuda: + blocks.append(cuda_compile_rule) + blocks += [link_rule, build, link, default] + with open(path, "w") as build_file: + for block in blocks: + lines = "\n".join(block) + build_file.write("{}\n\n".format(lines)) + + +class FileBaton: + def __init__(self, lock_file_path, wait_seconds=0.1): + self.lock_file_path = lock_file_path + self.wait_seconds = wait_seconds + self.fd = None + + def try_acquire(self): + try: + self.fd = os.open(self.lock_file_path, os.O_CREAT | os.O_EXCL) + return True + except FileExistsError: + return False + + def wait(self): + while os.path.exists(self.lock_file_path): + time.sleep(self.wait_seconds) + + def release(self): + if self.fd is not None: + os.close(self.fd) + + os.remove(self.lock_file_path) + + +##################################################################### +# Phase 6 +##################################################################### + + +def _build_with_ninja(build_dir: str, verbose: bool, error_prefix: str): + command = ["ninja", "-v"] + env = os.environ.copy() + try: + sys.stdout.flush() + sys.stderr.flush() + stdout_fileno = 1 + subprocess.run( + command, + stdout=stdout_fileno if verbose else subprocess.PIPE, + stderr=subprocess.STDOUT, + cwd=build_dir, + check=True, + env=env, + ) + except subprocess.CalledProcessError as e: + with open(os.path.join(build_dir, "build.ninja")) as f: + lines = f.readlines() + print(lines) + _, error, _ = sys.exc_info() + message = error_prefix + if hasattr(error, "output") and error.output: + message += ": {}".format(error.output.decode()) + raise RuntimeError(message) from e + + +def build( + name: str, + sources: Union[str, List[str]], + extra_cflags: Union[str, List[str]] = [], + extra_cuda_cflags: Union[str, List[str]] = [], + extra_ldflags: Union[str, List[str]] = [], + extra_include_paths: Union[str, List[str]] = [], + with_cuda: Optional[bool] = None, + build_dir: Optional[bool] = None, + verbose: bool = False, + abi_tag: Optional[int] = None, +) -> str: + r"""Build a Custom Op with ninja in the way of just-in-time (JIT). + + To build the custom op, a Ninja build file is emitted, which is used to + compile the given sources into a dynamic library. + + By default, the directory to which the build file is emitted and the + resulting library compiled to is ``/mge_custom_op/``, where + ```` is the temporary folder on the current platform and ```` + the name of the custom op. This location can be overridden in two ways. + First, if the ``MGE_CUSTOM_OP_DIR`` environment variable is set, it + replaces ``/mge_custom_op`` and all custom op will be compiled + into subfolders of this directory. Second, if the ``build_dir`` + argument to this function is supplied, it overrides the entire path, i.e. + the library will be compiled into that folder directly. + + To compile the sources, the default system compiler (``c++``) is used, + which can be overridden by setting the ``CXX`` environment variable. To pass + additional arguments to the compilation process, ``extra_cflags`` or + ``extra_ldflags`` can be provided. For example, to compile your custom op + with optimizations, pass ``extra_cflags=['-O3']``. You can also use + ``extra_cflags`` to pass further include directories. + + CUDA support with mixed compilation is provided. Simply pass CUDA source + files (``.cu`` or ``.cuh``) along with other sources. Such files will be + detected and compiled with nvcc rather than the C++ compiler. This includes + passing the CUDA lib64 directory as a library directory, and linking + ``cudart``. You can pass additional flags to nvcc via + ``extra_cuda_cflags``, just like with ``extra_cflags`` for C++. Various + heuristics for finding the CUDA install directory are used, which usually + work fine. If not, setting the ``CUDA_ROOT_DIR`` environment variable is the + safest option. If you use CUDNN, please also setting the ``CUDNN_ROOT_DIR`` + environment variable. + + Args: + name: The name of the custom op to build. + sources: A list of relative or absolute paths to C++ source files. + extra_cflags: optional list of compiler flags to forward to the build. + extra_cuda_cflags: optional list of compiler flags to forward to nvcc + when building CUDA sources. + extra_ldflags: optional list of linker flags to forward to the build. + extra_include_paths: optional list of include directories to forward + to the build. + with_cuda: Determines whether CUDA headers and libraries are added to + the build. If set to ``None`` (default), this value is + automatically determined based on the existence of ``.cu`` or + ``.cuh`` in ``sources``. Set it to `True`` to force CUDA headers + and libraries to be included. + build_dir: optional path to use as build workspace. + verbose: If ``True``, turns on verbose logging of load steps. + abi_tag: Determines the value of MACRO ``_GLIBCXX_USE_CXX11_ABI`` + in gcc compiler, should be ``0`` or ``1``. + + Returns: + the compiled dynamic library path + + """ + + # phase 1: prepare config + if abi_tag != None: + global MGE_ABI_VER + MGE_ABI_VER = abi_tag + + def strlist(args, name): + assert isinstance(args, str) or isinstance( + args, list + ), "{} must be str or list[str]".format(name) + if isinstance(args, str): + return [args] + for arg in args: + assert isinstance(arg, str) + args = [arg.strip() for arg in args] + return args + + sources = strlist(sources, "sources") + extra_cflags = strlist(extra_cflags, "extra_cflags") + extra_cuda_cflags = strlist(extra_cuda_cflags, "extra_cuda_cflags") + extra_ldflags = strlist(extra_ldflags, "extra_ldflags") + extra_include_paths = strlist(extra_include_paths, "extra_include_paths") + + with_cuda = any(map(_is_cuda_file, sources)) if with_cuda is None else with_cuda + with_cudnn = any(["cudnn" in f for f in extra_ldflags]) + + if CUDA_ROOT_DIR == None and with_cuda: + print( + "No CUDA runtime is found, using {}=/path/to/your/cuda_root_dir".format( + ev_cuda_root_dir + ) + ) + if CUDNN_ROOT_DIR == None and with_cudnn: + print( + "Cannot find the root directory of cudnn, using {}=/path/to/your/cudnn_root_dir".format( + ev_cudnn_root_dir + ) + ) + + build_dir = os.path.abspath( + _get_build_dir(name) if build_dir is None else build_dir + ) + if not os.path.exists(build_dir): + os.makedirs(build_dir, exist_ok=True) + + if verbose: + print("Using {} to build megengine custom op".format(build_dir)) + + # phase 2: version check + version, old_version = version_check( + name, + sources, + [extra_cflags, extra_cuda_cflags, extra_ldflags, extra_include_paths], + build_dir, + with_cuda, + with_cudnn, + abi_tag, + ) + if verbose: + if version != old_version and old_version != None: + print( + "Input conditions of custom op {} have changed, bumping to version {}".format( + name, version + ) + ) + print("Building custom op {} with version {}".format(name, version)) + if version == old_version: + if verbose: + print( + "No modifications detected for {}, skipping build step...".format(name) + ) + return + name = "{}_v{}".format(name, version) + + # phase 3: compiler and ninja check + _check_ninja_availability() + _check_compiler_comatibility() + + # phase 4: setup the compile flags + user_includes, system_includes = _setup_includes( + extra_include_paths, with_cuda, with_cudnn + ) + common_cflags = _setup_common_cflags(user_includes, system_includes) + cuda_cflags = ( + _setup_cuda_cflags(common_cflags, extra_cuda_cflags) if with_cuda else None + ) + ldflags = _setup_ldflags(extra_ldflags, with_cuda, with_cudnn) + + if IS_WINDOWS: + cflags = common_cflags + COMMON_MSVC_FLAGS + extra_cflags + cflags = _nt_quote_args(cflags) + else: + cflags = common_cflags + ["-fPIC", "-std=c++14"] + extra_cflags + + ldflags = _add_shared_flag(ldflags) + if sys.platform.startswith("darwin"): + ldflags.append("-undefined dynamic_lookup") + elif IS_WINDOWS: + ldflags += ["/link"] + ldflags = _nt_quote_args(ldflags) + + baton = FileBaton(os.path.join(build_dir, "lock")) + if baton.try_acquire(): + try: + # phase 5: generate ninja build file + objs = [_obj_file_path(src) for src in sources] + name += ".dll" if IS_WINDOWS else ".so" + + build_file_path = os.path.join(build_dir, "build.ninja") + if verbose: + print("Emitting ninja build file {}".format(build_file_path)) + _dump_ninja_file( + path=build_file_path, + cflags=cflags, + post_cflags=None, + cuda_cflags=cuda_cflags, + cuda_post_cflags=None, + sources=sources, + objects=objs, + ldflags=ldflags, + library_target=name, + with_cuda=with_cuda, + ) + + # phase 6: build with ninja + if verbose: + print( + "Compiling and linking your custom op {}".format( + os.path.join(build_dir, name) + ) + ) + _build_with_ninja(build_dir, verbose, "compiling error") + finally: + baton.release() + else: + baton.wait() + + return os.path.join(build_dir, name) + + +def build_and_load( + name: str, + sources: Union[str, List[str]], + extra_cflags: Union[str, List[str]] = [], + extra_cuda_cflags: Union[str, List[str]] = [], + extra_ldflags: Union[str, List[str]] = [], + extra_include_paths: Union[str, List[str]] = [], + with_cuda: Optional[bool] = None, + build_dir: Optional[bool] = None, + verbose: bool = False, + abi_tag: Optional[int] = None, +) -> str: + r"""Build and Load a Custom Op with ninja in the way of just-in-time (JIT). + Same as the function ``build()`` but load the built dynamic library. + + Args: + same as ``build()`` + + Returns: + the compiled dynamic library path + + """ + + lib_path = build( + name, + sources, + extra_cflags, + extra_cuda_cflags, + extra_ldflags, + extra_include_paths, + with_cuda, + build_dir, + verbose, + abi_tag, + ) + if verbose: + print("Load the compiled custom op {}".format(lib_path)) + load(lib_path) + return lib_path diff --git a/imperative/python/src/ops.cpp b/imperative/python/src/ops.cpp index 30f61a0f..515f8ead 100644 --- a/imperative/python/src/ops.cpp +++ b/imperative/python/src/ops.cpp @@ -766,6 +766,13 @@ void init_custom(pybind11::module m) { m.def("_install", &install_custom); m.def("_uninstall", &uninstall_custom); m.def("_get_custom_op_list", &get_custom_op_list); + m.def("get_custom_op_abi_tag", [](void) -> int { + int ret = 0; +#ifdef _GLIBCXX_USE_CXX11_ABI + ret = _GLIBCXX_USE_CXX11_ABI; +#endif + return ret; + }); static PyMethodDef method_def = { #ifdef METH_FASTCALL diff --git a/imperative/python/test/unit/core/custom_opsrc/elem_add.cpp b/imperative/python/test/unit/core/custom_opsrc/elem_add.cpp new file mode 100644 index 00000000..d8f0299d --- /dev/null +++ b/imperative/python/test/unit/core/custom_opsrc/elem_add.cpp @@ -0,0 +1,140 @@ +/** + * \file imperative/python/test/unit/core/custom_opsrc/elem_add.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#include "megbrain/custom/custom.h" + +CUSTOM_OP_REG_BEGIN(ElemAddSmooth) + +void forward_device_infer( + const std::vector& inputs, const Param& params, + std::vector& outputs) { + outputs[0] = inputs[0]; +} + +void forward_shape_infer( + const std::vector& inputs, const Param& params, + std::vector& outputs) { + outputs[0] = inputs[0]; +} + +void forward_dtype_infer( + const std::vector& inputs, const Param& params, + std::vector& outputs) { + outputs[0] = inputs[0]; +} + +void forward_format_infer( + const std::vector& inputs, const Param& params, + std::vector& outputs) { + outputs[0] = inputs[0]; +} + +template +void forward_kernel( + const scalar_t* input0, const scalar_t* input1, scalar_t* output, size_t len, + float smooth) { + for (size_t i = 0; i < len; ++i) { + output[i] = input0[i] + input1[i]; + if (output[i] < 0) + output[i] += smooth; + else + output[i] -= smooth; + } +} + +void forward_compute( + const std::vector& inputs, const Param& params, + std::vector& outputs) { + DISPATCH_SIGN_INT_AND_FLOAT_TYPES( + outputs[0].dtype(), "forward_compute", ([&]() { + forward_kernel( + inputs[0].data(), inputs[1].data(), + outputs[0].data(), outputs[0].size(), + params["smooth"].as()); + })); +} + +CUSTOM_OP_REG(ElemAddSmoothForward) + .set_description( + "Custom ElemAdd Operator With a Smooth Parameter, " + "which is used to verify the CPU kernel") + .add_input("lhs") + .add_input("rhs") + .add_output("output") + .add_param("smooth", 0.f) + .set_device_infer(forward_device_infer) + .set_shape_infer(forward_shape_infer) + .set_dtype_infer(forward_dtype_infer) + .set_format_infer(forward_format_infer) + .set_compute(forward_compute); + +void backward_device_infer( + const std::vector& ograds, const Param& params, + std::vector& igrads) { + igrads[0] = ograds[0]; + igrads[1] = ograds[0]; +} + +void backward_shape_infer( + const std::vector& ograds, const Param& params, + std::vector& igrads) { + igrads[0] = ograds[0]; + igrads[1] = ograds[0]; +} + +void backward_dtype_infer( + const std::vector& ograds, const Param& params, + std::vector& igrads) { + igrads[0] = ograds[0]; + igrads[1] = ograds[0]; +} + +void backward_format_infer( + const std::vector& ograds, const Param& params, + std::vector& igrads) { + igrads[0] = ograds[0]; + igrads[1] = ograds[0]; +} + +template +void backward_kernel( + const scalar_t* ograd, scalar_t* igrad0, scalar_t* igrad1, size_t len) { + for (size_t i = 0; i < len; ++i) { + igrad0[i] = ograd[i]; + igrad1[i] = ograd[i]; + } +} + +void backward_compute( + const std::vector& ograds, const Param& params, + std::vector& igrads) { + DISPATCH_SIGN_INT_AND_FLOAT_TYPES( + igrads[0].dtype(), "backward_compute", ([&]() { + backward_kernel( + ograds[0].data(), igrads[0].data(), + igrads[1].data(), igrads[0].size()); + })); +} + +CUSTOM_OP_REG(ElemAddSmoothBackward) + .set_description( + "Custom ElemAdd Operator With a Smooth Parameter, " + "which is used to verify the CPU kernel") + .add_input("ograd") + .add_output("igrad_lhs") + .add_output("igrad_rhs") + .set_device_infer(backward_device_infer) + .set_shape_infer(backward_shape_infer) + .set_dtype_infer(backward_dtype_infer) + .set_format_infer(backward_format_infer) + .set_compute(backward_compute); + +CUSTOM_OP_REG_END(ElemAddSmooth) diff --git a/imperative/python/test/unit/core/custom_opsrc/matmul_scale.cpp b/imperative/python/test/unit/core/custom_opsrc/matmul_scale.cpp new file mode 100644 index 00000000..31998dd9 --- /dev/null +++ b/imperative/python/test/unit/core/custom_opsrc/matmul_scale.cpp @@ -0,0 +1,65 @@ +/** + * \file imperative/python/test/unit/core/custom_opsrc/matmul_scale.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#include "./matmul_scale.h" +#include "megbrain/custom/custom.h" + +CUSTOM_OP_REG_BEGIN(MatMulScale) + +void forward_shape_infer( + const std::vector& inputs, const Param& params, + std::vector& outputs) { + outputs[0] = {inputs[0][0], inputs[1][1]}; +} + +void forward_compute( + const std::vector& inputs, const Param& params, + std::vector& outputs) { + matmul_forward_helper( + inputs[0], inputs[1], outputs[0], inputs[0].shape()[0], + inputs[0].shape()[1], inputs[1].shape()[1], params["scale"].as()); +} + +CUSTOM_OP_REG(MatMulScaleForward) + .add_inputs(2) + .add_outputs(1) + .add_param("scale", 1.0f) + .set_shape_infer(forward_shape_infer) + .set_compute("cuda", forward_compute); + +void backward_shape_infer( + const std::vector& ograd_and_inputs, const Param& params, + std::vector& outputs) { + outputs[0] = ograd_and_inputs[1]; + outputs[1] = ograd_and_inputs[2]; +} + +void backward_compute( + const std::vector& ograd_and_inputs, const Param& params, + std::vector& igrads) { + matmul_backward_lhs_helper( + ograd_and_inputs[2], ograd_and_inputs[0], igrads[0], + ograd_and_inputs[1].shape()[0], ograd_and_inputs[1].shape()[1], + ograd_and_inputs[2].shape()[1], params["scale"].as()); + matmul_backward_rhs_helper( + ograd_and_inputs[1], ograd_and_inputs[0], igrads[1], + ograd_and_inputs[1].shape()[0], ograd_and_inputs[1].shape()[1], + ograd_and_inputs[2].shape()[1], params["scale"].as()); +} + +CUSTOM_OP_REG(MatMulScaleBackward) + .add_inputs(3) + .add_outputs(2) + .add_param("scale", 1.0f) + .set_shape_infer(backward_shape_infer) + .set_compute("cuda", backward_compute); + +CUSTOM_OP_REG_END(MatMulScale) diff --git a/imperative/python/test/unit/core/custom_opsrc/matmul_scale.cu b/imperative/python/test/unit/core/custom_opsrc/matmul_scale.cu new file mode 100644 index 00000000..9d847d32 --- /dev/null +++ b/imperative/python/test/unit/core/custom_opsrc/matmul_scale.cu @@ -0,0 +1,97 @@ +/** + * \file imperative/python/test/unit/core/custom_opsrc/matmul_scale.cu + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#include +#include +#include +#include "./matmul_scale.h" + +using namespace custom; + +// matmul_forward for Mat_mxk * Mat_k*n +template +__global__ void matmul_forward_naive( + const T* lhs, const T* rhs, T* res, size_t M, size_t K, size_t N, float scale) { + int row = blockIdx.y * blockDim.y + threadIdx.y; + int col = blockIdx.x * blockDim.x + threadIdx.x; + + T acc = 0; + for (int i = 0; i < K; ++i) + acc += lhs[row * K + i] * rhs[i * N + col]; + res[row * N + col] = acc * scale; +} + +// matmul_backward_lhs for Mat_mxk * Mat_k*n = Mat_mxn +// that is Mat_mxn * Mat_nxk +template +__global__ void matmul_backward_lhs_naive( + const T* rhs, const T* ograd, T* lhs_grad, size_t M, size_t K, size_t N, + float scale) { + int row = blockIdx.y * blockDim.y + threadIdx.y; + int col = blockIdx.x * blockDim.x + threadIdx.x; + T acc = 0; + for (int i = 0; i < N; ++i) + acc += ograd[row * N + i] * rhs[col * N + i]; + lhs_grad[row * K + col] = acc / scale; +} + +// matmul_backward_rhs for Mat_mxk * Mat_k*n = Mat_mxn +// that is Mat_kxm * Mat_mxn +template +__global__ void matmul_backward_rhs_naive( + const T* lhs, const T* ograd, T* rhs_grad, size_t M, size_t K, size_t N, + float scale) { + int row = blockIdx.y * blockDim.y + threadIdx.y; + int col = blockIdx.x * blockDim.x + threadIdx.x; + T acc = 0; + for (int i = 0; i < M; ++i) + acc += lhs[i * K + row] * ograd[i * N + col]; + rhs_grad[row * N + col] = acc / scale; +} + +void matmul_forward_helper( + const Tensor& lhs, const Tensor& rhs, Tensor& res, size_t M, size_t K, size_t N, + float scale) { + dim3 block(1, 1); + dim3 grid(N / block.x, M / block.y); + + DISPATCH_INT_AND_FLOAT_TYPES(res.dtype(), "matmul_forward", ([&]() { + matmul_forward_naive<<>>( + lhs.data(), rhs.data(), + res.data(), M, K, N, scale); + })); +} + +void matmul_backward_lhs_helper( + const Tensor& rhs, const Tensor& ograd, Tensor& lhs_grad, size_t M, size_t K, + size_t N, float scale) { + dim3 block(1, 1); + dim3 grid(K / block.x, M / block.y); + DISPATCH_INT_AND_FLOAT_TYPES( + lhs_grad.dtype(), "matmul_backward_lhs", ([&]() { + matmul_backward_lhs_naive<<>>( + rhs.data(), ograd.data(), + lhs_grad.data(), M, K, N, scale); + })); +} + +void matmul_backward_rhs_helper( + const Tensor& lhs, const Tensor& ograd, Tensor& rhs_grad, size_t M, size_t K, + size_t N, float scale) { + dim3 block(1, 1); + dim3 grid(N / block.x, K / block.y); + DISPATCH_INT_AND_FLOAT_TYPES( + rhs_grad.dtype(), "matmul_backward_rhs", ([&]() { + matmul_backward_rhs_naive<<>>( + lhs.data(), ograd.data(), + rhs_grad.data(), M, K, N, scale); + })); +} diff --git a/imperative/python/test/unit/core/custom_opsrc/matmul_scale.h b/imperative/python/test/unit/core/custom_opsrc/matmul_scale.h new file mode 100644 index 00000000..5f7ea8d0 --- /dev/null +++ b/imperative/python/test/unit/core/custom_opsrc/matmul_scale.h @@ -0,0 +1,24 @@ +/** + * \file imperative/python/test/unit/core/custom_opsrc/matmul_scale.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#include "megbrain/custom/custom.h" + +using Tensor = custom::Tensor; + +void matmul_forward_helper( + const Tensor& lhs, const Tensor& rhs, Tensor& res, size_t M, size_t K, size_t N, + float scale); +void matmul_backward_lhs_helper( + const Tensor& rhs, const Tensor& ograd, Tensor& lhs_grad, size_t M, size_t K, + size_t N, float scale); +void matmul_backward_rhs_helper( + const Tensor& lhs, const Tensor& ograd, Tensor& rhs_grad, size_t M, size_t K, + size_t N, float scale); diff --git a/imperative/python/test/unit/core/test_custom_op.py b/imperative/python/test/unit/core/test_custom_op.py new file mode 100644 index 00000000..e2a9e4b2 --- /dev/null +++ b/imperative/python/test/unit/core/test_custom_op.py @@ -0,0 +1,111 @@ +# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") +# +# Copyright (c) 2014-2021 Megvii Inc. All rights reserved. +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + +import os +import platform +import shutil +import sys + +import numpy as np +import pytest + +import megengine +import megengine.functional as F +import megengine.optimizer as optim +from megengine import jit +from megengine.autodiff import Function, GradManager +from megengine.core._imperative_rt.core2 import apply +from megengine.core.ops import custom +from megengine.device import get_device_count +from megengine.module import Conv2d, Linear, Module +from megengine.random import normal +from megengine.tensor import Parameter, Tensor +from megengine.utils import custom_op_tools + + +def compare(ref, real): + if ref.shape != real.shape: + real = real.T + np.testing.assert_allclose(ref, real, rtol=1e-3, atol=1e-5) + + +def build_and_clean(test_func): + def wrapper(): + cur_dir_path = os.path.dirname(os.path.abspath(__file__)) + build_path = os.path.join(cur_dir_path, "custom_opsrc", "build") + mgb_root_path = os.path.dirname( + os.path.dirname( + os.path.dirname(os.path.dirname(os.path.dirname(cur_dir_path))) + ) + ) + extra_include_paths = [os.path.join(mgb_root_path, "src", "custom", "include")] + extra_ld_flags = [] + + if sys.platform != "win32": + ld_path = os.environ.get("LD_LIBRARY_PATH") + if ld_path != None: + ld_dirs = ld_path.split(":") + for ld_dir in ld_dirs: + if os.path.exists(ld_dir) and os.path.isdir(ld_dir): + for lib in os.listdir(ld_dir): + if "megengine_shared" in lib: + extra_ld_flags += [ + "-L{} -Wl,-rpath,{}".format(ld_dir, ld_dir) + ] + break + + if get_device_count("gpu") > 0: + custom_opsrc = [ + os.path.join(cur_dir_path, "custom_opsrc", "matmul_scale.cpp"), + os.path.join(cur_dir_path, "custom_opsrc", "matmul_scale.cu"), + ] + else: + custom_opsrc = [os.path.join(cur_dir_path, "custom_opsrc", "elem_add.cpp")] + + lib_path = custom_op_tools.build_and_load( + "test_op", + custom_opsrc, + extra_include_paths=extra_include_paths, + extra_ldflags=extra_ld_flags, + build_dir=build_path, + verbose=False, + abi_tag=custom.get_custom_op_abi_tag(), + ) + test_func() + + custom.unload(lib_path) + if os.path.exists(build_path): + shutil.rmtree(build_path) + + return wrapper + + +@pytest.mark.skipif( + get_device_count("gpu") > 0, reason="elem_add operator is only supported on CPU" +) +@build_and_clean +def test_custom_op_cpu_build(): + assert "ElemAddSmoothForward" in custom._get_custom_op_list() + assert "ElemAddSmoothBackward" in custom._get_custom_op_list() + assert hasattr(custom, "ElemAddSmoothForward") + assert hasattr(custom, "ElemAddSmoothBackward") + + +@pytest.mark.skipif( + platform.system() == "Darwin", + reason="GPU kernel is only support on Linux and Windows", +) +@pytest.mark.skipif( + get_device_count("gpu") < 1, reason="matmul scale operator is only supported on GPU" +) +@build_and_clean +def test_custom_op_gpu_build(): + assert "MatMulScaleForward" in custom._get_custom_op_list() + assert "MatMulScaleBackward" in custom._get_custom_op_list() + assert hasattr(custom, "MatMulScaleForward") + assert hasattr(custom, "MatMulScaleBackward") diff --git a/scripts/whl/macos/macos_build_whl.sh b/scripts/whl/macos/macos_build_whl.sh index c411c67d..c20559a7 100755 --- a/scripts/whl/macos/macos_build_whl.sh +++ b/scripts/whl/macos/macos_build_whl.sh @@ -171,6 +171,7 @@ function do_build() { mkdir -p staging cp -a imperative/python/{megengine,setup.py,requires.txt,requires-style.txt,requires-test.txt} staging/ + cp -a ${SRC_DIR}/src/custom/include staging/megengine/core/include/ cd ${BUILD_DIR}/staging/megengine/core rt_file=`ls _imperative_rt.*.so` echo "rt file is: ${rt_file}" diff --git a/scripts/whl/manylinux2014/do_build_common.sh b/scripts/whl/manylinux2014/do_build_common.sh index 2df0dbff..0f149724 100755 --- a/scripts/whl/manylinux2014/do_build_common.sh +++ b/scripts/whl/manylinux2014/do_build_common.sh @@ -151,6 +151,7 @@ do rm -rf staging mkdir -p staging cp -a imperative/python/{megengine,setup.py,requires.txt,requires-style.txt,requires-test.txt} staging/ + cp -a ${SRC_DIR}/src/custom/include/megbrain staging/megengine/core/include cd ${BUILD_DIR}/staging/megengine/core mkdir -p lib/ucx diff --git a/scripts/whl/windows/windows_build_whl.sh b/scripts/whl/windows/windows_build_whl.sh index b3824fbc..d33cb5c5 100755 --- a/scripts/whl/windows/windows_build_whl.sh +++ b/scripts/whl/windows/windows_build_whl.sh @@ -77,11 +77,13 @@ CUBLAS_LIB="/c/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v10.1/bin/cublas6 CURAND_LIB="/c/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v10.1/bin/curand64_10.dll" CUBLASLT_LIB="/c/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v10.1/bin/cublasLt64_10.dll" CUDART_LIB="/c/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v10.1/bin/cudart64_101.dll" -MGE_EXPORT_LIB="${SRC_DIR}/build_dir/host/build/src/megengine_shared.dll" +MGE_EXPORT_DLL="${SRC_DIR}/build_dir/host/build/src/megengine_shared.dll" +MGE_EXPORT_LIB="${SRC_DIR}/build_dir/host/build/src/megengine_shared.lib" function depend_real_copy() { REAL_DST=$1 echo "real copy lib to $1" + cp "${MGE_EXPORT_DLL}" ${REAL_DST} cp "${MGE_EXPORT_LIB}" ${REAL_DST} if [ ${BUILD_WHL_CPU_ONLY} = "OFF" ]; then @@ -190,6 +192,7 @@ function do_build() { rm -rf staging mkdir -p staging cp -a imperative/python/{megengine,setup.py,requires.txt,requires-style.txt,requires-test.txt} staging/ + cp -a ${SRC_DIR}/src/custom/include/megbrain staging/megengine/core/include/ cd ${BUILD_DIR}/staging/megengine/core rt_file=`ls _imperative_rt.*.pyd` echo "rt file is: ${rt_file}" diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 57664c74..30552579 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -1,3 +1,8 @@ +# force define a SHARED target for whl, caused by when build for APPLE +# we will force set BUILD_SHARED_LIBS=OFF for xcode needed +set(MGE_SHARED_LIB megengine_shared) +set(MGE_SHARED_LIB ${MGE_SHARED_LIB} PARENT_SCOPE) + if(MGE_WITH_JIT_MLIR) add_subdirectory(jit/include/megbrain/jit/mlir/ir) endif() @@ -206,32 +211,30 @@ set (_VER_FILE ${PROJECT_SOURCE_DIR}/src/version.ld) # Build as SHARED or STATIC depending on BUILD_SHARED_LIBS=ON/OFF add_library(megengine) -# force define a SHARED target for whl, caused by when build for APPLE -# we will force set BUILD_SHARED_LIBS=OFF for xcode needed -add_library(megengine_shared SHARED) +add_library(${MGE_SHARED_LIB} SHARED) target_link_libraries(megengine PRIVATE ${MGE_CUDA_LIBS}) target_link_libraries(megengine PUBLIC megbrain megdnn) -target_link_libraries(megengine_shared PUBLIC megbrain megdnn) -target_link_libraries(megengine_shared PRIVATE ${MGE_CUDA_LIBS}) +target_link_libraries(${MGE_SHARED_LIB} PUBLIC megbrain megdnn) +target_link_libraries(${MGE_SHARED_LIB} PRIVATE ${MGE_CUDA_LIBS}) if (UNIX AND NOT APPLE) target_link_options(megengine PRIVATE -Wl,--no-undefined -Wl,--version-script=${_VER_FILE}) set_target_properties(megengine PROPERTIES LINK_DEPENDS ${_VER_FILE}) - target_link_options(megengine_shared PRIVATE -Wl,--no-undefined -Wl,--version-script=${_VER_FILE}) - set_target_properties(megengine_shared PROPERTIES LINK_DEPENDS ${_VER_FILE}) + target_link_options(${MGE_SHARED_LIB} PRIVATE -Wl,--no-undefined -Wl,--version-script=${_VER_FILE}) + set_target_properties(${MGE_SHARED_LIB} PROPERTIES LINK_DEPENDS ${_VER_FILE}) endif() if(WIN32 OR MSVC) target_compile_definitions(megbrain PRIVATE MGE_DLL_EXPORT) target_compile_definitions(megdnn PRIVATE MGE_DLL_EXPORT) target_compile_definitions(megengine PRIVATE MGE_DLL_EXPORT) - target_compile_definitions(megengine_shared PRIVATE MGE_DLL_EXPORT) + target_compile_definitions(${MGE_SHARED_LIB} PRIVATE MGE_DLL_EXPORT) # please do not use WINDOWS_EXPORT_ALL_SYMBOLS, as symbols max than 65535 when build with CUDA #set_target_properties(megengine PROPERTIES WINDOWS_EXPORT_ALL_SYMBOLS TRUE) - #set_target_properties(megengine_shared PROPERTIES WINDOWS_EXPORT_ALL_SYMBOLS TRUE) + #set_target_properties(${MGE_SHARED_LIB} PROPERTIES WINDOWS_EXPORT_ALL_SYMBOLS TRUE) endif() if (MGE_WITH_DISTRIBUTED) message(VERBOSE "megengine configured to link megray") target_link_libraries(megengine PUBLIC megray) - target_link_libraries(megengine_shared PUBLIC megray) + target_link_libraries(${MGE_SHARED_LIB} PUBLIC megray) endif() # Do not export targets if MGE_WITH_DISTRIBUTED is on. MegRay is not ready # for this. diff --git a/src/custom/impl/manager.cpp b/src/custom/impl/manager.cpp index 3de0986f..39419d7e 100644 --- a/src/custom/impl/manager.cpp +++ b/src/custom/impl/manager.cpp @@ -18,12 +18,31 @@ #ifndef _WIN32 #include +#else +#include #endif using namespace mgb; namespace custom { +#ifdef _WIN32 +#define RTLD_LAZY 0 + +void* dlopen(const char* file, int) { + return static_cast(LoadLibrary(file)); +} + +int dlclose(void* handle) { + return static_cast(FreeLibrary(static_cast(handle))); +} + +const char* dlerror(void) { + static char win_err_info[] = "no dlerror info in windows"; + return win_err_info; +} +#endif + CustomOpManager* CustomOpManager::inst(void) { static CustomOpManager op_manager; return &op_manager; @@ -127,7 +146,6 @@ std::vector CustomOpManager::op_id_list(void) { return ret; } -#ifndef _WIN32 CustomLib::CustomLib(const std::string& path, int mode = RTLD_LAZY) : m_handle(nullptr, [](void* handle) { dlclose(handle); }) { auto op_list_before_load = CustomOpManager::inst()->op_name_list(); @@ -146,12 +164,6 @@ CustomLib::CustomLib(const std::string& path, int mode = RTLD_LAZY) } } } -#else -CustomLib::CustomLib(const std::string& path, int mode = 0) - : m_handle(nullptr, [](void* handle) {}) { - mgb_assert(false, "custom op is only supported on Linux now"); -} -#endif const std::vector& CustomLib::ops_in_lib(void) const { return m_ops; diff --git a/src/custom/include/megbrain/custom/custom.h b/src/custom/include/megbrain/custom/custom.h index e6751f25..726076a7 100644 --- a/src/custom/include/megbrain/custom/custom.h +++ b/src/custom/include/megbrain/custom/custom.h @@ -16,7 +16,8 @@ #include "tensor.h" namespace custom { -std::shared_ptr op_insert(std::string opname, uint32_t version); +MGE_WIN_DECLSPEC_FUC std::shared_ptr op_insert( + std::string opname, uint32_t version); } #define CUSTOM_OP_REG(OpName) \ diff --git a/src/custom/include/megbrain/custom/op.h b/src/custom/include/megbrain/custom/op.h index 2646ce56..b1afc801 100644 --- a/src/custom/include/megbrain/custom/op.h +++ b/src/custom/include/megbrain/custom/op.h @@ -32,27 +32,26 @@ namespace custom { using RunTimeId = uint64_t; -class ArgInfo { +class MGE_WIN_DECLSPEC_FUC ArgInfo { CUSTOM_PIMPL_CLS_DECL(ArgInfo); - MGE_WIN_DECLSPEC_FUC ArgInfo( - const std::string& name, const std::string& desc, + ArgInfo(const std::string& name, const std::string& desc, const std::unordered_set& dtypes, const int& ndim, const std::string& mem_stgy); - MGE_WIN_DECLSPEC_FUC const std::string& name(void) const; - MGE_WIN_DECLSPEC_FUC const std::string& desc(void) const; - MGE_WIN_DECLSPEC_FUC const std::unordered_set& dtypes(void) const; - MGE_WIN_DECLSPEC_FUC int ndim(void) const; - MGE_WIN_DECLSPEC_FUC const std::string& mem_strategy(void) const; + const std::string& name(void) const; + const std::string& desc(void) const; + const std::unordered_set& dtypes(void) const; + int ndim(void) const; + const std::string& mem_strategy(void) const; - MGE_WIN_DECLSPEC_FUC std::string str() const; + std::string str() const; }; -class CustomOp { +class MGE_WIN_DECLSPEC_FUC CustomOp { std::unique_ptr m_impl; public: - MGE_WIN_DECLSPEC_FUC CustomOp(const std::string& op_type, uint32_t version); + CustomOp(const std::string& op_type, uint32_t version); PREVENT_COPY_AND_ASSIGN(CustomOp); using DeviceInferFuncPtr = @@ -71,70 +70,65 @@ public: void (*)(const std::vector&, const Param&, std::vector&); // write for forward - MGE_WIN_DECLSPEC_FUC CustomOp& set_device_infer(DeviceInferFuncPtr func); - MGE_WIN_DECLSPEC_FUC CustomOp& set_shape_infer(ShapeInferFuncPtr func); - MGE_WIN_DECLSPEC_FUC CustomOp& set_dtype_infer(DTypeInferFuncPtr func); - MGE_WIN_DECLSPEC_FUC CustomOp& set_format_infer(FormatInferFuncPtr func); - MGE_WIN_DECLSPEC_FUC CustomOp& set_preprocess(PreprocessFuncPtr func); - MGE_WIN_DECLSPEC_FUC CustomOp& set_preprocess( - const std::string& device, PreprocessFuncPtr func); - MGE_WIN_DECLSPEC_FUC CustomOp& set_postprocess(PostprocessFuncPtr func); - MGE_WIN_DECLSPEC_FUC CustomOp& set_postprocess( - const std::string& device, PostprocessFuncPtr func); - MGE_WIN_DECLSPEC_FUC CustomOp& set_compute(ComputeFuncPtr func); - MGE_WIN_DECLSPEC_FUC CustomOp& set_compute( - const std::string& device, ComputeFuncPtr func); - - MGE_WIN_DECLSPEC_FUC CustomOp& set_description(const std::string& op_desc); - MGE_WIN_DECLSPEC_FUC CustomOp& add_input( + CustomOp& set_device_infer(DeviceInferFuncPtr func); + CustomOp& set_shape_infer(ShapeInferFuncPtr func); + CustomOp& set_dtype_infer(DTypeInferFuncPtr func); + CustomOp& set_format_infer(FormatInferFuncPtr func); + CustomOp& set_preprocess(PreprocessFuncPtr func); + CustomOp& set_preprocess(const std::string& device, PreprocessFuncPtr func); + CustomOp& set_postprocess(PostprocessFuncPtr func); + CustomOp& set_postprocess(const std::string& device, PostprocessFuncPtr func); + CustomOp& set_compute(ComputeFuncPtr func); + CustomOp& set_compute(const std::string& device, ComputeFuncPtr func); + + CustomOp& set_description(const std::string& op_desc); + CustomOp& add_input( const std::string& name, const std::string& desc, const std::initializer_list& legal_dtypes = {"float32"}, int dims = -1, const std::string& mem_stgy = "default"); - MGE_WIN_DECLSPEC_FUC CustomOp& add_output( + CustomOp& add_output( const std::string& name, const std::string& desc, const std::initializer_list& legal_dtypes = {"float32"}, int dims = -1, const std::string& mem_stgy = "default"); - MGE_WIN_DECLSPEC_FUC CustomOp& add_input( + CustomOp& add_input( const std::string& name, const std::initializer_list& legal_dtypes = {"float32"}, int dims = -1, const std::string& mem_stgy = "default"); - MGE_WIN_DECLSPEC_FUC CustomOp& add_output( + CustomOp& add_output( const std::string& name, const std::initializer_list& legal_dtypes = {"float32"}, int dims = -1, const std::string& mem_stgy = "default"); - MGE_WIN_DECLSPEC_FUC CustomOp& add_inputs(const size_t& input_num); - MGE_WIN_DECLSPEC_FUC CustomOp& add_outputs(const size_t& output_num); - MGE_WIN_DECLSPEC_FUC CustomOp& add_param( - const std::string& name, const ParamVal& default_val); - MGE_WIN_DECLSPEC_FUC CustomOp& add_param( + CustomOp& add_inputs(const size_t& input_num); + CustomOp& add_outputs(const size_t& output_num); + CustomOp& add_param(const std::string& name, const ParamVal& default_val); + CustomOp& add_param( const std::string& name, const std::string& desc, const ParamVal& default_val); // read - MGE_WIN_DECLSPEC_FUC std::string op_type(void) const; - MGE_WIN_DECLSPEC_FUC std::string op_desc(void) const; - MGE_WIN_DECLSPEC_FUC RunTimeId runtime_id(void) const; - MGE_WIN_DECLSPEC_FUC size_t input_num(void) const; - MGE_WIN_DECLSPEC_FUC size_t output_num(void) const; - MGE_WIN_DECLSPEC_FUC std::string str(void) const; - - MGE_WIN_DECLSPEC_FUC const ParamInfo& param_info(void) const; - MGE_WIN_DECLSPEC_FUC ArgInfo input_info(size_t idx) const; - MGE_WIN_DECLSPEC_FUC ArgInfo output_info(size_t idx) const; - MGE_WIN_DECLSPEC_FUC const std::vector& inputs_info(void) const; - MGE_WIN_DECLSPEC_FUC const std::vector& outputs_info(void) const; + std::string op_type(void) const; + std::string op_desc(void) const; + RunTimeId runtime_id(void) const; + size_t input_num(void) const; + size_t output_num(void) const; + std::string str(void) const; + + const ParamInfo& param_info(void) const; + ArgInfo input_info(size_t idx) const; + ArgInfo output_info(size_t idx) const; + const std::vector& inputs_info(void) const; + const std::vector& outputs_info(void) const; // use - MGE_WIN_DECLSPEC_FUC std::vector infer_output_device( + std::vector infer_output_device( const std::vector&, const Param&) const; - MGE_WIN_DECLSPEC_FUC std::vector infer_output_shape( + std::vector infer_output_shape( const std::vector&, const Param&) const; - MGE_WIN_DECLSPEC_FUC std::vector infer_output_dtype( + std::vector infer_output_dtype( const std::vector&, const Param&) const; - MGE_WIN_DECLSPEC_FUC std::vector infer_output_format( + std::vector infer_output_format( const std::vector&, const Param&) const; - MGE_WIN_DECLSPEC_FUC void compute( - const std::vector&, const Param&, std::vector&) const; + void compute(const std::vector&, const Param&, std::vector&) const; }; } // namespace custom diff --git a/src/custom/include/megbrain/custom/param.h b/src/custom/include/megbrain/custom/param.h index d895d913..f90a2674 100644 --- a/src/custom/include/megbrain/custom/param.h +++ b/src/custom/include/megbrain/custom/param.h @@ -23,7 +23,7 @@ class ParamInfoImpl; class ParamImpl; // Schema of a param element -class ParamSchema { +class MGE_WIN_DECLSPEC_FUC ParamSchema { CUSTOM_PIMPL_CLS_DECL(ParamSchema); ParamSchema( const std::string& name, const ParamVal& value, @@ -36,7 +36,7 @@ class ParamSchema { std::string str(void) const; }; -class ParamInfo { +class MGE_WIN_DECLSPEC_FUC ParamInfo { CUSTOM_PIMPL_CLS_DECL(ParamInfo); void set_tag(const std::string&); @@ -46,16 +46,16 @@ class ParamInfo { const std::vector& meta(void) const; }; -class Param { +class MGE_WIN_DECLSPEC_FUC Param { CUSTOM_PIMPL_CLS_DECL(Param); - MGE_WIN_DECLSPEC_FUC Param(const ParamInfo&); - MGE_WIN_DECLSPEC_FUC ParamVal& operator[](const std::string&); - MGE_WIN_DECLSPEC_FUC const ParamVal& operator[](const std::string&) const; - MGE_WIN_DECLSPEC_FUC const std::unordered_map& raw() const; - MGE_WIN_DECLSPEC_FUC bool exist(const std::string& name) const; - MGE_WIN_DECLSPEC_FUC std::string to_bytes(void) const; - MGE_WIN_DECLSPEC_FUC void from_bytes(const std::string&); + Param(const ParamInfo&); + ParamVal& operator[](const std::string&); + const ParamVal& operator[](const std::string&) const; + const std::unordered_map& raw() const; + bool exist(const std::string& name) const; + std::string to_bytes(void) const; + void from_bytes(const std::string&); }; MGE_WIN_DECLSPEC_FUC bool operator==(const Param&, const Param&); diff --git a/src/custom/include/megbrain/custom/param_val.h b/src/custom/include/megbrain/custom/param_val.h index 31b2a4b6..d7f3b521 100644 --- a/src/custom/include/megbrain/custom/param_val.h +++ b/src/custom/include/megbrain/custom/param_val.h @@ -169,21 +169,21 @@ std::string vec2str(const std::vector& vec) { * Con1: user need to set the type explicitly when class template instantiation * Con2: ParamVal can not be assigned to ParamVal */ -class ParamVal { +class MGE_WIN_DECLSPEC_FUC ParamVal { std::unique_ptr m_ptr; ParamDynType m_type; public: template - MGE_WIN_DECLSPEC_FUC ParamVal(const T& val); + ParamVal(const T& val); template - MGE_WIN_DECLSPEC_FUC ParamVal(const std::initializer_list& val); + ParamVal(const std::initializer_list& val); - MGE_WIN_DECLSPEC_FUC ParamVal(); - MGE_WIN_DECLSPEC_FUC ParamVal(const char* str); - MGE_WIN_DECLSPEC_FUC ParamVal(const std::initializer_list& strs); - MGE_WIN_DECLSPEC_FUC ParamVal(const std::vector& strs); - MGE_WIN_DECLSPEC_FUC ParamVal(const ParamVal& rhs); + ParamVal(); + ParamVal(const char* str); + ParamVal(const std::initializer_list& strs); + ParamVal(const std::vector& strs); + ParamVal(const ParamVal& rhs); template ParamVal& operator=(const T& rhs); @@ -196,30 +196,39 @@ public: ParamVal& operator=(const ParamVal& rhs); template - MGE_WIN_DECLSPEC_FUC const T& as(void) const; + const T& as(void) const; template - MGE_WIN_DECLSPEC_FUC T& as(void); - - MGE_WIN_DECLSPEC_FUC const void* raw_ptr(void) const; - MGE_WIN_DECLSPEC_FUC void* raw_ptr(void); - MGE_WIN_DECLSPEC_FUC ParamDynType type(void) const; - MGE_WIN_DECLSPEC_FUC std::string str(void) const; - MGE_WIN_DECLSPEC_FUC size_t size(void) const; - - MGE_WIN_DECLSPEC_FUC static std::string to_bytes(const ParamVal& value); - MGE_WIN_DECLSPEC_FUC static ParamVal from_bytes( - const std::string& bytes, size_t& offset); - - friend ParamVal operator+(const ParamVal& lhs, const ParamVal& rhs); - friend ParamVal operator-(const ParamVal& lhs, const ParamVal& rhs); - friend ParamVal operator*(const ParamVal& lhs, const ParamVal& rhs); - friend ParamVal operator/(const ParamVal& lhs, const ParamVal& rhs); - friend bool operator==(const ParamVal& lhs, const ParamVal& rhs); - friend bool operator!=(const ParamVal& lhs, const ParamVal& rhs); - friend bool operator>(const ParamVal& lhs, const ParamVal& rhs); - friend bool operator<(const ParamVal& lhs, const ParamVal& rhs); - friend bool operator>=(const ParamVal& lhs, const ParamVal& rhs); - friend bool operator<=(const ParamVal& lhs, const ParamVal& rhs); + T& as(void); + + const void* raw_ptr(void) const; + void* raw_ptr(void); + ParamDynType type(void) const; + std::string str(void) const; + size_t size(void) const; + + static std::string to_bytes(const ParamVal& value); + static ParamVal from_bytes(const std::string& bytes, size_t& offset); + + MGE_WIN_DECLSPEC_FUC friend ParamVal operator+( + const ParamVal& lhs, const ParamVal& rhs); + MGE_WIN_DECLSPEC_FUC friend ParamVal operator-( + const ParamVal& lhs, const ParamVal& rhs); + MGE_WIN_DECLSPEC_FUC friend ParamVal operator*( + const ParamVal& lhs, const ParamVal& rhs); + MGE_WIN_DECLSPEC_FUC friend ParamVal operator/( + const ParamVal& lhs, const ParamVal& rhs); + MGE_WIN_DECLSPEC_FUC friend bool operator==( + const ParamVal& lhs, const ParamVal& rhs); + MGE_WIN_DECLSPEC_FUC friend bool operator!=( + const ParamVal& lhs, const ParamVal& rhs); + MGE_WIN_DECLSPEC_FUC friend bool operator>( + const ParamVal& lhs, const ParamVal& rhs); + MGE_WIN_DECLSPEC_FUC friend bool operator<( + const ParamVal& lhs, const ParamVal& rhs); + MGE_WIN_DECLSPEC_FUC friend bool operator>=( + const ParamVal& lhs, const ParamVal& rhs); + MGE_WIN_DECLSPEC_FUC friend bool operator<=( + const ParamVal& lhs, const ParamVal& rhs); }; ParamVal operator+(const ParamVal& lhs, const ParamVal& rhs); diff --git a/src/custom/include/megbrain/custom/tensor.h b/src/custom/include/megbrain/custom/tensor.h index a1dd9ba5..53c54dc8 100644 --- a/src/custom/include/megbrain/custom/tensor.h +++ b/src/custom/include/megbrain/custom/tensor.h @@ -30,9 +30,9 @@ namespace custom { #define CUSTOM_DEVICE_TYPE_ENUM_DECL(custom_type, builtin_type, builtin_str) \ custom_type, -class Device { - MGE_WIN_DECLSPEC_FUC const void* impl() const; - MGE_WIN_DECLSPEC_FUC Device(const void* impl); +class MGE_WIN_DECLSPEC_FUC Device { + const void* impl() const; + Device(const void* impl); CUSTOM_PIMPL_CLS_DECL(Device); public: @@ -40,19 +40,19 @@ public: CUSTOM_FOR_EACH_DEVICE_TYPE(CUSTOM_DEVICE_TYPE_ENUM_DECL) }; - MGE_WIN_DECLSPEC_FUC Device(const std::string& device); - MGE_WIN_DECLSPEC_FUC Device(const char* device); - MGE_WIN_DECLSPEC_FUC Device(DeviceEnum device); + Device(const std::string& device); + Device(const char* device); + Device(DeviceEnum device); - MGE_WIN_DECLSPEC_FUC std::string str(void) const; - MGE_WIN_DECLSPEC_FUC DeviceEnum enumv(void) const; + std::string str(void) const; + DeviceEnum enumv(void) const; - MGE_WIN_DECLSPEC_FUC static bool is_legal(const std::string& device); - MGE_WIN_DECLSPEC_FUC static bool is_legal(DeviceEnum device); - MGE_WIN_DECLSPEC_FUC static std::vector legal_devices(void); + static bool is_legal(const std::string& device); + static bool is_legal(DeviceEnum device); + static std::vector legal_devices(void); friend class Tensor; - friend bool operator==(const Device& lhs, const Device& rhs); + MGE_WIN_DECLSPEC_FUC friend bool operator==(const Device& lhs, const Device& rhs); CUSTOM_DATA_ADAPTOR_FRIEND_DECL; }; @@ -60,23 +60,23 @@ using DeviceEnum = Device::DeviceEnum; bool operator==(const Device& lhs, const Device& rhs); -class Shape { - MGE_WIN_DECLSPEC_FUC const void* impl() const; - MGE_WIN_DECLSPEC_FUC Shape(const void* impl); +class MGE_WIN_DECLSPEC_FUC Shape { + const void* impl() const; + Shape(const void* impl); CUSTOM_PIMPL_CLS_DECL(Shape); public: - MGE_WIN_DECLSPEC_FUC Shape(const std::vector& rhs); - MGE_WIN_DECLSPEC_FUC Shape(const std::initializer_list& rhs); + Shape(const std::vector& rhs); + Shape(const std::initializer_list& rhs); size_t& operator[](size_t idx); size_t operator[](size_t idx) const; - MGE_WIN_DECLSPEC_FUC void ndim(size_t dim); - MGE_WIN_DECLSPEC_FUC size_t ndim(void) const; + void ndim(size_t dim); + size_t ndim(void) const; friend class Tensor; - friend bool operator==(const Shape& lhs, const Shape& rhs); + MGE_WIN_DECLSPEC_FUC friend bool operator==(const Shape& lhs, const Shape& rhs); CUSTOM_DATA_ADAPTOR_FRIEND_DECL; }; @@ -104,9 +104,9 @@ using bfloat16_t = uint16_t; #define CUSTOM_DTYPE_ENUM_DECL(custom_type, builtin_type, ctype) custom_type, -class DType { - MGE_WIN_DECLSPEC_FUC const void* impl() const; - MGE_WIN_DECLSPEC_FUC DType(const void* impl); +class MGE_WIN_DECLSPEC_FUC DType { + const void* impl() const; + DType(const void* impl); CUSTOM_PIMPL_CLS_DECL(DType); public: @@ -114,27 +114,33 @@ public: CUSTOM_FOR_EACH_TENSOR_DATA_TYPE(CUSTOM_DTYPE_ENUM_DECL) }; - MGE_WIN_DECLSPEC_FUC DType(const std::string& dtype); - MGE_WIN_DECLSPEC_FUC DType(const char* dtype); - MGE_WIN_DECLSPEC_FUC DType( - const std::string& dtype, float scale, uint8_t zero_point = 0); - MGE_WIN_DECLSPEC_FUC DType(const char* dtype, float scale, uint8_t zero_point = 0); - MGE_WIN_DECLSPEC_FUC DType(DTypeEnum dtype); - MGE_WIN_DECLSPEC_FUC DType(DTypeEnum dtype, float scale, uint8_t zero_point = 0); - - MGE_WIN_DECLSPEC_FUC std::string str(void) const; - MGE_WIN_DECLSPEC_FUC DTypeEnum enumv() const; - MGE_WIN_DECLSPEC_FUC float scale(void) const; - MGE_WIN_DECLSPEC_FUC uint8_t zero_point(void) const; + DType(const std::string& dtype); + DType(const char* dtype); + DType(const std::string& dtype, float scale, uint8_t zero_point = 0); + DType(const char* dtype, float scale, uint8_t zero_point = 0); + DType(DTypeEnum dtype); + DType(DTypeEnum dtype, float scale, uint8_t zero_point = 0); + + std::string str(void) const; + DTypeEnum enumv() const; + float scale(void) const; + uint8_t zero_point(void) const; template - MGE_WIN_DECLSPEC_FUC bool is_compatible(void) const; + bool is_compatible(void) const; - MGE_WIN_DECLSPEC_FUC static bool is_legal(const std::string& dtype); - MGE_WIN_DECLSPEC_FUC static bool is_legal(const DTypeEnum& dtype); - MGE_WIN_DECLSPEC_FUC static std::vector legal_dtypes(void); + static bool is_legal(const std::string& dtype); + static bool is_legal(const DTypeEnum& dtype); + static std::vector legal_dtypes(void); friend class Tensor; - friend bool operator==(const DType& lhs, const DType& rhs); + MGE_WIN_DECLSPEC_FUC friend bool operator==(const DType& lhs, const DType& rhs); + MGE_WIN_DECLSPEC_FUC friend bool operator==( + const DType& lhs, const std::string& rhs); + MGE_WIN_DECLSPEC_FUC friend bool operator==(const DType& lhs, const char* rhs); + MGE_WIN_DECLSPEC_FUC friend bool operator==( + const std::string& lhs, const DType& rhs); + MGE_WIN_DECLSPEC_FUC friend bool operator==(const char* lhs, const DType& rhs); + CUSTOM_DATA_ADAPTOR_FRIEND_DECL; }; @@ -180,45 +186,45 @@ bool operator==(const DType& lhs, const char* rhs); bool operator==(const std::string& lhs, const DType& rhs); bool operator==(const char* lhs, const DType& rhs); -class Format { - MGE_WIN_DECLSPEC_FUC const void* impl() const; - MGE_WIN_DECLSPEC_FUC Format(const void* impl); +class MGE_WIN_DECLSPEC_FUC Format { + const void* impl() const; + Format(const void* impl); CUSTOM_PIMPL_CLS_DECL(Format); public: - MGE_WIN_DECLSPEC_FUC Format(const std::string& format); - MGE_WIN_DECLSPEC_FUC Format(const char* format); + Format(const std::string& format); + Format(const char* format); - MGE_WIN_DECLSPEC_FUC std::string str(void) const; - MGE_WIN_DECLSPEC_FUC bool is_default(void) const; + std::string str(void) const; + bool is_default(void) const; friend class Tensor; CUSTOM_DATA_ADAPTOR_FRIEND_DECL; }; -class Tensor { +class MGE_WIN_DECLSPEC_FUC Tensor { void* m_tensor; - MGE_WIN_DECLSPEC_FUC const void* impl(void) const; - MGE_WIN_DECLSPEC_FUC Tensor(const void* impl); + const void* impl(void) const; + Tensor(const void* impl); - MGE_WIN_DECLSPEC_FUC const size_t* shapes_raw(void) const; - MGE_WIN_DECLSPEC_FUC const ptrdiff_t* strides_raw(void) const; + const size_t* shapes_raw(void) const; + const ptrdiff_t* strides_raw(void) const; public: Tensor() = delete; - MGE_WIN_DECLSPEC_FUC Tensor(const Tensor& rhs); - MGE_WIN_DECLSPEC_FUC Tensor& operator=(const Tensor& rhs); - - MGE_WIN_DECLSPEC_FUC Shape shape(void) const; - MGE_WIN_DECLSPEC_FUC DType dtype(void) const; - MGE_WIN_DECLSPEC_FUC Format format(void) const; - MGE_WIN_DECLSPEC_FUC Device device(void) const; - - MGE_WIN_DECLSPEC_FUC size_t size(void) const; - MGE_WIN_DECLSPEC_FUC std::vector stride(void) const; - MGE_WIN_DECLSPEC_FUC float scale(void) const; - MGE_WIN_DECLSPEC_FUC uint8_t zero_point(void) const; + Tensor(const Tensor& rhs); + Tensor& operator=(const Tensor& rhs); + + Shape shape(void) const; + DType dtype(void) const; + Format format(void) const; + Device device(void) const; + + size_t size(void) const; + std::vector stride(void) const; + float scale(void) const; + uint8_t zero_point(void) const; void* data(void); const void* data(void) const; diff --git a/src/custom/include/megbrain/custom/utils.h b/src/custom/include/megbrain/custom/utils.h index 318bc62d..1bc64c6a 100644 --- a/src/custom/include/megbrain/custom/utils.h +++ b/src/custom/include/megbrain/custom/utils.h @@ -19,10 +19,19 @@ namespace custom { -void assert_failed_log( +#ifndef MGE_WIN_DECLSPEC_FUC +#ifdef _WIN32 +#define MGE_WIN_DECLSPEC_FUC __declspec(dllexport) +#else +#define MGE_WIN_DECLSPEC_FUC +#endif +#endif + +MGE_WIN_DECLSPEC_FUC void assert_failed_log( const char* file, int line, const char* func, const char* expr, const char* msg_fmt, ...); +#ifndef _WIN32 #define custom_expect(expr, msg...) \ if (!(expr)) { \ assert_failed_log(__FILE__, __LINE__, __PRETTY_FUNCTION__, #expr, ##msg); \ @@ -33,8 +42,22 @@ void assert_failed_log( assert_failed_log(__FILE__, __LINE__, __PRETTY_FUNCTION__, #expr, ##msg); \ } \ assert((expr)) +#else +#define custom_expect(expr, ...) \ + if (!(expr)) { \ + assert_failed_log( \ + __FILE__, __LINE__, __PRETTY_FUNCTION__, #expr, __VA_ARGS__); \ + } + +#define custom_assert(expr, ...) \ + if (!(expr)) { \ + assert_failed_log( \ + __FILE__, __LINE__, __PRETTY_FUNCTION__, #expr, __VA_ARGS__); \ + } \ + assert((expr)) +#endif -class UnImpleWarnLog { +class MGE_WIN_DECLSPEC_FUC UnImpleWarnLog { public: UnImpleWarnLog( const std::string& func, const std::string& attr, const std::string& val); @@ -54,9 +77,9 @@ void impl_deleter(void* ptr) { std::unique_ptr m_impl; \ \ public: \ - MGE_WIN_DECLSPEC_FUC Cls(); \ - MGE_WIN_DECLSPEC_FUC Cls(const Cls& rhs); \ - MGE_WIN_DECLSPEC_FUC Cls& operator=(const Cls& rhs) + Cls(); \ + Cls(const Cls& rhs); \ + Cls& operator=(const Cls& rhs) #define CUSTOM_PIMPL_CLS_DEFINE(Cls) \ Cls::Cls() : m_impl(new Cls##Impl(), impl_deleter) {} \