Browse Source

feat(mgb/jit): add llvm-lit mlir pass check

GitOrigin-RevId: acc798a49f
release-1.1
Megvii Engine Team 4 years ago
parent
commit
5cfb46e174
18 changed files with 529 additions and 8 deletions
  1. +5
    -0
      CMakeLists.txt
  2. +1
    -1
      src/jit/impl/fusion_pass.cpp
  3. +3
    -4
      src/jit/impl/mlir/ir/create_gpu_kernel_outlining_pass.cpp
  4. +13
    -2
      src/jit/impl/mlir/ir/lower_to_affine_pass.cpp
  5. +10
    -0
      src/jit/impl/mlir/ir/lower_to_llvm_pass.cpp
  6. +6
    -0
      src/jit/impl/mlir/ir/ops.td
  7. +1
    -1
      src/jit/impl/mlir/ir/types.h
  8. +27
    -0
      src/jit/test/mlir/CMakeLists.txt
  9. +16
    -0
      src/jit/test/mlir/ir/BUILD
  10. +58
    -0
      src/jit/test/mlir/ir/add.mlir
  11. +5
    -0
      src/jit/test/mlir/utils/BUILD
  12. +127
    -0
      src/jit/test/mlir/utils/lit.bzl
  13. +52
    -0
      src/jit/test/mlir/utils/lit.bzl.cfg.py
  14. +43
    -0
      src/jit/test/mlir/utils/lit.bzl.site.cfg.py
  15. +49
    -0
      src/jit/test/mlir/utils/lit.bzl.site.cfg.py.in
  16. +58
    -0
      src/jit/test/mlir/utils/lit.cfg.py
  17. +49
    -0
      src/jit/test/mlir/utils/lit.site.cfg.py.in
  18. +6
    -0
      test/CMakeLists.txt

+ 5
- 0
CMakeLists.txt View File

@@ -831,3 +831,8 @@ if(MSVC OR WIN32)
endif()
endforeach()
endif()

if(MGE_WITH_JIT_MLIR)
add_subdirectory(tools/mlir/mgb-opt)
add_subdirectory(tools/mlir/mgb-file-check)
endif()

+ 1
- 1
src/jit/impl/fusion_pass.cpp View File

@@ -297,7 +297,7 @@ void JITFusionPass::Impl::process_opr(OperatorNodeBase* opr) {
#if MGB_JIT_MLIR
//! FIXME mlir does't support broadcast currently.
auto backend = MGB_GETENV("MGB_JIT_BACKEND");
if (!strcmp(backend, "MLIR")) {
if (backend && !strcmp(backend, "MLIR")) {
for (VarNode* var : opr->input()) {
if (!SymbolVar{var}.as_immutable_scalar().valid()) {
if (opr->node_prop().dep_map().at(var) &


+ 3
- 4
src/jit/impl/mlir/ir/create_gpu_kernel_outlining_pass.cpp View File

@@ -44,6 +44,7 @@

using namespace mlir;

namespace {
template <typename OpTy>
static void createForAllDimensions(OpBuilder& builder, Location loc,
SmallVectorImpl<Value>& values) {
@@ -80,7 +81,7 @@ static bool isSinkingBeneficiary(Operation* op) {
return isa<ConstantOp, DimOp>(op);
}

LogicalResult mlir::sinkOperationsIntoLaunchOp(gpu::LaunchOp launchOp) {
LogicalResult sink_operations_into_launch_op(gpu::LaunchOp launchOp) {
Region& launchOpBody = launchOp.body();

// Identify uses from values defined outside of the scope of the launch
@@ -232,7 +233,6 @@ static void convertToLaunchFuncOp(gpu::LaunchOp launchOp,
launchOp.erase();
}

namespace {
/// Pass that moves the kernel of each LaunchOp into its separate nested module.
///
/// This pass moves the kernel code of each LaunchOp into a function created
@@ -258,7 +258,7 @@ public:
.str();

// Pull in instructions that can be sunk
if (failed(sinkOperationsIntoLaunchOp(op)))
if (failed(sink_operations_into_launch_op(op)))
return WalkResult::interrupt();
gpu::GPUFuncOp outlinedFunc =
outlineKernelFuncImpl(op, kernelFnName, operands);
@@ -327,7 +327,6 @@ private:
return kernelModule;
}
};

} // namespace

std::unique_ptr<mlir::Pass> mgb::jit::create_gpu_kernel_outlining_pass() {


+ 13
- 2
src/jit/impl/mlir/ir/lower_to_affine_pass.cpp View File

@@ -20,13 +20,12 @@

#include "./each_mode.h"

#include <llvm/ADT/Sequence.h>
#include <mlir/Dialect/Affine/IR/AffineOps.h>
#include <mlir/Pass/Pass.h>
#include <mlir/Transforms/DialectConversion.h>
#include "mlir/IR/StandardTypes.h"

#include <llvm/ADT/Sequence.h>

using namespace mgb;
using namespace jit;

@@ -188,6 +187,7 @@ struct ReturnOpLowering : public OpRewritePattern<jit::ReturnOp> {

LogicalResult matchAndRewrite(jit::ReturnOp op,
PatternRewriter& rewriter) const final {
// We lower "mgb.return" directly to "std.return".
rewriter.replaceOpWithNewOp<mlir::ReturnOp>(op);
return success();
}
@@ -212,6 +212,7 @@ public:
void runOnFunction() override final {
ConversionTarget target(getContext());
target.addLegalDialect<AffineDialect, StandardOpsDialect>();
// target.addLegalDialect<AffineDialect>();
target.addIllegalDialect<MgbDialect>();

OwningRewritePatternList patterns;
@@ -236,6 +237,16 @@ std::unique_ptr<mlir::Pass> mgb::jit::create_lower_to_affine_pass() {
return std::make_unique<MgbToAffineLoweringPass>();
}

namespace mgb {
namespace jit {
void register_test_mgb_to_affine_lowering_pass() {
PassRegistration<MgbToAffineLoweringPass>(
"mgb-convert-to-affine",
"Perform conversion from MGB Dialect to Affine Dialect ",
[] { return std::make_unique<MgbToAffineLoweringPass>(); });
}
} // namespace jit
} // namespace mgb
#endif // MGB_JIT && MGB_JIT_MLIR

// vim: syntax=cpp.doxygen

+ 10
- 0
src/jit/impl/mlir/ir/lower_to_llvm_pass.cpp View File

@@ -53,6 +53,16 @@ std::unique_ptr<mlir::Pass> mgb::jit::create_lower_to_llvm_pass() {
return std::make_unique<AffineToLLVMLoweringPass>();
}

namespace mgb {
namespace jit {
void register_test_affine_to_llvm_lowering_pass() {
PassRegistration<AffineToLLVMLoweringPass>(
"mgb-codegen-convert-affine-to-llvm",
"Perform final conversion from Affine to LLVMIR ",
[] { return std::make_unique<AffineToLLVMLoweringPass>(); });
}
} // namespace jit
} // namespace mgb
#endif // MGB_JIT && MGB_JIT_MLIR

// vim: syntax=cpp.doxygen

+ 6
- 0
src/jit/impl/mlir/ir/ops.td View File

@@ -177,6 +177,12 @@ def ReturnOp : GenericOp<"return",
The operation takes an no tensor operand and produces no results.
}];

// The return operation takes an optional input operand to return. This
// value must match the return type of the enclosing function.
let arguments = (ins);

// The return operation only emits the input in the format if it is present.
let assemblyFormat = "attr-dict";
}

def ConstantScalarOp: GenericOp<"sconst", [NoSideEffect]> {


+ 1
- 1
src/jit/impl/mlir/ir/types.h View File

@@ -19,7 +19,7 @@
namespace mgb {
namespace jit {

inline const bool is_elemwise_float(const mlir::Type& dt) {
inline bool is_elemwise_float(const mlir::Type& dt) {
if (auto cast = dt.dyn_cast_or_null<mlir::MemRefType>()) {
if (cast.getElementType().getKind() == mlir::StandardTypes::F32) {
return true;


+ 27
- 0
src/jit/test/mlir/CMakeLists.txt View File

@@ -0,0 +1,27 @@
configure_lit_site_cfg(
${CMAKE_CURRENT_SOURCE_DIR}/utils/lit.site.cfg.py.in
${CMAKE_CURRENT_BINARY_DIR}/utils/lit.site.cfg.py
MAIN_CONFIG
${CMAKE_CURRENT_SOURCE_DIR}/utils/lit.cfg.py
)

set(LLVM_EXTERNAL_LIT "${PROJECT_SOURCE_DIR}/third_party/llvm-project/llvm/utils/lit/lit.py" CACHE STRING "External lit")

set(MLIR_MGB_TEST_DEPENDS
mgb-file-check
count not
mgb-opt
)

add_lit_testsuite(mgb-mlir-test-lit "Running the mgb regression tests"
${CMAKE_CURRENT_BINARY_DIR}/utils
DEPENDS ${MLIR_MGB_TEST_DEPENDS}
)
set_target_properties(mgb-mlir-test-lit PROPERTIES FOLDER "Tests")

add_lit_testsuites(MLIR_TEST ${CMAKE_CURRENT_SOURCE_DIR}
DEPENDS ${MLIR_MGB_TEST_DEPENDS}
)

add_custom_target(mlir_pass_check)
add_dependencies(mlir_pass_check mgb-mlir-test-lit)

+ 16
- 0
src/jit/test/mlir/ir/BUILD View File

@@ -0,0 +1,16 @@
load("//brain/megbrain/src/jit/test/mlir/utils:lit.bzl", "mlir_lit_test_suite")

filegroup(
name = "mlir_test_tools",
testonly = True,
data = [
"//brain/megbrain/tools/mlir:mgb-opt",
"//brain/megbrain/tools/mlir:mgb-file-check"
],
)

mlir_lit_test_suite(
name = "mlir_pass_check",
data = [":mlir_test_tools"],
test_file_exts = ["mlir",]
)

+ 58
- 0
src/jit/test/mlir/ir/add.mlir View File

@@ -0,0 +1,58 @@
// RUN: mgb-opt --mgb-convert-to-affine --split-input-file -canonicalize -cse %s | mgb-file-check %s
// RUN: mgb-opt --mgb-convert-to-affine --mgb-codegen-convert-affine-to-llvm --split-input-file -canonicalize -cse %s

func @add_dim1(%lhs: memref<2xf32>, %rhs: memref<2xf32>, %res: memref<2xf32>) -> () {
%0 = "mgb.add"(%lhs, %rhs) {name = "add.f"} :
(memref<2xf32>, memref<2xf32>) -> memref<2xf32>
"mgb.assign"(%0, %res) : (memref<2xf32>, memref<2xf32>) -> ()
mgb.return
}
// CHECK-LABEL: func @add_dim1(%arg0: memref<2xf32>, %arg1: memref<2xf32>, %arg2: memref<2xf32>) {
// CHECK: %0 = alloc() : memref<2xf32>
// CHECK: affine.for %arg3 = 0 to 2 {
// CHECK: %1 = affine.load %arg0[%arg3] : memref<2xf32>
// CHECK: %2 = affine.load %arg1[%arg3] : memref<2xf32>
// CHECK: %3 = addf %1, %2 : f32
// CHECK: affine.store %3, %0[%arg3] : memref<2xf32>
// CHECK: }
// CHECK: affine.for %arg3 = 0 to 2 {
// CHECK: %1 = affine.load %0[%arg3] : memref<2xf32>
// CHECK: affine.store %1, %arg2[%arg3] : memref<2xf32>
// CHECK: }
// CHECK: dealloc %0 : memref<2xf32>
// CHECK: return
// CHECK: }

func @add_dim4(%lhs: memref<4x3x64x64xf32>, %rhs: memref<4x3x64x64xf32>, %res: memref<4x3x64x64xf32>) -> () {
%0 = "mgb.add"(%lhs, %rhs) {name = "add.f"} :
(memref<4x3x64x64xf32>, memref<4x3x64x64xf32>) -> memref<4x3x64x64xf32>
"mgb.assign"(%0, %res) : (memref<4x3x64x64xf32>, memref<4x3x64x64xf32>) -> ()
mgb.return
}
// CHECK-LABEL: func @add_dim4(%arg0: memref<4x3x64x64xf32>, %arg1: memref<4x3x64x64xf32>, %arg2: memref<4x3x64x64xf32>) {
// CHECK: %0 = alloc() : memref<4x3x64x64xf32>
// CHECK: affine.for %arg3 = 0 to 4 {
// CHECK: affine.for %arg4 = 0 to 3 {
// CHECK: affine.for %arg5 = 0 to 64 {
// CHECK: affine.for %arg6 = 0 to 64 {
// CHECK: %1 = affine.load %arg0[%arg3, %arg4, %arg5, %arg6] : memref<4x3x64x64xf32>
// CHECK: %2 = affine.load %arg1[%arg3, %arg4, %arg5, %arg6] : memref<4x3x64x64xf32>
// CHECK: %3 = addf %1, %2 : f32
// CHECK: affine.store %3, %0[%arg3, %arg4, %arg5, %arg6] : memref<4x3x64x64xf32>
// CHECK: }
// CHECK: }
// CHECK: }
// CHECK: }
// CHECK: affine.for %arg3 = 0 to 4 {
// CHECK: affine.for %arg4 = 0 to 3 {
// CHECK: affine.for %arg5 = 0 to 64 {
// CHECK: affine.for %arg6 = 0 to 64 {
// CHECK: %1 = affine.load %0[%arg3, %arg4, %arg5, %arg6] : memref<4x3x64x64xf32>
// CHECK: affine.store %1, %arg2[%arg3, %arg4, %arg5, %arg6] : memref<4x3x64x64xf32>
// CHECK: }
// CHECK: }
// CHECK: }
// CHECK: }
// CHECK: dealloc %0 : memref<4x3x64x64xf32>
// CHECK: return
// CHECK: }

+ 5
- 0
src/jit/test/mlir/utils/BUILD View File

@@ -0,0 +1,5 @@
filegroup(
name = "litfiles",
srcs = glob(["lit.bzl.*py"]),
visibility = ["//visibility:public"],
)

+ 127
- 0
src/jit/test/mlir/utils/lit.bzl View File

@@ -0,0 +1,127 @@
# Test definitions for Lit, the LLVM test runner.
#
"""Lit runner globbing test
"""

# Default values used by the test runner.
_default_test_file_exts = ["mlir", "pbtxt", "td"]
_default_size = "small"
_default_tags = []

# These are patterns which we should never match, for tests, subdirectories, or
# test input data files.
_ALWAYS_EXCLUDE = [
"**/LICENSE.txt",
"**/README.txt",
"**/lit.local.cfg",
# Exclude input files that have spaces in their names, since bazel
# cannot cope with such "targets" in the srcs list.
"**/* *",
"**/* */**",
]

def _run_lit_test(name, data, size, tags, features):
"""Runs lit on all tests it can find in `data` under megbrain/src/jit/test/mlir/ir.

Note that, due to Bazel's hermetic builds, lit only sees the tests that
are included in the `data` parameter, regardless of what other tests might
exist in the directory searched.

Args:
name: str, the name of the test, including extension.
data: [str], the data input to the test.
size: str, the size of the test.
tags: [str], tags to attach to the test.
features: [str], list of extra features to enable.
"""

native.py_test(
name = name,
srcs = ["@llvm-project//llvm:lit"],
tags = tags,
args = [
"brain/megbrain/src/jit/test/mlir/utils --config-prefix=lit.bzl -v",
] + features,
data = data + [
"//brain/megbrain/src/jit/test/mlir/utils:litfiles",
"//brain/megbrain/tools/mlir:mgb-file-check",
"@llvm-project//llvm:count",
"@llvm-project//llvm:not",
],
size = size,
main = "lit.py",
)

def mlir_lit_test_suite(
name,
exclude = [],
test_file_exts = _default_test_file_exts,
default_size = _default_size,
size_override = {},
data = [],
per_test_extra_data = {},
default_tags = _default_tags,
tags_override = {},
features = []):
"""Creates all plausible Lit tests (and their inputs) under this directory.

Args:
name: str, name of the generated test suite.
exclude: [str], paths to exclude (for tests and inputs).
test_file_exts: [str], extensions for files that are tests.
default_size: str, the test size for targets not in "size_override".
size_override: {str: str}, sizes to use for specific tests.
data: [str], additional input data to the test.
per_test_extra_data: {str: [str]}, extra data to attach to a given file.
default_tags: [str], additional tags to attach to the test.
tags_override: {str: str}, tags to add to specific tests.
features: [str], list of extra features to enable.
"""

# Ignore some patterns by default for tests and input data.
exclude = _ALWAYS_EXCLUDE + exclude

test_names = []
tests = native.glob(
["*." + ext for ext in test_file_exts],
exclude = exclude,
)

# Run tests individually such that errors can be attributed to a specific
# failure.
for i in range(len(tests)):
cur_test = tests[i]

# Instantiate this test with updated parameters.
internal_name = cur_test
lit_test(
name = internal_name,
data = data + per_test_extra_data.pop(cur_test, []),
size = size_override.pop(cur_test, default_size),
tags = ["windows_fail"] + default_tags + tags_override.pop(cur_test, []),
features = features,
)
test_names.append(internal_name + ".test")

native.test_suite(
name = name,
tests = test_names,
tags = default_tags,
)

def lit_test(
name,
data = [],
size = _default_size,
tags = _default_tags,
features = []):
"""Runs test files under lit.

Args:
name: str, the name of the test.
data: [str], labels that should be provided as data inputs.
size: str, the size of the test.
tags: [str], tags to attach to the test.
features: [str], list of extra features to enable.
"""
_run_lit_test(name + ".test", data + [name], size, tags, features)

+ 52
- 0
src/jit/test/mlir/utils/lit.bzl.cfg.py View File

@@ -0,0 +1,52 @@
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.

import os
import platform
import re
import subprocess
import tempfile

import lit.formats
import lit.util

from lit.llvm import llvm_config
from lit.llvm.subst import ToolSubst
from lit.llvm.subst import FindTool

# Configuration file for the 'lit' test runner.

# name: The name of this test suite.
config.name = 'MLIR_TEST'

config.test_format = lit.formats.ShTest(not llvm_config.use_lit_shell)

# suffixes: A list of file extensions to treat as test files.
config.suffixes = ['.mlir']

# test_source_root: The root path where tests are located.
config.test_source_root = config.mlir_test_dir

# test_exec_root: The root path where tests should be run.
config.test_exec_root = os.environ['RUNFILES_DIR']

llvm_config.use_default_substitutions()

# Tweak the PATH to include the tools dir.
llvm_config.with_environment('PATH', config.llvm_tools_dir, append_path=True)

tool_dirs = config.mlir_mgb_tools_dirs + [config.mlir_tools_dir, config.llvm_tools_dir]
tool_names = [
'mgb-opt',
'mlir-tblgen',
'mlir-translate',
'mgb-file-check',
]
tools = [ToolSubst(s, unresolved='ignore') for s in tool_names]
llvm_config.add_tool_substitutions(tools, tool_dirs)

+ 43
- 0
src/jit/test/mlir/utils/lit.bzl.site.cfg.py View File

@@ -0,0 +1,43 @@
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.

"""Lit runner site configuration."""
import os
import lit.llvm

config.llvm_tools_dir = os.path.join(os.environ['TEST_SRCDIR'], 'llvm-project', 'llvm')
config.mlir_obj_root = os.path.join(os.environ['TEST_SRCDIR'])
config.mlir_tools_dir = os.path.join(os.environ['TEST_SRCDIR'], 'llvm-project', 'mlir')
config.suffixes = ['.td', '.mlir', '.pbtxt']

mlir_mgb_tools_dirs = [
'brain/megbrain/tools/mlir',
]
config.mlir_mgb_tools_dirs = [
os.path.join(os.environ['TEST_SRCDIR'], os.environ['TEST_WORKSPACE'], s)
for s in mlir_mgb_tools_dirs
]
test_dir = os.environ['TEST_TARGET']
test_dir = test_dir.strip('/').rsplit(':', 1)[0]
config.mlir_test_dir = os.path.join(
os.environ['TEST_SRCDIR'],
os.environ['TEST_WORKSPACE'],
test_dir,
)
lit.llvm.initialize(lit_config, config)

# Let the main config do the real work.
lit_config.load_config(
config,
os.path.join(
os.path.join(
os.environ['TEST_SRCDIR'],
os.environ['TEST_WORKSPACE'],
'brain/megbrain/src/jit/test/mlir/utils/lit.bzl.cfg.py',
)))

+ 49
- 0
src/jit/test/mlir/utils/lit.bzl.site.cfg.py.in View File

@@ -0,0 +1,49 @@
@LIT_SITE_CFG_IN_HEADER@

import sys

config.host_triple = "@LLVM_HOST_TRIPLE@"
config.target_triple = "@TARGET_TRIPLE@"
config.llvm_src_root = "@LLVM_SOURCE_DIR@"
config.llvm_obj_root = "@LLVM_BINARY_DIR@"
config.llvm_tools_dir = "@LLVM_TOOLS_DIR@"
config.llvm_lib_dir = "@LLVM_LIBRARY_DIR@"
config.llvm_shlib_dir = "@SHLIBDIR@"
config.llvm_shlib_ext = "@SHLIBEXT@"
config.llvm_exe_ext = "@EXEEXT@"
config.lit_tools_dir = "@LLVM_LIT_TOOLS_DIR@"
config.python_executable = "@PYTHON_EXECUTABLE@"
config.gold_executable = "@GOLD_EXECUTABLE@"
config.ld64_executable = "@LD64_EXECUTABLE@"
config.enable_shared = @ENABLE_SHARED@
config.enable_assertions = @ENABLE_ASSERTIONS@
config.targets_to_build = "@TARGETS_TO_BUILD@"
config.native_target = "@LLVM_NATIVE_ARCH@"
config.llvm_bindings = "@LLVM_BINDINGS@".split(' ')
config.host_os = "@HOST_OS@"
config.host_cc = "@HOST_CC@"
config.host_cxx = "@HOST_CXX@"
# Note: ldflags can contain double-quoted paths, so must use single quotes here.
config.host_ldflags = '@HOST_LDFLAGS@'
config.llvm_use_sanitizer = "@LLVM_USE_SANITIZER@"
config.llvm_host_triple = '@LLVM_HOST_TRIPLE@'
config.host_arch = "@HOST_ARCH@"
config.mgb_src_root = "@CMAKE_SOURCE_DIR@"
config.mgb_obj_root = "@CMAKE_BINARY_DIR@"

# Support substitution of the tools_dir with user parameters. This is
# used when we can't determine the tool dir at configuration time.
try:
config.llvm_tools_dir = config.llvm_tools_dir % lit_config.params
config.llvm_shlib_dir = config.llvm_shlib_dir % lit_config.params
except KeyError:
e = sys.exc_info()[1]
key, = e.args
lit_config.fatal("unable to find %r parameter, use '--param=%s=VALUE'" % (key,key))


import lit.llvm
lit.llvm.initialize(lit_config, config)

# Let the main config do the real work.
lit_config.load_config(config, "@CMAKE_SOURCE_DIR@/src/jit/test/mlir/utils/lit.cfg.py")

+ 58
- 0
src/jit/test/mlir/utils/lit.cfg.py View File

@@ -0,0 +1,58 @@
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.

import os
import platform
import re
import subprocess
import tempfile

import lit.formats
import lit.util

from lit.llvm import llvm_config
from lit.llvm.subst import ToolSubst
from lit.llvm.subst import FindTool

# Configuration file for the 'lit' test runner.

# name: The name of this test suite.
config.name = 'MLIR_TEST'

config.test_format = lit.formats.ShTest(not llvm_config.use_lit_shell)

# suffixes: A list of file extensions to treat as test files.
config.suffixes = ['.mlir']

# test_source_root: The root path where tests are located.
config.test_source_root = os.path.join(os.path.dirname(__file__), '../ir')

# test_exec_root: The root path where tests should be run.
config.test_exec_root = config.test_source_root

# llvm_config.use_default_substitutions()

# Tweak the PATH to include the tools dir.
llvm_config.with_environment('PATH', config.llvm_tools_dir, append_path=True)

tool_dirs = [
os.path.join(config.mgb_obj_root, 'tools/mlir'),
os.path.join(config.mgb_obj_root, 'tools/mlir/mgb-opt'),
os.path.join(config.mgb_obj_root, 'tools/mlir/mgb-file-check'),
config.llvm_tools_dir]
tool_names = [
'mgb-opt',
'mlir-tblgen',
'mlir-translate',
'mgb-file-check',
]
tools = [ToolSubst(s, unresolved='ignore') for s in tool_names]
llvm_config.add_tool_substitutions(tools, tool_dirs)

lit.llvm.initialize(lit_config, config)

+ 49
- 0
src/jit/test/mlir/utils/lit.site.cfg.py.in View File

@@ -0,0 +1,49 @@
@LIT_SITE_CFG_IN_HEADER@

import sys

config.host_triple = "@LLVM_HOST_TRIPLE@"
config.target_triple = "@TARGET_TRIPLE@"
config.llvm_src_root = "@LLVM_SOURCE_DIR@"
config.llvm_obj_root = "@LLVM_BINARY_DIR@"
config.llvm_tools_dir = "@LLVM_BINARY_DIR@/bin"
config.llvm_lib_dir = "@LLVM_LIBRARY_DIR@"
config.llvm_shlib_dir = "@SHLIBDIR@"
config.llvm_shlib_ext = "@SHLIBEXT@"
config.llvm_exe_ext = "@EXEEXT@"
config.lit_tools_dir = "@LLVM_LIT_TOOLS_DIR@"
config.python_executable = "@PYTHON_EXECUTABLE@"
config.gold_executable = "@GOLD_EXECUTABLE@"
config.ld64_executable = "@LD64_EXECUTABLE@"
config.enable_shared = @ENABLE_SHARED@
config.enable_assertions = @ENABLE_ASSERTIONS@
config.targets_to_build = "@TARGETS_TO_BUILD@"
config.native_target = "@LLVM_NATIVE_ARCH@"
config.llvm_bindings = "@LLVM_BINDINGS@".split(' ')
config.host_os = "@HOST_OS@"
config.host_cc = "@HOST_CC@"
config.host_cxx = "@HOST_CXX@"
# Note: ldflags can contain double-quoted paths, so must use single quotes here.
config.host_ldflags = '@HOST_LDFLAGS@'
config.llvm_use_sanitizer = "@LLVM_USE_SANITIZER@"
config.llvm_host_triple = '@LLVM_HOST_TRIPLE@'
config.host_arch = "@HOST_ARCH@"
config.mgb_src_root = "@CMAKE_SOURCE_DIR@"
config.mgb_obj_root = "@CMAKE_BINARY_DIR@"

# Support substitution of the tools_dir with user parameters. This is
# used when we can't determine the tool dir at configuration time.
try:
config.llvm_tools_dir = config.llvm_tools_dir % lit_config.params
config.llvm_shlib_dir = config.llvm_shlib_dir % lit_config.params
except KeyError:
e = sys.exc_info()[1]
key, = e.args
lit_config.fatal("unable to find %r parameter, use '--param=%s=VALUE'" % (key,key))


import lit.llvm
lit.llvm.initialize(lit_config, config)

# Let the main config do the real work.
lit_config.load_config(config, "@CMAKE_SOURCE_DIR@/src/jit/test/mlir/utils/lit.cfg.py")

+ 6
- 0
test/CMakeLists.txt View File

@@ -43,3 +43,9 @@ endif()
if (MGE_WITH_DISTRIBUTED)
target_link_libraries(megbrain_test megray)
endif()

if(MGE_WITH_JIT)
if(MGE_WITH_JIT_MLIR)
add_subdirectory(${PROJECT_SOURCE_DIR}/src/jit/test/mlir ${CMAKE_CURRENT_BINARY_DIR}/../src/jit/test/mlir)
endif()
endif()

Loading…
Cancel
Save