Browse Source

feat(mgb/jit): add mlir backend for cpu and cuda

GitOrigin-RevId: 814fed047e
tags/v1.0.0-rc1
Megvii Engine Team 4 years ago
parent
commit
a51d5b4c31
29 changed files with 1999 additions and 4 deletions
  1. +1
    -0
      .gitattributes
  2. +9
    -0
      src/core/impl/comp_node_env.cpp
  3. +22
    -0
      src/core/include/megbrain/comp_node_env.h
  4. +22
    -1
      src/jit/impl/compiler.cpp
  5. +198
    -0
      src/jit/impl/mlir/compiler.cpp
  6. +56
    -0
      src/jit/impl/mlir/compiler.h
  7. +134
    -0
      src/jit/impl/mlir/executable_cpu.cpp
  8. +51
    -0
      src/jit/impl/mlir/executable_cpu.h
  9. +166
    -0
      src/jit/impl/mlir/executable_cuda.cpp
  10. +74
    -0
      src/jit/impl/mlir/executable_cuda.h
  11. +41
    -0
      src/jit/impl/mlir/ir/common.cpp
  12. +32
    -0
      src/jit/impl/mlir/ir/common.h
  13. +91
    -0
      src/jit/impl/mlir/ir/dialect.cpp
  14. +159
    -0
      src/jit/impl/mlir/ir/lower_to_affine_pass.cpp
  15. +211
    -0
      src/jit/impl/mlir/ir/lower_to_gpu_pass.cpp
  16. +56
    -0
      src/jit/impl/mlir/ir/lower_to_llvm_pass.cpp
  17. +72
    -0
      src/jit/impl/mlir/ir/ops.td
  18. +30
    -0
      src/jit/impl/mlir/ir/shape_inference_interface.td
  19. +100
    -0
      src/jit/impl/mlir/ir/shape_inference_pass.cpp
  20. +207
    -0
      src/jit/impl/mlir/mlir_gen.cpp
  21. +42
    -0
      src/jit/impl/mlir/mlir_gen.h
  22. +45
    -0
      src/jit/impl/mlir/utils.h
  23. +45
    -0
      src/jit/include/megbrain/jit/mlir/ir/dialect.h
  24. +43
    -0
      src/jit/include/megbrain/jit/mlir/ir/passes.h
  25. +33
    -0
      src/jit/include/megbrain/jit/mlir/ir/shape_inference_interface.h
  26. +2
    -2
      src/jit/include/megbrain/jit/param_elem_visitor.h
  27. +53
    -0
      src/jit/test/codegen.cpp
  28. +3
    -0
      src/jit/test/helper.cpp
  29. +1
    -1
      src/jit/test/helper.h

+ 1
- 0
.gitattributes View File

@@ -4,3 +4,4 @@ dnn/src/cuda/conv_bias/int8/kimpl/* binary
dnn/src/cuda/conv_bias/int8_imma/kimpl/* binary
dnn/src/cuda/batch_conv_bias/int8/kimpl/* binary
dnn/src/cuda/sass/prebuilt/map_defs.cpp binary
tools/mlir/mlir-tblgen filter=lfs diff=lfs merge=lfs -text

+ 9
- 0
src/core/impl/comp_node_env.cpp View File

@@ -162,6 +162,15 @@ void mgb::_on_cuda_error(const char* expr, cudaError_t err, const char* file,
cudaGetErrorString(err), expr, file, func, line);
}

void mgb::_on_cuda_cu_error(const char* expr, CUresult err, const char* file,
const char* func, int line) {
const char* msg;
cuGetErrorName(err, &msg);
mgb_throw(CudaError, "cuda error %d: %s (%s at %s:%s:%d)", int(err), msg,
expr, file, func, line);
}


void CompNodeEnv::init_cuda_async(int dev, CompNode comp_node,
const ContinuationCtx<cudaStream_t>& cont) {
m_comp_node = comp_node;


+ 22
- 0
src/core/include/megbrain/comp_node_env.h View File

@@ -22,6 +22,7 @@

#if MGB_CUDA
#include <cuda_runtime.h>
#include <cuda.h>

#if MGB_ENABLE_LOGGING
#define MGB_CUDA_CHECK(expr) \
@@ -32,6 +33,16 @@
__func__, __LINE__); \
} \
} while (0)

#define MGB_CUDA_CU_CHECK(expr) \
do { \
CUresult __cuda_check_code = (expr); \
if (!mgb_likely(__cuda_check_code == CUDA_SUCCESS)) { \
::mgb::_on_cuda_cu_error(#expr, __cuda_check_code, __FILE__, \
__func__, __LINE__); \
} \
} while (0)

#else
#define MGB_CUDA_CHECK(expr) \
do { \
@@ -41,6 +52,14 @@
} \
} while (0)

#define MGB_CUDA_CU_CHECK(expr) \
do { \
CUresult __cuda_check_code = (expr); \
if (!mgb_likely(__cuda_check_code == CUDA_SUCCESS)) { \
::mgb::_on_cuda_cu_error(#expr, __cuda_check_code, "", "", 1); \
} \
} while (0)

#endif //MGB_ENABLE_LOGGING
#endif //MGB_CUDA

@@ -178,6 +197,9 @@ namespace mgb {
#if MGB_CUDA
[[noreturn]] void _on_cuda_error(const char* expr, cudaError_t err,
const char* file, const char* func, int line);
[[noreturn]] void _on_cuda_cu_error(const char* expr, CUresult err,
const char* file, const char* func,
int line);
#endif




+ 22
- 1
src/jit/impl/compiler.cpp View File

@@ -11,6 +11,7 @@

#include "./halide/compiler_cuda.h"
#include "./nvrtc/compiler_cuda.h"
#include "./mlir/compiler.h"

#include "megbrain/jit/compiler.h"
#include "megbrain/utils/hash.h"
@@ -54,6 +55,8 @@ bool Compiler::is_supported_device(CompNode::DeviceType device) {
case CompNode::DeviceType::CUDA:
return true;
#endif
case CompNode::DeviceType::CPU:
return true;
default:
return false;
}
@@ -87,12 +90,30 @@ Compiler* Compiler::get(ComputingGraph& graph, CompNode comp_node) {
break;
}
#endif
#if MGB_JIT_MLIR
if (!backend || !strcmp(backend, "MLIR")) {
compiler = std::make_unique<MLIRCompiler>(
CompNode::DeviceType::CUDA);
break;
}
#endif
if (!backend || !strcmp(backend, "NVRTC")) {
compiler = std::make_unique<CudaCompiler>();
break;
}
#endif
// fall through
mgb_throw(InternalError, "No compiler support for cuda");
break;
case CompNode::DeviceType::CPU:
#if MGB_JIT_MLIR
if (!backend || !strcmp(backend, "MLIR")) {
compiler = std::make_unique<MLIRCompiler>(
CompNode::DeviceType::CPU);
break;
}
#endif
mgb_throw(InternalError, "No compiler support for cpu");
break;
default:
mgb_throw(InternalError,
"unsupported JIT config: "


+ 198
- 0
src/jit/impl/mlir/compiler.cpp View File

@@ -0,0 +1,198 @@
/**
* \file src/jit/impl/mlir/compiler.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/

#include "megbrain_build_config.h"
#if MGB_JIT && MGB_JIT_MLIR

#include "./compiler.h"
#include "./executable_cpu.h"
#include "./executable_cuda.h"
#include "./mlir_gen.h"

#include "megbrain/common.h"
#include "megbrain/comp_node_env.h"
#include "megbrain/jit/mlir/ir/dialect.h"
#include "megbrain/jit/mlir/ir/passes.h"

#include <mlir/Conversion/GPUCommon/GPUCommonPass.h>
#include <mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h>
#include <mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h>
#include <mlir/Dialect/GPU/Passes.h>
#include <mlir/IR/Dialect.h>
#include <mlir/IR/MLIRContext.h>
#include <mlir/IR/Module.h>
#include <mlir/InitAllDialects.h>
#include <mlir/Pass/PassManager.h>
#include <mlir/Support/LogicalResult.h>
#include <mlir/Target/NVVMIR.h>
#include <mlir/Transforms/Passes.h>

#include <llvm/Support/TargetSelect.h>

using namespace mgb;
using namespace jit;

namespace {

struct LLVMInitializer {
LLVMInitializer() {
llvm::InitializeNativeTarget();
llvm::InitializeNativeTargetAsmPrinter();
}
};
static LLVMInitializer initializer;

#if MGB_CUDA
mlir::OwnedBlob compile_ptx_to_cubin(const std::string ptx, mlir::Location,
llvm::StringRef) {
OwnedBlob result = std::make_unique<std::vector<char>>(
ptx.data(), ptx.data() + ptx.size());

return result;
}

#endif

void add_cpu_lowering_pass(mlir::PassManager& manager) {
{
mlir::OpPassManager& opt_pm = manager.nest<mlir::FuncOp>();
opt_pm.addPass(create_shape_inference_pass());
opt_pm.addPass(mlir::createCanonicalizerPass());
opt_pm.addPass(mlir::createCSEPass());
}

manager.addPass(create_lower_to_affine_pass());
{
mlir::OpPassManager& opt_pm = manager.nest<mlir::FuncOp>();
opt_pm.addPass(mlir::createCanonicalizerPass());
opt_pm.addPass(mlir::createCSEPass());
opt_pm.addPass(mlir::createLoopFusionPass());
opt_pm.addPass(mlir::createMemRefDataFlowOptPass());
}
manager.addPass(create_lower_to_llvm_pass());
}

#if MGB_CUDA
void add_cuda_lowering_pass(mlir::PassManager& manager, CompNode cn) {
{
mlir::OpPassManager& opt_pm = manager.nest<mlir::FuncOp>();
opt_pm.addPass(create_shape_inference_pass());
opt_pm.addPass(mlir::createCanonicalizerPass());
opt_pm.addPass(mlir::createCSEPass());
}
manager.addPass(create_lower_to_gpu_pass());
{
mlir::OpPassManager& opt_pm = manager.nest<mlir::FuncOp>();
opt_pm.addPass(mlir::createCanonicalizerPass());
opt_pm.addPass(mlir::createCSEPass());
opt_pm.addPass(mlir::createLoopFusionPass());
opt_pm.addPass(mlir::createMemRefDataFlowOptPass());
}
manager.addPass(mlir::createGpuKernelOutliningPass());
{
auto& kernel_pm = manager.nest<gpu::GPUModuleOp>();
kernel_pm.addPass(mlir::createLowerGpuOpsToNVVMOpsPass());

auto&& prop = CompNodeEnv::from_comp_node(cn).cuda_env().device_prop;
kernel_pm.addPass(mlir::createConvertGPUKernelToBlobPass(
mlir::translateModuleToNVVMIR, compile_ptx_to_cubin,
"nvptx64-nvidia-cuda",
ssprintf("sm_%d%d", prop.major, prop.minor), "+ptx60",
MLIRCUDAExecutable::sm_blob_annotation));
}
}
#endif

} // namespace

/* ==================== MLIRCompiler ===================== */

thread_local mlir::MLIRContext MLIRCompiler::sm_ctx;

MLIRCompiler::MLIRCompiler(CompNode::DeviceType device_type)
: m_device_type{device_type} {
mlir::registerAllDialects();
mlir::registerDialect<MgbDialect>();

#if MGB_CUDA
if (m_device_type == CompNode::DeviceType::CUDA) {
LLVMInitializeNVPTXTarget();
LLVMInitializeNVPTXTargetInfo();
LLVMInitializeNVPTXTargetMC();
LLVMInitializeNVPTXAsmPrinter();
}
#endif
}

void MLIRCompiler::run_lowering_pass(mlir::OwningModuleRef& module,
CompNode cn) {
mgb_assert(cn.device_type() == m_device_type);
mlir::PassManager manager(module->getContext());
switch (m_device_type) {
case CompNode::DeviceType::CPU:
add_cpu_lowering_pass(manager);
break;
#if MGB_CUDA
case CompNode::DeviceType::CUDA:
add_cuda_lowering_pass(manager, cn);
break;
#endif
default:
mgb_throw(InternalError, "Unsupport device type: %d",
static_cast<int>(m_device_type));
break;
}
mgb_assert(mlir::succeeded(manager.run(*module)));
}

std::unique_ptr<Executable> MLIRCompiler::do_compile(
const InternalGraph& graph, const JITExecutor::Args& args) {
MGB_MARK_USED_VAR(graph);
MGB_MARK_USED_VAR(args);

mlir::MLIRContext ctx;
ctx.printStackTraceOnDiagnostic(true);
ctx.printOpOnDiagnostic(true);

auto&& res = mlir_gen(ctx, graph, args);
mgb_assert(res.second, "failed to generate module");

CompNode cn = args.owner->comp_node();
run_lowering_pass(res.second, cn);
switch (cn.device_type()) {
case CompNode::DeviceType::CPU:
return std::make_unique<MLIRCPUExecutable>(res.second,
res.first.str());
#if MGB_CUDA
case CompNode::DeviceType::CUDA:
return std::make_unique<MLIRCUDAExecutable>(res.second,
res.first.str());
#endif
default:
mgb_throw(InternalError, "Unsupport device type: %d",
static_cast<int>(cn.device_type()));
return nullptr;
}
}

size_t MLIRCompiler::get_nr_workspace_outputs(JITExecutor* opr) const {
MGB_MARK_USED_VAR(opr);
return 0;
}

void MLIRCompiler::init_workspace_size_infer(JITExecutor* opr) {
MGB_MARK_USED_VAR(opr);
}

#endif // MGB_JIT && MGB_JIT_MLIR

// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}

+ 56
- 0
src/jit/impl/mlir/compiler.h View File

@@ -0,0 +1,56 @@
/**
* \file src/jit/impl/mlir/compiler.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/

#pragma once

#include "megbrain_build_config.h"
#if MGB_JIT && MGB_JIT_MLIR

#include "megbrain/jit/compiler.h"

#include <mlir/ExecutionEngine/ExecutionEngine.h>
#include <mlir/IR/Module.h>

namespace mgb {
namespace jit {

/*!
* \brief MLIR compiler
*/
class MLIRCompiler final : public Compiler {
std::unique_ptr<Executable> do_compile(
const InternalGraph& graph, const JITExecutor::Args& args) override;

public:
MLIRCompiler(CompNode::DeviceType device_type = CompNode::DeviceType::CPU);
Property property() const override {
using F = Property::Flag;
return Property{F::NEED_INPUT_COLLAPSE | F::BIND_NDIM,
JITFeatureBits::NONE, 64};
}

size_t get_nr_workspace_outputs(JITExecutor* opr) const override;

void init_workspace_size_infer(JITExecutor* opr) override;

private:
void run_lowering_pass(mlir::OwningModuleRef& module, CompNode cn);

CompNode::DeviceType m_device_type;
static thread_local mlir::MLIRContext sm_ctx;
};

} // namespace jit
} // namespace mgb

#endif // MGB_JIT && MGB_JIT_MLIR

// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}

+ 134
- 0
src/jit/impl/mlir/executable_cpu.cpp View File

@@ -0,0 +1,134 @@
/**
* \file src/jit/impl/mlir/executable_cpu.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/

#include "megbrain_build_config.h"
#if MGB_JIT && MGB_JIT_MLIR

#include "./executable_cpu.h"
#include "./utils.h"

#include <mlir/ExecutionEngine/OptUtils.h>

using namespace mgb;
using namespace jit;

namespace {

template <int N>
void* tensor2memref_dim(const megdnn::TensorND& tensor) {
switch (tensor.layout.dtype.enumv()) {
case megdnn::DTypeEnum::Float32: {
StridedMemRefType<float, N>* desc =
static_cast<StridedMemRefType<float, N>*>(
malloc(sizeof(StridedMemRefType<float, N>)));
desc->basePtr = tensor.ptr<float>();
desc->data = tensor.ptr<float>();
desc->offset = 0;
for (size_t i = 0; i < tensor.layout.ndim; i++) {
desc->sizes[i] = tensor.layout.shape[i];
desc->strides[i] = tensor.layout.stride[i];
}
return desc;
break;
}
default:
mgb_throw(InternalError, "Unsupport dtype, got %s",
tensor.layout.dtype.name());
break;
}
return nullptr;
}

void* tensor2memref(const megdnn::TensorND& tensor) {
switch (tensor.layout.ndim) {
#define cb(i) \
case i: \
return tensor2memref_dim<i>(tensor)

cb(1);
cb(2);
cb(3);
cb(4);
cb(5);
default:
mgb_throw(InternalError, "Unsupported ndim, got %zu",
tensor.layout.ndim);
#undef cb
}
}

} // namespace
MLIRCPUExecutable::MLIRCPUExecutable(mlir::OwningModuleRef& module,
const std::string& kernel_name)
: m_kernel_name{kernel_name} {
auto opt_pipeline = mlir::makeOptimizingTransformer(3, 3, 0);
std::vector<std::string> libs;
auto&& engine = mlir::ExecutionEngine::create(
*module, opt_pipeline, llvm::None,
std::vector<llvm::StringRef>(libs.begin(), libs.end()), true,
false);
mgb_assert(engine);
m_engine = std::move(*engine);
}

void MLIRCPUExecutable::execute(JITExecutor* fusion_opr) {
auto&& args = fusion_opr->args();
std::vector<void*> args_array(args.inputs.size() + args.outputs.size());
std::vector<void*> args_array_pointer(args.inputs.size() +
args.outputs.size());
size_t idx = 0;
for (size_t i = 0; i < args.inputs.size(); i++) {
args_array[idx] =
tensor2memref({args.inputs[i].from->dev_tensor().raw_ptr(),
args.inputs[i].layout});
args_array_pointer[idx] = &args_array[idx];
idx++;
}
int64_t nr_elements = 0;
for (size_t i = 0; i < args.outputs.size(); i++) {
if (nr_elements == 0) {
nr_elements = args.outputs[i].layout.total_nr_elems();
} else {
mgb_assert(static_cast<size_t>(nr_elements) ==
args.outputs[i].layout.total_nr_elems(),
"The number of elements of outputs mismatch, expected: "
"%zu got: %zu(%s)",
static_cast<size_t>(nr_elements),
args.outputs[i].layout.total_nr_elems(),
args.outputs[i].layout.to_string().c_str());
}
args_array[idx] =
tensor2memref({args.outputs[i].from->dev_tensor().raw_ptr(),
args.outputs[i].layout});
args_array_pointer[idx] = &args_array[idx];
idx++;
}

args_array_pointer[idx++] = &nr_elements;
std::string adapter_name = std::string("_mlir_ciface_") + m_kernel_name;
auto err = m_engine->invoke(
adapter_name, llvm::MutableArrayRef<void*>(args_array_pointer));
if (err) {
mgb_throw(InternalError, "failed to run MLIR kernel %s\n",
m_kernel_name.c_str());
}

for (size_t i = 0; i < args_array.size(); i++) {
free(args_array[i]);
}
}

MLIRCPUExecutable::~MLIRCPUExecutable() {}

#endif // MGB_JIT && MGB_JIT_MLIR

// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}

+ 51
- 0
src/jit/impl/mlir/executable_cpu.h View File

@@ -0,0 +1,51 @@
/**
* \file src/jit/impl/mlir/executable_cpu.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/

#pragma once

#include "megbrain_build_config.h"
#if MGB_JIT && MGB_JIT_MLIR

#include "megbrain/jit/compiler.h"

#include <mlir/ExecutionEngine/ExecutionEngine.h>
#include <mlir/IR/Module.h>

namespace mgb {
namespace jit {

/*!
* \brief Executable class for MLIR
*/
class MLIRCPUExecutable final : public Executable {
public:
MLIRCPUExecutable(mlir::OwningModuleRef& module,
const std::string& kernel_name);
~MLIRCPUExecutable();

/*!
* \brief execute
* A executable instance can be executed by one or more fusion_opr
*/
void execute(JITExecutor* fusion_opr) override final;

private:
std::unique_ptr<mlir::ExecutionEngine> m_engine;
std::string m_kernel_name;
};

} // namespace jit
} // namespace mgb

#endif // MGB_JIT && MGB_JIT_MLIR

// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}

+ 166
- 0
src/jit/impl/mlir/executable_cuda.cpp View File

@@ -0,0 +1,166 @@
/**
* \file src/jit/impl/mlir/executable_cuda.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/

#include <vector>
#include "megbrain_build_config.h"
#include "megdnn/dtype.h"
#if MGB_JIT && MGB_JIT_MLIR

#if MGB_CUDA
#include "./executable_cuda.h"
#include "./utils.h"
#include "megbrain/utils/timer.h"
#include "megbrain/utils/persistent_cache.h"
#include "megbrain/comp_node_env.h"

#include <mlir/ExecutionEngine/OptUtils.h>
#include <mlir/Dialect/GPU/GPUDialect.h>
#include <mlir/IR/OpDefinition.h>

using namespace mgb;
using namespace jit;

namespace {
template <int out_dim, typename ctype>
void setup_and_launch(const JITExecutor* fusion_opr, CUfunction func,
int block_size) {
auto&& args = fusion_opr->args();
std::vector<StridedMemRefType<ctype, out_dim>> param_holders;
std::vector<void*> params;

auto set_params = [&param_holders, &params](
void* ptr, const megdnn::TensorLayout& layout) {
param_holders.push_back(StridedMemRefType<ctype, out_dim>{});
StridedMemRefType<ctype, out_dim>& desc = param_holders.back();
desc.basePtr = static_cast<ctype*>(ptr);
params.push_back(&(desc.basePtr));
desc.data = static_cast<ctype*>(ptr);
params.push_back(&(desc.data));
desc.offset = 0;
params.push_back(&(desc.offset));
for (size_t i = 0; i < layout.ndim; i++) {
desc.sizes[i] = layout.shape[i];
params.push_back(&(desc.sizes[i]));
desc.strides[i] = layout.stride[i];
params.push_back(&(desc.strides[i]));
}
};
for (const auto& arg : args.inputs) {
set_params(arg.from->dev_tensor().raw_ptr(), arg.layout);
}
int64_t nr_elements = 0;
for (const auto& arg : args.outputs) {
if (nr_elements == 0) {
nr_elements = arg.layout.total_nr_elems();
} else {
mgb_assert(static_cast<size_t>(nr_elements) ==
arg.layout.total_nr_elems(),
"The number of elements of outputs mismatch, expected: "
"%zu got: %zu(%s)",
static_cast<size_t>(nr_elements),
arg.layout.total_nr_elems(),
arg.layout.to_string().c_str());
}

set_params(arg.from->dev_tensor().raw_ptr(), arg.layout);
}
const CompNodeEnv& env =
CompNodeEnv::from_comp_node(fusion_opr->comp_node());

int64_t num_block = (nr_elements - 1) / block_size + 1;
params.insert(params.begin(), &nr_elements);
MGB_CUDA_CU_CHECK(cuLaunchKernel(func, num_block, 1, 1, block_size, 1, 1, 0,
env.cuda_env().stream, params.data(), 0));
}
} // namespace

const std::string MLIRCUDAExecutable::sm_blob_annotation = "nvvm.cubin";
MLIRCUDAExecutable::MLIRCUDAExecutable(mlir::OwningModuleRef& module,
const std::string& kernel_name) {
m_kernel_name = kernel_name + "_kernel";
auto kernel_module =
module->lookupSymbol<mlir::gpu::GPUModuleOp>(m_kernel_name);
mgb_assert(kernel_module, "Expected gpu kernel module");

auto binary_attr = kernel_module.getAttrOfType<mlir::StringAttr>(
llvm::StringRef(sm_blob_annotation));
mgb_assert(binary_attr, "Missing %s attribute in gpu kernel module",
sm_blob_annotation.c_str());
m_kernel_data = binary_attr.getValue().str();
}

void MLIRCUDAExecutable::execute(JITExecutor* fusion_opr) {
FuncCache* func;
auto cn = fusion_opr->comp_node();
auto&& prop = CompNodeEnv::from_comp_node(cn).cuda_env().device_prop;
func = &m_func_cache[{prop.major, prop.minor}];
func->kernel_data = m_kernel_data;
func->exec(fusion_opr, this);
}

MLIRCUDAExecutable::~MLIRCUDAExecutable() {}

void MLIRCUDAExecutable::FuncCache::exec(const JITExecutor* fusion_opr,
const MLIRCUDAExecutable* cuda_exe) {
Func* func;
{
MGB_LOCK_GUARD(mtx);
auto ins = cn2func.insert({fusion_opr->comp_node(), {}});
func = &ins.first->second;
if (ins.second) {
MGB_CUDA_CU_CHECK(
cuModuleLoadData(&func->module, kernel_data.data()));
MGB_CUDA_CU_CHECK(
cuModuleGetFunction(&func->func, func->module,
cuda_exe->m_kernel_name.c_str()));
int min_grid_size = 0;
MGB_CUDA_CU_CHECK(cuOccupancyMaxPotentialBlockSize(
&min_grid_size, &func->block_size, func->func, nullptr, 0,
0));
}
}

mgb_assert(fusion_opr->args().outputs.size() == 1,
"Currently only support 1 outputs, got %zu",
fusion_opr->args().outputs.size());
int out_dim = fusion_opr->args().outputs[0].layout.ndim;
DType dtype = fusion_opr->args().outputs[0].layout.dtype;
#define cb_outdim(_ndim, _dtype) \
if (_ndim == out_dim) { \
setup_and_launch<_ndim, _dtype>(fusion_opr, func->func, \
func->block_size); \
return; \
}

#define cb(_dtype) \
cb_outdim(1, float); \
cb_outdim(2, float); \
cb_outdim(3, float); \
cb_outdim(4, float); \
mgb_throw(InternalError, "unsupported out_dim=%zu", \
static_cast<size_t>(out_dim)); \
return;

switch (dtype.enumv()) {
case DTypeEnum::Float32:
cb(float);
default:
mgb_throw(InternalError, "unsupport dtype: %s", dtype.name());
}
#undef cb
#undef cb_outdim
}

#endif // MGB_CUDA
#endif // MGB_JIT && MGB_JIT_MLIR

// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}

+ 74
- 0
src/jit/impl/mlir/executable_cuda.h View File

@@ -0,0 +1,74 @@
/**
* \file src/jit/impl/mlir/executable_cuda.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/

#pragma once

#include "megbrain_build_config.h"
#if MGB_JIT && MGB_JIT_MLIR

#if MGB_CUDA
#include "megbrain/jit/compiler.h"

#include <mlir/IR/Module.h>

#include <cuda.h>

namespace mgb {
namespace jit {

/*!
* \brief Executable class for MLIR
*/
class MLIRCUDAExecutable final : public Executable {
public:
MLIRCUDAExecutable(mlir::OwningModuleRef& module,
const std::string& kernel_name);
~MLIRCUDAExecutable();

/*!
* \brief execute
* A executable instance can be executed by one or more fusion_opr
*/
void execute(JITExecutor* fusion_opr) override final;

const static std::string sm_blob_annotation;
private:
//! cache for a func on a specific device
struct FuncCache {
struct Func {
int block_size{-1};
CUmodule module{nullptr};
CUfunction func{nullptr};
};

std::mutex mtx;
std::string kernel_data;
CompNode::UnorderedMap<Func> cn2func;

void exec(const JITExecutor* fusion_opr,
const MLIRCUDAExecutable* cuda_exe);
};

std::string m_kernel_name;
std::string m_kernel_data;

//! (cuda_major, cuda_minor) => func
ThinHashMap<std::pair<uint32_t, uint32_t>, FuncCache> m_func_cache;
};

} // namespace jit
} // namespace mgb

#endif // MGB_CUDA
#endif // MGB_JIT && MGB_JIT_MLIR

// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}

+ 41
- 0
src/jit/impl/mlir/ir/common.cpp View File

@@ -0,0 +1,41 @@
/**
* \file src/jit/impl/mlir/ir/common.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/

#include "megbrain_build_config.h"
#if MGB_JIT && MGB_JIT_MLIR

#include "common.h"

#include <mlir/Dialect/Affine/IR/AffineOps.h>

using namespace mgb;
using namespace jit;

mlir::Value jit::insert_alloc_and_dealloc(mlir::MemRefType type,
mlir::Location loc,
mlir::PatternRewriter& rewriter) {
auto alloc = rewriter.create<mlir::AllocOp>(loc, type);

// Make sure to allocate at the beginning of the block.
auto* parent_block = alloc.getOperation()->getBlock();
alloc.getOperation()->moveBefore(&parent_block->front());

// Make sure to deallocate this alloc at the end of the block. This is fine
// as toy functions have no control flow.
auto dealloc = rewriter.create<mlir::DeallocOp>(loc, alloc);
dealloc.getOperation()->moveBefore(&parent_block->back());
return alloc;
}

#endif // MGB_JIT_MLIR

// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}

+ 32
- 0
src/jit/impl/mlir/ir/common.h View File

@@ -0,0 +1,32 @@
/**
* \file src/jit/impl/mlir/ir/common.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/

#pragma once

#include "megbrain_build_config.h"
#if MGB_JIT && MGB_JIT_MLIR

#include <mlir/IR/PatternMatch.h>
#include <mlir/IR/StandardTypes.h>
#include <mlir/IR/Value.h>

namespace mgb {
namespace jit {

mlir::Value insert_alloc_and_dealloc(mlir::MemRefType type, mlir::Location loc,
mlir::PatternRewriter& rewriter);

} // namespace jit
} // namespace mgb

#endif // MGB_JIT_MLIR

// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}

+ 91
- 0
src/jit/impl/mlir/ir/dialect.cpp View File

@@ -0,0 +1,91 @@
/**
* \file src/jit/impl/mlir/ir/dialect.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/

#include "megbrain_build_config.h"
#if MGB_JIT && MGB_JIT_MLIR

#include "megbrain/jit/mlir/ir/dialect.h"

#include <mlir/Support/LogicalResult.h>
#include <mlir/IR/Builders.h>
#include <mlir/IR/OpImplementation.h>
#include <mlir/IR/StandardTypes.h>

using namespace mgb;
using namespace jit;

MgbDialect::MgbDialect(mlir::MLIRContext *ctx) : mlir::Dialect("mgb", ctx) {
addOperations<
#define GET_OP_LIST
#include "megbrain/jit/mlir/ir/ops.cpp.inc"
>();
}

static mlir::ParseResult parseBinaryOp(mlir::OpAsmParser &parser,
mlir::OperationState &result) {
SmallVector<mlir::OpAsmParser::OperandType, 2> operands;
llvm::SMLoc operandsLoc = parser.getCurrentLocation();
Type type;
if (parser.parseOperandList(operands, /*requiredOperandCount=*/2) ||
parser.parseOptionalAttrDict(result.attributes) ||
parser.parseColonType(type))
return mlir::failure();

// If the type is a function type, it contains the input and result types of
// this operation.
if (FunctionType funcType = type.dyn_cast<FunctionType>()) {
if (parser.resolveOperands(operands, funcType.getInputs(), operandsLoc,
result.operands))
return mlir::failure();
result.addTypes(funcType.getResults());
return mlir::success();
}

// Otherwise, the parsed type is the type of both operands and results.
if (parser.resolveOperands(operands, type, result.operands))
return mlir::failure();
result.addTypes(type);
return mlir::success();
}

static void printBinaryOp(mlir::OpAsmPrinter &printer, mlir::Operation *op) {
printer << op->getName() << " " << op->getOperands();
printer.printOptionalAttrDict(op->getAttrs());
printer << " : ";

// If all of the types are the same, print the type directly.
Type resultType = *op->result_type_begin();
if (llvm::all_of(op->getOperandTypes(),
[=](Type type) { return type == resultType; })) {
printer << resultType;
return;
}

// Otherwise, print a functional type.
printer.printFunctionalType(op->getOperandTypes(), op->getResultTypes());
}

///////////////////////// ElemwiseOp /////////////////////////////////////////////

void AddOp::build(mlir::OpBuilder &builder, mlir::OperationState &state,
mlir::Value lhs, mlir::Value rhs) {
state.addTypes(lhs.getType());
state.addOperands({lhs, rhs});
}
void AddOp::infer_shapes() { getResult().setType(getOperand(0).getType()); }

#define GET_OP_CLASSES
#include "megbrain/jit/mlir/ir/ops.cpp.inc"

#endif // MGB_JIT_MLIR

// vim: syntax=cpp.doxygen

+ 159
- 0
src/jit/impl/mlir/ir/lower_to_affine_pass.cpp View File

@@ -0,0 +1,159 @@
/**
* \file src/jit/impl/mlir/ir/lower_to_affine_pass.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/

#include "megbrain_build_config.h"
#if MGB_JIT && MGB_JIT_MLIR

#include "megbrain/common.h"
#include "megbrain/jit/mlir/ir/dialect.h"
#include "megbrain/jit/mlir/ir/passes.h"

#include "./common.h"

#include <mlir/Dialect/Affine/IR/AffineOps.h>
#include <mlir/Dialect/StandardOps/IR/Ops.h>
#include <mlir/Pass/Pass.h>
#include <mlir/Transforms/DialectConversion.h>

#include <llvm/ADT/Sequence.h>

using namespace mgb;
using namespace jit;

namespace {

using LoopIterationFn = function_ref<Value(
OpBuilder& rewriter, ValueRange memRefOperands, ValueRange loopIvs)>;

void lower_op_to_loops(Operation* op, ValueRange operands,
PatternRewriter& rewriter,
LoopIterationFn process_iteration) {
auto memref_type = (*op->result_type_begin()).cast<MemRefType>();
auto loc = op->getLoc();

auto alloc = jit::insert_alloc_and_dealloc(memref_type, loc, rewriter);

SmallVector<int64_t, 4> lower_bounds(memref_type.getRank(), 0);
SmallVector<int64_t, 4> steps(memref_type.getRank(), 1);
buildAffineLoopNest(
rewriter, loc, lower_bounds, memref_type.getShape(), steps,
[&](OpBuilder& nested_builder, Location loc, ValueRange ivs) {
Value value_to_store =
process_iteration(nested_builder, operands, ivs);
nested_builder.create<AffineStoreOp>(loc, value_to_store, alloc,
ivs);
});

// Replace this operation with the generated alloc.
rewriter.replaceOp(op, alloc);
}

template <typename BinaryOp, typename LoweredBinaryOp>
struct BinaryOpLowering : public ConversionPattern {
BinaryOpLowering(MLIRContext* ctx)
: ConversionPattern(BinaryOp::getOperationName(), 1, ctx) {}

LogicalResult matchAndRewrite(
Operation* op, ArrayRef<Value> operands,
ConversionPatternRewriter& rewriter) const final {
auto loc = op->getLoc();
lower_op_to_loops(
op, operands, rewriter,
[loc](OpBuilder& builder, ValueRange memref_operands,
ValueRange loop_ivs) {
typename BinaryOp::Adaptor binary_adaptor(memref_operands);

auto loaded_lhs = builder.create<AffineLoadOp>(
loc, binary_adaptor.lhs(), loop_ivs);
auto loaded_rhs = builder.create<AffineLoadOp>(
loc, binary_adaptor.rhs(), loop_ivs);

return builder.create<LoweredBinaryOp>(loc, loaded_lhs,
loaded_rhs);
});
return success();
}
};
using AddOpLowering = BinaryOpLowering<jit::AddOp, AddFOp>;

struct AssignOpLowering : public ConversionPattern {
AssignOpLowering(MLIRContext* ctx)
: ConversionPattern(jit::AssignOp::getOperationName(), 1, ctx) {}

LogicalResult matchAndRewrite(
Operation* op, ArrayRef<Value> operands,
ConversionPatternRewriter& rewriter) const final {
auto loc = op->getLoc();
auto memref_type = operands[0].getType().cast<MemRefType>();
AssignOpAdaptor assign_adaptor(operands);

SmallVector<int64_t, 4> lower_bounds(memref_type.getRank(), 0);
SmallVector<int64_t, 4> steps(memref_type.getRank(), 1);
buildAffineLoopNest(
rewriter, loc, lower_bounds, memref_type.getShape(), steps,
[&](OpBuilder& nested_builder, Location loc, ValueRange ivs) {
auto loaded_lhs = nested_builder.create<AffineLoadOp>(
loc, assign_adaptor.lhs(), ivs);
nested_builder.create<AffineStoreOp>(
loc, loaded_lhs, assign_adaptor.rhs(), ivs);
});

rewriter.eraseOp(op);
return success();
}
};

struct ReturnOpLowering : public OpRewritePattern<jit::ReturnOp> {
using OpRewritePattern<jit::ReturnOp>::OpRewritePattern;

LogicalResult matchAndRewrite(jit::ReturnOp op,
PatternRewriter& rewriter) const final {
rewriter.replaceOpWithNewOp<mlir::ReturnOp>(op);
return success();
}
};

class MgbToAffineLoweringPass
: public PassWrapper<MgbToAffineLoweringPass, FunctionPass> {
public:
void runOnFunction() override final {
auto function = getFunction();

// Verify that the given main has no inputs and results.
if (function.getType().getNumResults()) {
mgb_log_error("expected 'main' to have 0 results");
return signalPassFailure();
}

ConversionTarget target(getContext());
target.addLegalDialect<AffineDialect, StandardOpsDialect>();
target.addIllegalDialect<MgbDialect>();

OwningRewritePatternList patterns;
patterns.insert<AddOpLowering, ReturnOpLowering, AssignOpLowering>(
&getContext());

if (failed(applyPartialConversion(getFunction(), target, patterns))) {
signalPassFailure();
}
}
};

} // namespace

std::unique_ptr<mlir::Pass> mgb::jit::create_lower_to_affine_pass() {
return std::make_unique<MgbToAffineLoweringPass>();
}

#endif // MGB_JIT_MLIR

// vim: syntax=cpp.doxygen

+ 211
- 0
src/jit/impl/mlir/ir/lower_to_gpu_pass.cpp View File

@@ -0,0 +1,211 @@
/**
* \file src/jit/impl/mlir/ir/lower_to_gpu_pass.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/

#include "megbrain_build_config.h"
#if MGB_JIT && MGB_JIT_MLIR

#include "megbrain/common.h"
#include "megbrain/jit/mlir/ir/dialect.h"
#include "megbrain/jit/mlir/ir/passes.h"

#include "../utils.h"

#include <mlir/Dialect/GPU/GPUDialect.h>
#include <mlir/Dialect/SCF/SCF.h>
#include <mlir/Dialect/StandardOps/IR/Ops.h>
#include <mlir/EDSC/Builders.h>
#include <mlir/IR/StandardTypes.h>
#include <mlir/Pass/Pass.h>
#include <mlir/Transforms/DialectConversion.h>

#include <llvm/ADT/PointerUnion.h>
#include <llvm/ADT/Sequence.h>
#include <llvm/ADT/SetVector.h>
#include <llvm/ADT/Twine.h>
#include <llvm/IR/Type.h>

using namespace mgb;
using namespace jit;

namespace {

mlir::Value get_operand(ConversionPatternRewriter& rewriter,
const mlir::Location& loc, const mlir::Value& val,
const mlir::Value& index) {
if (val.getType().isa<mlir::MemRefType>()) {
return rewriter.create<LoadOp>(loc, val, index);
} else {
return val;
}
}

mlir::Value get_tid(ConversionPatternRewriter& rewriter, const Location& loc) {
auto thread_idx = rewriter.create<gpu::ThreadIdOp>(
loc, rewriter.getIndexType(), rewriter.getStringAttr("x"));
auto block_idx = rewriter.create<gpu::BlockIdOp>(
loc, rewriter.getIndexType(), rewriter.getStringAttr("x"));
auto group_size = rewriter.create<gpu::BlockDimOp>(
loc, rewriter.getIndexType(), rewriter.getStringAttr("x"));
Value index = rewriter.create<AddIOp>(
loc, thread_idx,
rewriter.create<MulIOp>(loc, block_idx, group_size));

return index;
}

template <typename BinaryOp, typename LoweredBinaryOp>
struct BinaryOpLowering : public ConversionPattern {
BinaryOpLowering(MLIRContext* ctx, gpu::LaunchOp* launch_op)
: ConversionPattern(BinaryOp::getOperationName(), 1, ctx),
m_launch_op{launch_op} {}

LogicalResult matchAndRewrite(
Operation* op, ArrayRef<Value> operands,
ConversionPatternRewriter& rewriter) const final {
auto loc = op->getLoc();

typename BinaryOp::Adaptor binary_adaptor(operands);
rewriter.setInsertionPointToEnd(&(m_launch_op->body().front()));

auto index = get_tid(rewriter, loc);
auto loaded_lhs =
get_operand(rewriter, loc, binary_adaptor.lhs(), index);
auto loaded_rhs =
get_operand(rewriter, loc, binary_adaptor.rhs(), index);

auto binary_op =
rewriter.create<LoweredBinaryOp>(loc, loaded_lhs, loaded_rhs);

rewriter.replaceOp(op, binary_op.getResult());
return success();
}

private:
gpu::LaunchOp* m_launch_op;
};

using AddOpLowering = BinaryOpLowering<jit::AddOp, AddFOp>;

struct ReturnOpLowering : public ConversionPattern {
ReturnOpLowering(MLIRContext* ctx, gpu::LaunchOp* launch_op)
: ConversionPattern(jit::ReturnOp::getOperationName(), 1, ctx),
m_launch_op{launch_op} {}

LogicalResult matchAndRewrite(
Operation* op, ArrayRef<Value>,
ConversionPatternRewriter& rewriter) const final {
rewriter.replaceOpWithNewOp<mlir::ReturnOp>(op);
auto loc = op->getLoc();

//! remove the first gpu.terminator
m_launch_op->body().front().front().erase();

//! if (tid >= nr_tid) {return;} in the begin of the block
rewriter.setInsertionPointToStart(&(m_launch_op->body().front()));
Block* cond_block = rewriter.getInsertionBlock();
Block::iterator op_position = rewriter.getInsertionPoint();
Block* remaining_ops_block =
rewriter.splitBlock(cond_block, op_position);
rewriter.setInsertionPointToEnd(cond_block);

auto index = get_tid(rewriter, loc);
auto comparison = rewriter.create<mlir::CmpIOp>(
loc, CmpIPredicate::sge, index,
m_launch_op->getParentOfType<mlir::FuncOp>()
.getArguments()
.back());

Block* then_block =
rewriter.splitBlock(cond_block, rewriter.getInsertionPoint());
rewriter.setInsertionPointToEnd(then_block);
rewriter.create<gpu::TerminatorOp>(loc);

rewriter.setInsertionPointToEnd(cond_block);
rewriter.create<mlir::CondBranchOp>(
loc, comparison, then_block, ArrayRef<Value>(),
remaining_ops_block, ArrayRef<Value>());

rewriter.setInsertionPointToEnd(remaining_ops_block);
rewriter.create<gpu::TerminatorOp>(loc);

return success();
}

private:
gpu::LaunchOp* m_launch_op;
};

struct AssignOpLowering : public ConversionPattern {
AssignOpLowering(MLIRContext* ctx, gpu::LaunchOp* launch_op)
: ConversionPattern(jit::AssignOp::getOperationName(), 2, ctx),
m_launch_op{launch_op} {}

LogicalResult matchAndRewrite(
Operation* op, ArrayRef<Value> operands,
ConversionPatternRewriter& rewriter) const final {
auto loc = op->getLoc();

AssignOpAdaptor assign_adaptor(operands);
rewriter.setInsertionPointToEnd(&(m_launch_op->body().front()));

auto index = get_tid(rewriter, loc);

auto loaded_lhs =
get_operand(rewriter, loc, assign_adaptor.lhs(), index);
rewriter.create<StoreOp>(loc, loaded_lhs, assign_adaptor.rhs(), index);

rewriter.eraseOp(op);
return success();
}

private:
gpu::LaunchOp* m_launch_op;
};

class MgbToGpuLoweringPass
: public PassWrapper<MgbToGpuLoweringPass, FunctionPass> {
public:
void runOnFunction() override final {
auto func_op = getFunction();
Location loc = func_op.getLoc();
OpBuilder builder(&func_op.getBody());
Value constantOne = builder.create<ConstantIndexOp>(loc, 1);
gpu::LaunchOp launch_op = builder.create<gpu::LaunchOp>(
loc, constantOne, constantOne, constantOne, constantOne,
constantOne, constantOne);
builder.setInsertionPointToEnd(&(launch_op.body().front()));
builder.create<gpu::TerminatorOp>(loc);

OwningRewritePatternList patterns;
ConversionTarget target(getContext());
target.addLegalDialect<StandardOpsDialect>();
target.addLegalDialect<gpu::GPUDialect>();
target.addIllegalDialect<MgbDialect>();

patterns.insert<AddOpLowering, AssignOpLowering, ReturnOpLowering>(
&getContext(), &launch_op);

if (failed(applyPartialConversion(func_op, target, patterns))) {
signalPassFailure();
}
}
};

} // namespace

std::unique_ptr<mlir::Pass> mgb::jit::create_lower_to_gpu_pass() {
return std::make_unique<MgbToGpuLoweringPass>();
}

#endif // MGB_JIT_MLIR

// vim: syntax=cpp.doxygen

+ 56
- 0
src/jit/impl/mlir/ir/lower_to_llvm_pass.cpp View File

@@ -0,0 +1,56 @@
/**
* \file src/jit/impl/mlir/ir/lower_to_llvm_pass.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/

#include "megbrain_build_config.h"
#if MGB_JIT && MGB_JIT_MLIR

#include "megbrain/common.h"
#include "megbrain/jit/mlir/ir/dialect.h"
#include "megbrain/jit/mlir/ir/passes.h"

#include <mlir/Conversion/AffineToStandard/AffineToStandard.h>
#include <mlir/Conversion/SCFToStandard/SCFToStandard.h>
#include <mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h>
#include <mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h>

using namespace mgb;
using namespace jit;

namespace {

class AffineToLLVMLoweringPass : public PassWrapper<AffineToLLVMLoweringPass,
OperationPass<ModuleOp>> {
void runOnOperation() final {
LLVMConversionTarget target(getContext());
target.addLegalOp<ModuleOp, ModuleTerminatorOp>();

LLVMTypeConverter typeConverter(&getContext());

OwningRewritePatternList patterns;
populateAffineToStdConversionPatterns(patterns, &getContext());
populateLoopToStdConversionPatterns(patterns, &getContext());
populateStdToLLVMConversionPatterns(typeConverter, patterns);

auto module = getOperation();
if (failed(applyFullConversion(module, target, patterns)))
signalPassFailure();
}
};
} // namespace

std::unique_ptr<mlir::Pass> mgb::jit::create_lower_to_llvm_pass() {
return std::make_unique<AffineToLLVMLoweringPass>();
}

#endif // MGB_JIT_MLIR

// vim: syntax=cpp.doxygen

+ 72
- 0
src/jit/impl/mlir/ir/ops.td View File

@@ -0,0 +1,72 @@
/**
* \file src/jit/impl/mlir/ir/ops.td
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/

#ifndef MGB_MLIR_OPS
#define MGB_MLIR_OPS

include "mlir/IR/OpBase.td"
include "mlir/Interfaces/SideEffectInterfaces.td"

include "./shape_inference_interface.td"

def Mgb_Dialect : Dialect {
let name = "mgb";
let cppNamespace = "mgb::jit";
}

class ElemwiseOp<string mnemonic, list<OpTrait> traits = []> :
Op<Mgb_Dialect, mnemonic, traits>;

class GenericOp<string mnemonic, list<OpTrait> traits = []> :
Op<Mgb_Dialect, mnemonic, traits>;

def AddOp : ElemwiseOp<"add",
[NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
let summary = "element-wise addition operation";
let description = [{
The "add" operation performs element-wise addition between two tensors.
The shapes of the tensor operands are expected to match.
}];

let arguments = (ins F32MemRef:$lhs, F32MemRef:$rhs);
let results = (outs F32MemRef);

// Specify a parser and printer method.
let parser = [{ return ::parseBinaryOp(parser, result); }];
let printer = [{ return ::printBinaryOp(p, *this); }];

// Allow building an AddOp with from the two input operands.
let builders = [
OpBuilder<"OpBuilder &b, OperationState &state, Value lhs, Value rhs">
];
}

def ReturnOp : GenericOp<"return",
[NoSideEffect, HasParent<"FuncOp">, Terminator]> {
let summary = "return operation";
let description = [{
The "return" operation represents a return operation within a function.
The operation takes an no tensor operand and produces no results.
}];

}

def AssignOp : GenericOp<"assign", []> {
let summary = "assign op";
let description = [{
assign rhs to lhs without results
}];

let arguments = (ins F32MemRef:$lhs, F32MemRef:$rhs);
}

#endif

+ 30
- 0
src/jit/impl/mlir/ir/shape_inference_interface.td View File

@@ -0,0 +1,30 @@
/**
* \file src/jit/impl/mlir/ir/shape_inference_interface.td
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/

#ifndef MGB_JIT_SHAPE_INFERENCE_INTERFACE
#define MGB_JIT_SHAPE_INFERENCE_INTERFACE

include "mlir/IR/OpBase.td"

def ShapeInferenceOpInterface : OpInterface<"ShapeInference"> {
let description = [{
Interface to access a registered method to infer the return types for an
operation that can be used during type inference.
}];

let methods = [
InterfaceMethod<"Infer and set the output shape for the current operation.",
"void", "infer_shapes">
];
}

#endif // MGB_SHAPE_INFERENCE_INTERFACE

+ 100
- 0
src/jit/impl/mlir/ir/shape_inference_pass.cpp View File

@@ -0,0 +1,100 @@
/**
* \file src/jit/impl/mlir/ir/shape_inference_pass.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/

#include "megbrain_build_config.h"
#if MGB_JIT && MGB_JIT_MLIR

#include "megbrain/common.h"
#include "megbrain/jit/mlir/ir/dialect.h"
#include "megbrain/jit/mlir/ir/passes.h"
#include "megbrain/jit/mlir/ir/shape_inference_interface.h"

#include <llvm/ADT/SmallPtrSet.h>
#include <mlir/IR/StandardTypes.h>
#include <mlir/Pass/Pass.h>

using namespace mgb;
using namespace jit;

#include "megbrain/jit/mlir/ir/shape_inference_interface.cpp.inc"

namespace {
class ShapeInferencePass
: public mlir::PassWrapper<ShapeInferencePass, FunctionPass> {
public:
void runOnFunction() override {
auto f = getFunction();

llvm::SmallPtrSet<mlir::Operation*, 16> op_worklist;
f.walk([&](mlir::Operation* op) {
if (returns_dynamic_shape(op))
op_worklist.insert(op);
});

// Iterate on the operations in the worklist until all operations have
// been inferred or no change happened (fix point).
while (!op_worklist.empty()) {
// Find the next operation ready for inference, that is an operation
// with all operands already resolved (non-generic).
auto nextop = llvm::find_if(op_worklist, all_operands_inferred);
if (nextop == op_worklist.end())
break;

Operation* op = *nextop;
op_worklist.erase(op);

if (auto shapeOp = dyn_cast<ShapeInference>(op)) {
shapeOp.infer_shapes();
} else {
mgb_log_error(
"unable to infer shape of operation without shape "
"inference interface");
return signalPassFailure();
}
}

// If the operation worklist isn't empty, this indicates a failure.
if (!op_worklist.empty()) {
mgb_log_error(
"Shape inference failed, %zu operations couldn't be "
"inferred",
op_worklist.size());
signalPassFailure();
}
}

//! A utility method that returns if the given operation has all of its
//! operands inferred.
static bool all_operands_inferred(Operation* op) {
return llvm::all_of(op->getOperandTypes(), [](Type operandType) {
return operandType.isa<mlir::MemRefType>();
});
}

//! A utility method that returns if the given operation has a dynamically
//! shaped result.
static bool returns_dynamic_shape(Operation* op) {
return llvm::any_of(op->getResultTypes(), [](Type resultType) {
return !resultType.isa<mlir::MemRefType>();
});
}
};

} // namespace

std::unique_ptr<mlir::Pass> mgb::jit::create_shape_inference_pass() {
return std::make_unique<ShapeInferencePass>();
}

#endif // MGB_JIT_MLIR

// vim: syntax=cpp.doxygen

+ 207
- 0
src/jit/impl/mlir/mlir_gen.cpp View File

@@ -0,0 +1,207 @@
/**
* \file src/jit/impl/mlir/mlir_gen.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/

#include "megbrain_build_config.h"
#if MGB_JIT && MGB_JIT_MLIR

#include "./mlir_gen.h"

#include "./utils.h"
#include "megbrain/jit/mlir/ir/dialect.h"
#include "megbrain/opr/basic_arith.h"
#include "megdnn/dtype.h"

#include <mlir/Dialect/Affine/IR/AffineOps.h>
#include <mlir/Dialect/StandardOps/IR/Ops.h>
#include <mlir/IR/Attributes.h>
#include <mlir/IR/Builders.h>
#include <mlir/IR/Function.h>
#include <mlir/IR/MLIRContext.h>
#include <mlir/IR/Module.h>
#include <mlir/IR/StandardTypes.h>
#include <mlir/IR/Types.h>
#include <mlir/IR/Value.h>
#include <mlir/IR/Verifier.h>
#include <mlir/Support/LogicalResult.h>

#include <llvm/ADT/ScopedHashTable.h>
#include <llvm/Support/raw_ostream.h>

using namespace mgb;
using namespace jit;

namespace {
class MLIRGenImpl {
public:
MLIRGenImpl(mlir::MLIRContext& context) : m_builder(&context) {}

std::pair<llvm::StringRef, mlir::OwningModuleRef> gen(
const InternalGraph& internal_graph,
const JITExecutor::Args& args) {
mlir::ModuleOp module =
mlir::ModuleOp::create(m_builder.getUnknownLoc());

//! Create main routine function
auto func_op = gen_func_op(internal_graph, args);
module.push_back(func_op);

if (mlir::failed(mlir::verify(module))) {
module.emitError("module verification error");
return {};
}

return {func_op.getName(), module};
}

private:
mlir::OpBuilder m_builder;
llvm::ScopedHashTable<mlir::StringRef, mlir::Value> m_symbol_table;

mlir::FuncOp gen_func_op(const InternalGraph& internal_graph,
const JITExecutor::Args& args) {
llvm::ScopedHashTableScope<llvm::StringRef, mlir::Value> var_scope(
m_symbol_table);
std::vector<mlir::Type> func_args;
for (auto&& arg : args.inputs) {
func_args.push_back(get_type(arg.layout));
}
for (auto&& arg : args.outputs) {
func_args.push_back(get_type(arg.layout));
}
//! the last arg is nr_elements
func_args.push_back(m_builder.getIndexType());

auto func_type = m_builder.getFunctionType(func_args, llvm::None);
//! function name maybe renamed in later pass
mlir::FuncOp func_op = mlir::FuncOp::create(m_builder.getUnknownLoc(),
"func", func_type);
if (!func_op)
return nullptr;

func_op.setAttr("llvm.emit_c_interface",
mlir::UnitAttr::get(m_builder.getContext()));
auto& entry_block = *func_op.addEntryBlock();
size_t idx = 0;
for (auto&& input : args.inputs) {
if (mlir::failed(declare(internal_graph.placeholders()[input.idx]
->output(0)
->name(),
entry_block.getArgument(idx)))) {
return nullptr;
}
idx++;
}
for (auto&& output : args.outputs) {
if (mlir::failed(declare(output.from->name(),
entry_block.getArgument(idx)))) {
return nullptr;
}
idx++;
}

m_builder.setInsertionPointToStart(&entry_block);

if (mlir::failed(gen_func_body(internal_graph, args))) {
func_op.erase();
return nullptr;
}

jit::ReturnOp return_op;
if (!return_op) {
m_builder.create<jit::ReturnOp>(m_builder.getUnknownLoc());
}
std::string op_content = to_string(func_op);
func_op.setName(
ssprintf("jit_mlir_%" PRIx64,
XXHash{}.update(op_content.data(), op_content.size())
.digest()));
return func_op;
}

mlir::LogicalResult gen_func_body(const InternalGraph& internal_graph,
const JITExecutor::Args& args) {
llvm::ScopedHashTableScope<llvm::StringRef, mlir::Value> var_scope(
m_symbol_table);
cg::DepOprIter{[&](cg::OperatorNodeBase* opr) {
if (opr->same_type<JITPlaceholder>()) {
return;
}

if (opr->same_type<opr::Elemwise>()) {
auto&& out = gen_op(opr->cast_final<opr::Elemwise>());
mgb_assert(
mlir::succeeded(declare(opr->output(0)->name(), out)));
}
}}.add(internal_graph.output());
m_builder.create<AssignOp>(m_builder.getUnknownLoc(),
get(internal_graph.output()),
get(args.outputs[0].from));

return mlir::success();
}

mlir::Value gen_op(const opr::Elemwise& opr) {
switch (opr.param().mode) {
case opr::Elemwise::Mode::ADD:
return m_builder.create<AddOp>(m_builder.getUnknownLoc(),
get(opr.input(0)),
get(opr.input(1)));
break;
default:
return nullptr;
}
return nullptr;
}

mlir::Type get_type(const TensorLayout& layout) {
std::vector<int64_t> shape;
for (size_t i = 0; i < layout.ndim; i++) {
shape.push_back(layout[i]);
}
mgb_assert(layout.ndim != 0);
switch (layout.dtype.enumv()) {
case DTypeEnum::Float32:
return mlir::MemRefType::get(shape, m_builder.getF32Type());
default:
mgb_throw(InternalError, "No supported dtype: %s",
layout.dtype.name());
}
return mlir::UnrankedMemRefType::get(m_builder.getNoneType(), 0);
}

mlir::Value get(const VarNode* var) {
if (auto ret = m_symbol_table.lookup(var->name())) {
return ret;
}
mgb_throw(InternalError, "Unknown var: %s", var->cname());
}

mlir::LogicalResult declare(llvm::StringRef var, mlir::Value value) {
if (m_symbol_table.count(var)) {
return mlir::failure();
}
m_symbol_table.insert(var, value);
return mlir::success();
}
};
} // namespace

std::pair<llvm::StringRef, mlir::OwningModuleRef> mgb::jit::mlir_gen(
mlir::MLIRContext& context,
const mgb::jit::InternalGraph& internal_graph,
const mgb::jit::JITExecutor::Args& args) {
return MLIRGenImpl(context).gen(internal_graph, args);
}

#endif // MGB_JIT_MLIR

// vim: syntax=cpp.doxygen

+ 42
- 0
src/jit/impl/mlir/mlir_gen.h View File

@@ -0,0 +1,42 @@
/**
* \file src/jit/impl/mlir/mlir_gen.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/

#pragma once

#include "megbrain_build_config.h"
#if MGB_JIT && MGB_JIT_MLIR

#include "megbrain/jit/executor_opr.h"
#include "megbrain/jit/internal_graph.h"

#include <mlir/IR/Module.h>

namespace mgb {
namespace jit {

/**
* \brief generate mlir from subgraph.
*
* \param context mlir context
* \param internal_graph internal graph used to generate mlir
* \param args input args for the internal graph
* \return A pair of {kernel_name, module}
**/
std::pair<llvm::StringRef, mlir::OwningModuleRef> mlir_gen(
mlir::MLIRContext& context, const InternalGraph& internal_graph,
const JITExecutor::Args& args);
}
} // namespace mgb

#endif // MGB_JIT_MLIR

// vim: syntax=cpp.doxygen

+ 45
- 0
src/jit/impl/mlir/utils.h View File

@@ -0,0 +1,45 @@
/**
* \file src/jit/impl/mlir/utils.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/

#pragma once

#include "megbrain_build_config.h"
#if MGB_JIT && MGB_JIT_MLIR

#include "megbrain/common.h"
#include "megbrain/exception.h"
#include "megdnn/basic_types.h"
#include "megdnn/dtype.h"

#include <string>

#include <mlir/ExecutionEngine/CRunnerUtils.h>

#include <llvm/Support/raw_ostream.h>

namespace mgb {
namespace jit {

template <typename T>
std::string to_string(T&& t) {
std::string ret;
llvm::raw_string_ostream stream(ret);
t.print(stream);
return ret;
}

} // namespace jit
} // namespace mgb

#endif // MGB_JIT && MGB_JIT_MLIR

// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}

+ 45
- 0
src/jit/include/megbrain/jit/mlir/ir/dialect.h View File

@@ -0,0 +1,45 @@
/**
* \file src/jit/impl/mlir/ir/dialect.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/

#pragma once

#include "megbrain_build_config.h"
#if MGB_JIT && MGB_JIT_MLIR

#include <mlir/IR/OpDefinition.h>
#include <mlir/IR/Dialect.h>
#include <mlir/IR/Function.h>
#include <mlir/Interfaces/SideEffectInterfaces.h>

#include "megbrain/jit/mlir/ir/shape_inference_interface.h"

namespace mgb {
namespace jit {

class MgbDialect : public ::mlir::Dialect {
public:
explicit MgbDialect(::mlir::MLIRContext* ctx);

//! We should register this function in dialect
static llvm::StringRef getDialectNamespace() { return "mgb::jit"; }
};

#define GET_OP_CLASSES
using namespace mlir;
#include "megbrain/jit/mlir/ir/ops.h.inc"

} // namespace jit
} // namespace mgb

#endif // MGB_JIT_MLIR

// vim: syntax=cpp.doxygen

+ 43
- 0
src/jit/include/megbrain/jit/mlir/ir/passes.h View File

@@ -0,0 +1,43 @@
/**
* \file src/jit/impl/mlir/ir/passes.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/

#pragma once

#include <mlir/IR/Module.h>
#include "megbrain_build_config.h"
#if MGB_JIT && MGB_JIT_MLIR

#include <memory>

#include <mlir/Pass/Pass.h>

namespace mgb {
namespace jit {

std::unique_ptr<mlir::Pass> create_shape_inference_pass();

/**
* \brief Create a pass for lowering to operations in the `Affine` and `Std`
* dialects, for a subset of the megbrain IR.
*/
std::unique_ptr<mlir::Pass> create_lower_to_affine_pass();

std::unique_ptr<mlir::Pass> create_lower_to_llvm_pass();

std::unique_ptr<mlir::Pass> create_lower_to_gpu_pass();

} // namespace jit
} // namespace mgb

#endif // MGB_JIT_MLIR

// vim: syntax=cpp.doxygen

+ 33
- 0
src/jit/include/megbrain/jit/mlir/ir/shape_inference_interface.h View File

@@ -0,0 +1,33 @@
/**
* \file src/jit/impl/mlir/ir/shape_inference_interface.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/

#pragma once

#include "megbrain_build_config.h"
#if MGB_JIT && MGB_JIT_MLIR

#include "mlir/IR/OpDefinition.h"

namespace mgb {
namespace jit {

/// Include the auto-generated declarations.
#include "megbrain/jit/mlir/ir/shape_inference_interface.h.inc"

} // end namespace toy
} // end namespace mlir



#endif // MGB_JIT_MLIR

// vim: syntax=cpp.doxygen

+ 2
- 2
src/jit/include/megbrain/jit/param_elem_visitor.h View File

@@ -18,8 +18,8 @@
*/

/*!
* * \brief fast division for unsigned int
* */
* \brief fast division for unsigned int
*/
struct Uint32Fastdiv {
unsigned int m_mul, m_divisor, m_divisor_is_not_1, m_inc_dividend, m_shift;



+ 53
- 0
src/jit/test/codegen.cpp View File

@@ -14,6 +14,7 @@
#include "megbrain/jit/executor_opr.h"
#include "megbrain/opr/basic_arith_wrapper.h"
#include "megbrain/test/helper.h"
#include "megdnn/dtype.h"

#if MGB_JIT
using namespace mgb;
@@ -120,6 +121,44 @@ void run<grad>(Backend backend, CompNode cn) {

template <>
void run<void>(Backend, CompNode) {}

#if MGB_JIT_MLIR
void run_mlir(CompNode cn) {
set_backend(Backend::MLIR);
auto graph = ComputingGraph::make();
HostTensorGenerator<dtype::Float32> gen;

auto host_x0 = gen({23, 42}, cn), host_x1 = gen({23, 42}, cn),
host_x2 = gen({23, 42}, cn);

auto a = opr::Host2DeviceCopy::make(*graph, host_x0),
b = opr::Host2DeviceCopy::make(*graph, host_x1),
c = opr::Host2DeviceCopy::make(*graph, host_x2);

auto y = a + b + c;

VarNodeArray inputs{a.node(), b.node(), c.node()}, outputs{y.node()};
auto ig_gen =
std::make_unique<InternalGraphGenerator>(y.node()->owner_opr());

for (auto i : get_rev_topo_order(y)) {
if (!i->same_type<opr::Host2DeviceCopy>()) {
ig_gen->add_opr(i);
}
}

auto igraph = ig_gen->generate();
auto y_jit = JITExecutor::make(igraph, ig_gen->orig_inps());

HostTensorND host_y, host_y_jit;
auto func = graph->compile({make_callback_copy(y, host_y),
make_callback_copy(y_jit, host_y_jit)});
func->execute();

MGB_ASSERT_TENSOR_EQ(host_y, host_y_jit);
}
#endif

} // anonymous namespace

#if MGB_JIT_HALIDE
@@ -140,6 +179,20 @@ TYPED_TEST(TestJITNvrtcCodeGen, run) {
run<TypeParam>(Backend::NVRTC, CompNode::load("gpu0"));
}

#if MGB_JIT_MLIR
TEST(TestJITMlirCodeGen, Basic) {
auto cn = CompNode::load("cpu0");
run_mlir(cn);
}

TEST(TestJITMlirCodeGen, BasicGPU) {
REQUIRE_GPU(1);
auto cn = CompNode::load("gpu0");
run_mlir(cn);
}

#endif

#endif // MGB_JIT

// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}

+ 3
- 0
src/jit/test/helper.cpp View File

@@ -35,6 +35,9 @@ void jit::set_backend(Backend backend) {
case Backend::NVRTC:
setenv("MGB_JIT_BACKEND", "NVRTC", 1);
return;
case Backend::MLIR:
setenv("MGB_JIT_BACKEND", "MLIR", 1);
return;
default:
mgb_assert(0);
}


+ 1
- 1
src/jit/test/helper.h View File

@@ -15,7 +15,7 @@

namespace mgb {
namespace jit {
enum class Backend { NONE, HALIDE, NVRTC };
enum class Backend { NONE, HALIDE, NVRTC, MLIR };

void set_backend(Backend backend);



Loading…
Cancel
Save