@@ -831,3 +831,8 @@ if(MSVC OR WIN32) | |||||
endif() | endif() | ||||
endforeach() | endforeach() | ||||
endif() | endif() | ||||
if(MGE_WITH_JIT_MLIR) | |||||
add_subdirectory(tools/mlir/mgb-opt) | |||||
add_subdirectory(tools/mlir/mgb-file-check) | |||||
endif() |
@@ -297,7 +297,7 @@ void JITFusionPass::Impl::process_opr(OperatorNodeBase* opr) { | |||||
#if MGB_JIT_MLIR | #if MGB_JIT_MLIR | ||||
//! FIXME mlir does't support broadcast currently. | //! FIXME mlir does't support broadcast currently. | ||||
auto backend = MGB_GETENV("MGB_JIT_BACKEND"); | auto backend = MGB_GETENV("MGB_JIT_BACKEND"); | ||||
if (!strcmp(backend, "MLIR")) { | |||||
if (backend && !strcmp(backend, "MLIR")) { | |||||
for (VarNode* var : opr->input()) { | for (VarNode* var : opr->input()) { | ||||
if (!SymbolVar{var}.as_immutable_scalar().valid()) { | if (!SymbolVar{var}.as_immutable_scalar().valid()) { | ||||
if (opr->node_prop().dep_map().at(var) & | if (opr->node_prop().dep_map().at(var) & | ||||
@@ -44,6 +44,7 @@ | |||||
using namespace mlir; | using namespace mlir; | ||||
namespace { | |||||
template <typename OpTy> | template <typename OpTy> | ||||
static void createForAllDimensions(OpBuilder& builder, Location loc, | static void createForAllDimensions(OpBuilder& builder, Location loc, | ||||
SmallVectorImpl<Value>& values) { | SmallVectorImpl<Value>& values) { | ||||
@@ -80,7 +81,7 @@ static bool isSinkingBeneficiary(Operation* op) { | |||||
return isa<ConstantOp, DimOp>(op); | return isa<ConstantOp, DimOp>(op); | ||||
} | } | ||||
LogicalResult mlir::sinkOperationsIntoLaunchOp(gpu::LaunchOp launchOp) { | |||||
LogicalResult sink_operations_into_launch_op(gpu::LaunchOp launchOp) { | |||||
Region& launchOpBody = launchOp.body(); | Region& launchOpBody = launchOp.body(); | ||||
// Identify uses from values defined outside of the scope of the launch | // Identify uses from values defined outside of the scope of the launch | ||||
@@ -232,7 +233,6 @@ static void convertToLaunchFuncOp(gpu::LaunchOp launchOp, | |||||
launchOp.erase(); | launchOp.erase(); | ||||
} | } | ||||
namespace { | |||||
/// Pass that moves the kernel of each LaunchOp into its separate nested module. | /// Pass that moves the kernel of each LaunchOp into its separate nested module. | ||||
/// | /// | ||||
/// This pass moves the kernel code of each LaunchOp into a function created | /// This pass moves the kernel code of each LaunchOp into a function created | ||||
@@ -258,7 +258,7 @@ public: | |||||
.str(); | .str(); | ||||
// Pull in instructions that can be sunk | // Pull in instructions that can be sunk | ||||
if (failed(sinkOperationsIntoLaunchOp(op))) | |||||
if (failed(sink_operations_into_launch_op(op))) | |||||
return WalkResult::interrupt(); | return WalkResult::interrupt(); | ||||
gpu::GPUFuncOp outlinedFunc = | gpu::GPUFuncOp outlinedFunc = | ||||
outlineKernelFuncImpl(op, kernelFnName, operands); | outlineKernelFuncImpl(op, kernelFnName, operands); | ||||
@@ -327,7 +327,6 @@ private: | |||||
return kernelModule; | return kernelModule; | ||||
} | } | ||||
}; | }; | ||||
} // namespace | } // namespace | ||||
std::unique_ptr<mlir::Pass> mgb::jit::create_gpu_kernel_outlining_pass() { | std::unique_ptr<mlir::Pass> mgb::jit::create_gpu_kernel_outlining_pass() { | ||||
@@ -20,13 +20,12 @@ | |||||
#include "./each_mode.h" | #include "./each_mode.h" | ||||
#include <llvm/ADT/Sequence.h> | |||||
#include <mlir/Dialect/Affine/IR/AffineOps.h> | #include <mlir/Dialect/Affine/IR/AffineOps.h> | ||||
#include <mlir/Pass/Pass.h> | #include <mlir/Pass/Pass.h> | ||||
#include <mlir/Transforms/DialectConversion.h> | #include <mlir/Transforms/DialectConversion.h> | ||||
#include "mlir/IR/StandardTypes.h" | #include "mlir/IR/StandardTypes.h" | ||||
#include <llvm/ADT/Sequence.h> | |||||
using namespace mgb; | using namespace mgb; | ||||
using namespace jit; | using namespace jit; | ||||
@@ -188,6 +187,7 @@ struct ReturnOpLowering : public OpRewritePattern<jit::ReturnOp> { | |||||
LogicalResult matchAndRewrite(jit::ReturnOp op, | LogicalResult matchAndRewrite(jit::ReturnOp op, | ||||
PatternRewriter& rewriter) const final { | PatternRewriter& rewriter) const final { | ||||
// We lower "mgb.return" directly to "std.return". | |||||
rewriter.replaceOpWithNewOp<mlir::ReturnOp>(op); | rewriter.replaceOpWithNewOp<mlir::ReturnOp>(op); | ||||
return success(); | return success(); | ||||
} | } | ||||
@@ -212,6 +212,7 @@ public: | |||||
void runOnFunction() override final { | void runOnFunction() override final { | ||||
ConversionTarget target(getContext()); | ConversionTarget target(getContext()); | ||||
target.addLegalDialect<AffineDialect, StandardOpsDialect>(); | target.addLegalDialect<AffineDialect, StandardOpsDialect>(); | ||||
// target.addLegalDialect<AffineDialect>(); | |||||
target.addIllegalDialect<MgbDialect>(); | target.addIllegalDialect<MgbDialect>(); | ||||
OwningRewritePatternList patterns; | OwningRewritePatternList patterns; | ||||
@@ -236,6 +237,16 @@ std::unique_ptr<mlir::Pass> mgb::jit::create_lower_to_affine_pass() { | |||||
return std::make_unique<MgbToAffineLoweringPass>(); | return std::make_unique<MgbToAffineLoweringPass>(); | ||||
} | } | ||||
namespace mgb { | |||||
namespace jit { | |||||
void register_test_mgb_to_affine_lowering_pass() { | |||||
PassRegistration<MgbToAffineLoweringPass>( | |||||
"mgb-convert-to-affine", | |||||
"Perform conversion from MGB Dialect to Affine Dialect ", | |||||
[] { return std::make_unique<MgbToAffineLoweringPass>(); }); | |||||
} | |||||
} // namespace jit | |||||
} // namespace mgb | |||||
#endif // MGB_JIT && MGB_JIT_MLIR | #endif // MGB_JIT && MGB_JIT_MLIR | ||||
// vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen |
@@ -53,6 +53,16 @@ std::unique_ptr<mlir::Pass> mgb::jit::create_lower_to_llvm_pass() { | |||||
return std::make_unique<AffineToLLVMLoweringPass>(); | return std::make_unique<AffineToLLVMLoweringPass>(); | ||||
} | } | ||||
namespace mgb { | |||||
namespace jit { | |||||
void register_test_affine_to_llvm_lowering_pass() { | |||||
PassRegistration<AffineToLLVMLoweringPass>( | |||||
"mgb-codegen-convert-affine-to-llvm", | |||||
"Perform final conversion from Affine to LLVMIR ", | |||||
[] { return std::make_unique<AffineToLLVMLoweringPass>(); }); | |||||
} | |||||
} // namespace jit | |||||
} // namespace mgb | |||||
#endif // MGB_JIT && MGB_JIT_MLIR | #endif // MGB_JIT && MGB_JIT_MLIR | ||||
// vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen |
@@ -177,6 +177,12 @@ def ReturnOp : GenericOp<"return", | |||||
The operation takes an no tensor operand and produces no results. | The operation takes an no tensor operand and produces no results. | ||||
}]; | }]; | ||||
// The return operation takes an optional input operand to return. This | |||||
// value must match the return type of the enclosing function. | |||||
let arguments = (ins); | |||||
// The return operation only emits the input in the format if it is present. | |||||
let assemblyFormat = "attr-dict"; | |||||
} | } | ||||
def ConstantScalarOp: GenericOp<"sconst", [NoSideEffect]> { | def ConstantScalarOp: GenericOp<"sconst", [NoSideEffect]> { | ||||
@@ -19,7 +19,7 @@ | |||||
namespace mgb { | namespace mgb { | ||||
namespace jit { | namespace jit { | ||||
inline const bool is_elemwise_float(const mlir::Type& dt) { | |||||
inline bool is_elemwise_float(const mlir::Type& dt) { | |||||
if (auto cast = dt.dyn_cast_or_null<mlir::MemRefType>()) { | if (auto cast = dt.dyn_cast_or_null<mlir::MemRefType>()) { | ||||
if (cast.getElementType().getKind() == mlir::StandardTypes::F32) { | if (cast.getElementType().getKind() == mlir::StandardTypes::F32) { | ||||
return true; | return true; | ||||
@@ -0,0 +1,27 @@ | |||||
configure_lit_site_cfg( | |||||
${CMAKE_CURRENT_SOURCE_DIR}/utils/lit.site.cfg.py.in | |||||
${CMAKE_CURRENT_BINARY_DIR}/utils/lit.site.cfg.py | |||||
MAIN_CONFIG | |||||
${CMAKE_CURRENT_SOURCE_DIR}/utils/lit.cfg.py | |||||
) | |||||
set(LLVM_EXTERNAL_LIT "${PROJECT_SOURCE_DIR}/third_party/llvm-project/llvm/utils/lit/lit.py" CACHE STRING "External lit") | |||||
set(MLIR_MGB_TEST_DEPENDS | |||||
mgb-file-check | |||||
count not | |||||
mgb-opt | |||||
) | |||||
add_lit_testsuite(mgb-mlir-test-lit "Running the mgb regression tests" | |||||
${CMAKE_CURRENT_BINARY_DIR}/utils | |||||
DEPENDS ${MLIR_MGB_TEST_DEPENDS} | |||||
) | |||||
set_target_properties(mgb-mlir-test-lit PROPERTIES FOLDER "Tests") | |||||
add_lit_testsuites(MLIR_TEST ${CMAKE_CURRENT_SOURCE_DIR} | |||||
DEPENDS ${MLIR_MGB_TEST_DEPENDS} | |||||
) | |||||
add_custom_target(mlir_pass_check) | |||||
add_dependencies(mlir_pass_check mgb-mlir-test-lit) |
@@ -0,0 +1,16 @@ | |||||
load("//brain/megbrain/src/jit/test/mlir/utils:lit.bzl", "mlir_lit_test_suite") | |||||
filegroup( | |||||
name = "mlir_test_tools", | |||||
testonly = True, | |||||
data = [ | |||||
"//brain/megbrain/tools/mlir:mgb-opt", | |||||
"//brain/megbrain/tools/mlir:mgb-file-check" | |||||
], | |||||
) | |||||
mlir_lit_test_suite( | |||||
name = "mlir_pass_check", | |||||
data = [":mlir_test_tools"], | |||||
test_file_exts = ["mlir",] | |||||
) |
@@ -0,0 +1,58 @@ | |||||
// RUN: mgb-opt --mgb-convert-to-affine --split-input-file -canonicalize -cse %s | mgb-file-check %s | |||||
// RUN: mgb-opt --mgb-convert-to-affine --mgb-codegen-convert-affine-to-llvm --split-input-file -canonicalize -cse %s | |||||
func @add_dim1(%lhs: memref<2xf32>, %rhs: memref<2xf32>, %res: memref<2xf32>) -> () { | |||||
%0 = "mgb.add"(%lhs, %rhs) {name = "add.f"} : | |||||
(memref<2xf32>, memref<2xf32>) -> memref<2xf32> | |||||
"mgb.assign"(%0, %res) : (memref<2xf32>, memref<2xf32>) -> () | |||||
mgb.return | |||||
} | |||||
// CHECK-LABEL: func @add_dim1(%arg0: memref<2xf32>, %arg1: memref<2xf32>, %arg2: memref<2xf32>) { | |||||
// CHECK: %0 = alloc() : memref<2xf32> | |||||
// CHECK: affine.for %arg3 = 0 to 2 { | |||||
// CHECK: %1 = affine.load %arg0[%arg3] : memref<2xf32> | |||||
// CHECK: %2 = affine.load %arg1[%arg3] : memref<2xf32> | |||||
// CHECK: %3 = addf %1, %2 : f32 | |||||
// CHECK: affine.store %3, %0[%arg3] : memref<2xf32> | |||||
// CHECK: } | |||||
// CHECK: affine.for %arg3 = 0 to 2 { | |||||
// CHECK: %1 = affine.load %0[%arg3] : memref<2xf32> | |||||
// CHECK: affine.store %1, %arg2[%arg3] : memref<2xf32> | |||||
// CHECK: } | |||||
// CHECK: dealloc %0 : memref<2xf32> | |||||
// CHECK: return | |||||
// CHECK: } | |||||
func @add_dim4(%lhs: memref<4x3x64x64xf32>, %rhs: memref<4x3x64x64xf32>, %res: memref<4x3x64x64xf32>) -> () { | |||||
%0 = "mgb.add"(%lhs, %rhs) {name = "add.f"} : | |||||
(memref<4x3x64x64xf32>, memref<4x3x64x64xf32>) -> memref<4x3x64x64xf32> | |||||
"mgb.assign"(%0, %res) : (memref<4x3x64x64xf32>, memref<4x3x64x64xf32>) -> () | |||||
mgb.return | |||||
} | |||||
// CHECK-LABEL: func @add_dim4(%arg0: memref<4x3x64x64xf32>, %arg1: memref<4x3x64x64xf32>, %arg2: memref<4x3x64x64xf32>) { | |||||
// CHECK: %0 = alloc() : memref<4x3x64x64xf32> | |||||
// CHECK: affine.for %arg3 = 0 to 4 { | |||||
// CHECK: affine.for %arg4 = 0 to 3 { | |||||
// CHECK: affine.for %arg5 = 0 to 64 { | |||||
// CHECK: affine.for %arg6 = 0 to 64 { | |||||
// CHECK: %1 = affine.load %arg0[%arg3, %arg4, %arg5, %arg6] : memref<4x3x64x64xf32> | |||||
// CHECK: %2 = affine.load %arg1[%arg3, %arg4, %arg5, %arg6] : memref<4x3x64x64xf32> | |||||
// CHECK: %3 = addf %1, %2 : f32 | |||||
// CHECK: affine.store %3, %0[%arg3, %arg4, %arg5, %arg6] : memref<4x3x64x64xf32> | |||||
// CHECK: } | |||||
// CHECK: } | |||||
// CHECK: } | |||||
// CHECK: } | |||||
// CHECK: affine.for %arg3 = 0 to 4 { | |||||
// CHECK: affine.for %arg4 = 0 to 3 { | |||||
// CHECK: affine.for %arg5 = 0 to 64 { | |||||
// CHECK: affine.for %arg6 = 0 to 64 { | |||||
// CHECK: %1 = affine.load %0[%arg3, %arg4, %arg5, %arg6] : memref<4x3x64x64xf32> | |||||
// CHECK: affine.store %1, %arg2[%arg3, %arg4, %arg5, %arg6] : memref<4x3x64x64xf32> | |||||
// CHECK: } | |||||
// CHECK: } | |||||
// CHECK: } | |||||
// CHECK: } | |||||
// CHECK: dealloc %0 : memref<4x3x64x64xf32> | |||||
// CHECK: return | |||||
// CHECK: } |
@@ -0,0 +1,5 @@ | |||||
filegroup( | |||||
name = "litfiles", | |||||
srcs = glob(["lit.bzl.*py"]), | |||||
visibility = ["//visibility:public"], | |||||
) |
@@ -0,0 +1,127 @@ | |||||
# Test definitions for Lit, the LLVM test runner. | |||||
# | |||||
"""Lit runner globbing test | |||||
""" | |||||
# Default values used by the test runner. | |||||
_default_test_file_exts = ["mlir", "pbtxt", "td"] | |||||
_default_size = "small" | |||||
_default_tags = [] | |||||
# These are patterns which we should never match, for tests, subdirectories, or | |||||
# test input data files. | |||||
_ALWAYS_EXCLUDE = [ | |||||
"**/LICENSE.txt", | |||||
"**/README.txt", | |||||
"**/lit.local.cfg", | |||||
# Exclude input files that have spaces in their names, since bazel | |||||
# cannot cope with such "targets" in the srcs list. | |||||
"**/* *", | |||||
"**/* */**", | |||||
] | |||||
def _run_lit_test(name, data, size, tags, features): | |||||
"""Runs lit on all tests it can find in `data` under megbrain/src/jit/test/mlir/ir. | |||||
Note that, due to Bazel's hermetic builds, lit only sees the tests that | |||||
are included in the `data` parameter, regardless of what other tests might | |||||
exist in the directory searched. | |||||
Args: | |||||
name: str, the name of the test, including extension. | |||||
data: [str], the data input to the test. | |||||
size: str, the size of the test. | |||||
tags: [str], tags to attach to the test. | |||||
features: [str], list of extra features to enable. | |||||
""" | |||||
native.py_test( | |||||
name = name, | |||||
srcs = ["@llvm-project//llvm:lit"], | |||||
tags = tags, | |||||
args = [ | |||||
"brain/megbrain/src/jit/test/mlir/utils --config-prefix=lit.bzl -v", | |||||
] + features, | |||||
data = data + [ | |||||
"//brain/megbrain/src/jit/test/mlir/utils:litfiles", | |||||
"//brain/megbrain/tools/mlir:mgb-file-check", | |||||
"@llvm-project//llvm:count", | |||||
"@llvm-project//llvm:not", | |||||
], | |||||
size = size, | |||||
main = "lit.py", | |||||
) | |||||
def mlir_lit_test_suite( | |||||
name, | |||||
exclude = [], | |||||
test_file_exts = _default_test_file_exts, | |||||
default_size = _default_size, | |||||
size_override = {}, | |||||
data = [], | |||||
per_test_extra_data = {}, | |||||
default_tags = _default_tags, | |||||
tags_override = {}, | |||||
features = []): | |||||
"""Creates all plausible Lit tests (and their inputs) under this directory. | |||||
Args: | |||||
name: str, name of the generated test suite. | |||||
exclude: [str], paths to exclude (for tests and inputs). | |||||
test_file_exts: [str], extensions for files that are tests. | |||||
default_size: str, the test size for targets not in "size_override". | |||||
size_override: {str: str}, sizes to use for specific tests. | |||||
data: [str], additional input data to the test. | |||||
per_test_extra_data: {str: [str]}, extra data to attach to a given file. | |||||
default_tags: [str], additional tags to attach to the test. | |||||
tags_override: {str: str}, tags to add to specific tests. | |||||
features: [str], list of extra features to enable. | |||||
""" | |||||
# Ignore some patterns by default for tests and input data. | |||||
exclude = _ALWAYS_EXCLUDE + exclude | |||||
test_names = [] | |||||
tests = native.glob( | |||||
["*." + ext for ext in test_file_exts], | |||||
exclude = exclude, | |||||
) | |||||
# Run tests individually such that errors can be attributed to a specific | |||||
# failure. | |||||
for i in range(len(tests)): | |||||
cur_test = tests[i] | |||||
# Instantiate this test with updated parameters. | |||||
internal_name = cur_test | |||||
lit_test( | |||||
name = internal_name, | |||||
data = data + per_test_extra_data.pop(cur_test, []), | |||||
size = size_override.pop(cur_test, default_size), | |||||
tags = ["windows_fail"] + default_tags + tags_override.pop(cur_test, []), | |||||
features = features, | |||||
) | |||||
test_names.append(internal_name + ".test") | |||||
native.test_suite( | |||||
name = name, | |||||
tests = test_names, | |||||
tags = default_tags, | |||||
) | |||||
def lit_test( | |||||
name, | |||||
data = [], | |||||
size = _default_size, | |||||
tags = _default_tags, | |||||
features = []): | |||||
"""Runs test files under lit. | |||||
Args: | |||||
name: str, the name of the test. | |||||
data: [str], labels that should be provided as data inputs. | |||||
size: str, the size of the test. | |||||
tags: [str], tags to attach to the test. | |||||
features: [str], list of extra features to enable. | |||||
""" | |||||
_run_lit_test(name + ".test", data + [name], size, tags, features) |
@@ -0,0 +1,52 @@ | |||||
# -*- coding: utf-8 -*- | |||||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
# | |||||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, | |||||
# software distributed under the License is distributed on an | |||||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
import os | |||||
import platform | |||||
import re | |||||
import subprocess | |||||
import tempfile | |||||
import lit.formats | |||||
import lit.util | |||||
from lit.llvm import llvm_config | |||||
from lit.llvm.subst import ToolSubst | |||||
from lit.llvm.subst import FindTool | |||||
# Configuration file for the 'lit' test runner. | |||||
# name: The name of this test suite. | |||||
config.name = 'MLIR_TEST' | |||||
config.test_format = lit.formats.ShTest(not llvm_config.use_lit_shell) | |||||
# suffixes: A list of file extensions to treat as test files. | |||||
config.suffixes = ['.mlir'] | |||||
# test_source_root: The root path where tests are located. | |||||
config.test_source_root = config.mlir_test_dir | |||||
# test_exec_root: The root path where tests should be run. | |||||
config.test_exec_root = os.environ['RUNFILES_DIR'] | |||||
llvm_config.use_default_substitutions() | |||||
# Tweak the PATH to include the tools dir. | |||||
llvm_config.with_environment('PATH', config.llvm_tools_dir, append_path=True) | |||||
tool_dirs = config.mlir_mgb_tools_dirs + [config.mlir_tools_dir, config.llvm_tools_dir] | |||||
tool_names = [ | |||||
'mgb-opt', | |||||
'mlir-tblgen', | |||||
'mlir-translate', | |||||
'mgb-file-check', | |||||
] | |||||
tools = [ToolSubst(s, unresolved='ignore') for s in tool_names] | |||||
llvm_config.add_tool_substitutions(tools, tool_dirs) |
@@ -0,0 +1,43 @@ | |||||
# -*- coding: utf-8 -*- | |||||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
# | |||||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, | |||||
# software distributed under the License is distributed on an | |||||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
"""Lit runner site configuration.""" | |||||
import os | |||||
import lit.llvm | |||||
config.llvm_tools_dir = os.path.join(os.environ['TEST_SRCDIR'], 'llvm-project', 'llvm') | |||||
config.mlir_obj_root = os.path.join(os.environ['TEST_SRCDIR']) | |||||
config.mlir_tools_dir = os.path.join(os.environ['TEST_SRCDIR'], 'llvm-project', 'mlir') | |||||
config.suffixes = ['.td', '.mlir', '.pbtxt'] | |||||
mlir_mgb_tools_dirs = [ | |||||
'brain/megbrain/tools/mlir', | |||||
] | |||||
config.mlir_mgb_tools_dirs = [ | |||||
os.path.join(os.environ['TEST_SRCDIR'], os.environ['TEST_WORKSPACE'], s) | |||||
for s in mlir_mgb_tools_dirs | |||||
] | |||||
test_dir = os.environ['TEST_TARGET'] | |||||
test_dir = test_dir.strip('/').rsplit(':', 1)[0] | |||||
config.mlir_test_dir = os.path.join( | |||||
os.environ['TEST_SRCDIR'], | |||||
os.environ['TEST_WORKSPACE'], | |||||
test_dir, | |||||
) | |||||
lit.llvm.initialize(lit_config, config) | |||||
# Let the main config do the real work. | |||||
lit_config.load_config( | |||||
config, | |||||
os.path.join( | |||||
os.path.join( | |||||
os.environ['TEST_SRCDIR'], | |||||
os.environ['TEST_WORKSPACE'], | |||||
'brain/megbrain/src/jit/test/mlir/utils/lit.bzl.cfg.py', | |||||
))) |
@@ -0,0 +1,49 @@ | |||||
@LIT_SITE_CFG_IN_HEADER@ | |||||
import sys | |||||
config.host_triple = "@LLVM_HOST_TRIPLE@" | |||||
config.target_triple = "@TARGET_TRIPLE@" | |||||
config.llvm_src_root = "@LLVM_SOURCE_DIR@" | |||||
config.llvm_obj_root = "@LLVM_BINARY_DIR@" | |||||
config.llvm_tools_dir = "@LLVM_TOOLS_DIR@" | |||||
config.llvm_lib_dir = "@LLVM_LIBRARY_DIR@" | |||||
config.llvm_shlib_dir = "@SHLIBDIR@" | |||||
config.llvm_shlib_ext = "@SHLIBEXT@" | |||||
config.llvm_exe_ext = "@EXEEXT@" | |||||
config.lit_tools_dir = "@LLVM_LIT_TOOLS_DIR@" | |||||
config.python_executable = "@PYTHON_EXECUTABLE@" | |||||
config.gold_executable = "@GOLD_EXECUTABLE@" | |||||
config.ld64_executable = "@LD64_EXECUTABLE@" | |||||
config.enable_shared = @ENABLE_SHARED@ | |||||
config.enable_assertions = @ENABLE_ASSERTIONS@ | |||||
config.targets_to_build = "@TARGETS_TO_BUILD@" | |||||
config.native_target = "@LLVM_NATIVE_ARCH@" | |||||
config.llvm_bindings = "@LLVM_BINDINGS@".split(' ') | |||||
config.host_os = "@HOST_OS@" | |||||
config.host_cc = "@HOST_CC@" | |||||
config.host_cxx = "@HOST_CXX@" | |||||
# Note: ldflags can contain double-quoted paths, so must use single quotes here. | |||||
config.host_ldflags = '@HOST_LDFLAGS@' | |||||
config.llvm_use_sanitizer = "@LLVM_USE_SANITIZER@" | |||||
config.llvm_host_triple = '@LLVM_HOST_TRIPLE@' | |||||
config.host_arch = "@HOST_ARCH@" | |||||
config.mgb_src_root = "@CMAKE_SOURCE_DIR@" | |||||
config.mgb_obj_root = "@CMAKE_BINARY_DIR@" | |||||
# Support substitution of the tools_dir with user parameters. This is | |||||
# used when we can't determine the tool dir at configuration time. | |||||
try: | |||||
config.llvm_tools_dir = config.llvm_tools_dir % lit_config.params | |||||
config.llvm_shlib_dir = config.llvm_shlib_dir % lit_config.params | |||||
except KeyError: | |||||
e = sys.exc_info()[1] | |||||
key, = e.args | |||||
lit_config.fatal("unable to find %r parameter, use '--param=%s=VALUE'" % (key,key)) | |||||
import lit.llvm | |||||
lit.llvm.initialize(lit_config, config) | |||||
# Let the main config do the real work. | |||||
lit_config.load_config(config, "@CMAKE_SOURCE_DIR@/src/jit/test/mlir/utils/lit.cfg.py") |
@@ -0,0 +1,58 @@ | |||||
# -*- coding: utf-8 -*- | |||||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
# | |||||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, | |||||
# software distributed under the License is distributed on an | |||||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
import os | |||||
import platform | |||||
import re | |||||
import subprocess | |||||
import tempfile | |||||
import lit.formats | |||||
import lit.util | |||||
from lit.llvm import llvm_config | |||||
from lit.llvm.subst import ToolSubst | |||||
from lit.llvm.subst import FindTool | |||||
# Configuration file for the 'lit' test runner. | |||||
# name: The name of this test suite. | |||||
config.name = 'MLIR_TEST' | |||||
config.test_format = lit.formats.ShTest(not llvm_config.use_lit_shell) | |||||
# suffixes: A list of file extensions to treat as test files. | |||||
config.suffixes = ['.mlir'] | |||||
# test_source_root: The root path where tests are located. | |||||
config.test_source_root = os.path.join(os.path.dirname(__file__), '../ir') | |||||
# test_exec_root: The root path where tests should be run. | |||||
config.test_exec_root = config.test_source_root | |||||
# llvm_config.use_default_substitutions() | |||||
# Tweak the PATH to include the tools dir. | |||||
llvm_config.with_environment('PATH', config.llvm_tools_dir, append_path=True) | |||||
tool_dirs = [ | |||||
os.path.join(config.mgb_obj_root, 'tools/mlir'), | |||||
os.path.join(config.mgb_obj_root, 'tools/mlir/mgb-opt'), | |||||
os.path.join(config.mgb_obj_root, 'tools/mlir/mgb-file-check'), | |||||
config.llvm_tools_dir] | |||||
tool_names = [ | |||||
'mgb-opt', | |||||
'mlir-tblgen', | |||||
'mlir-translate', | |||||
'mgb-file-check', | |||||
] | |||||
tools = [ToolSubst(s, unresolved='ignore') for s in tool_names] | |||||
llvm_config.add_tool_substitutions(tools, tool_dirs) | |||||
lit.llvm.initialize(lit_config, config) |
@@ -0,0 +1,49 @@ | |||||
@LIT_SITE_CFG_IN_HEADER@ | |||||
import sys | |||||
config.host_triple = "@LLVM_HOST_TRIPLE@" | |||||
config.target_triple = "@TARGET_TRIPLE@" | |||||
config.llvm_src_root = "@LLVM_SOURCE_DIR@" | |||||
config.llvm_obj_root = "@LLVM_BINARY_DIR@" | |||||
config.llvm_tools_dir = "@LLVM_BINARY_DIR@/bin" | |||||
config.llvm_lib_dir = "@LLVM_LIBRARY_DIR@" | |||||
config.llvm_shlib_dir = "@SHLIBDIR@" | |||||
config.llvm_shlib_ext = "@SHLIBEXT@" | |||||
config.llvm_exe_ext = "@EXEEXT@" | |||||
config.lit_tools_dir = "@LLVM_LIT_TOOLS_DIR@" | |||||
config.python_executable = "@PYTHON_EXECUTABLE@" | |||||
config.gold_executable = "@GOLD_EXECUTABLE@" | |||||
config.ld64_executable = "@LD64_EXECUTABLE@" | |||||
config.enable_shared = @ENABLE_SHARED@ | |||||
config.enable_assertions = @ENABLE_ASSERTIONS@ | |||||
config.targets_to_build = "@TARGETS_TO_BUILD@" | |||||
config.native_target = "@LLVM_NATIVE_ARCH@" | |||||
config.llvm_bindings = "@LLVM_BINDINGS@".split(' ') | |||||
config.host_os = "@HOST_OS@" | |||||
config.host_cc = "@HOST_CC@" | |||||
config.host_cxx = "@HOST_CXX@" | |||||
# Note: ldflags can contain double-quoted paths, so must use single quotes here. | |||||
config.host_ldflags = '@HOST_LDFLAGS@' | |||||
config.llvm_use_sanitizer = "@LLVM_USE_SANITIZER@" | |||||
config.llvm_host_triple = '@LLVM_HOST_TRIPLE@' | |||||
config.host_arch = "@HOST_ARCH@" | |||||
config.mgb_src_root = "@CMAKE_SOURCE_DIR@" | |||||
config.mgb_obj_root = "@CMAKE_BINARY_DIR@" | |||||
# Support substitution of the tools_dir with user parameters. This is | |||||
# used when we can't determine the tool dir at configuration time. | |||||
try: | |||||
config.llvm_tools_dir = config.llvm_tools_dir % lit_config.params | |||||
config.llvm_shlib_dir = config.llvm_shlib_dir % lit_config.params | |||||
except KeyError: | |||||
e = sys.exc_info()[1] | |||||
key, = e.args | |||||
lit_config.fatal("unable to find %r parameter, use '--param=%s=VALUE'" % (key,key)) | |||||
import lit.llvm | |||||
lit.llvm.initialize(lit_config, config) | |||||
# Let the main config do the real work. | |||||
lit_config.load_config(config, "@CMAKE_SOURCE_DIR@/src/jit/test/mlir/utils/lit.cfg.py") |
@@ -43,3 +43,9 @@ endif() | |||||
if (MGE_WITH_DISTRIBUTED) | if (MGE_WITH_DISTRIBUTED) | ||||
target_link_libraries(megbrain_test megray) | target_link_libraries(megbrain_test megray) | ||||
endif() | endif() | ||||
if(MGE_WITH_JIT) | |||||
if(MGE_WITH_JIT_MLIR) | |||||
add_subdirectory(${PROJECT_SOURCE_DIR}/src/jit/test/mlir ${CMAKE_CURRENT_BINARY_DIR}/../src/jit/test/mlir) | |||||
endif() | |||||
endif() |