GitOrigin-RevId: 814fed047e
tags/v1.0.0-rc1
@@ -4,3 +4,4 @@ dnn/src/cuda/conv_bias/int8/kimpl/* binary | |||||
dnn/src/cuda/conv_bias/int8_imma/kimpl/* binary | dnn/src/cuda/conv_bias/int8_imma/kimpl/* binary | ||||
dnn/src/cuda/batch_conv_bias/int8/kimpl/* binary | dnn/src/cuda/batch_conv_bias/int8/kimpl/* binary | ||||
dnn/src/cuda/sass/prebuilt/map_defs.cpp binary | dnn/src/cuda/sass/prebuilt/map_defs.cpp binary | ||||
tools/mlir/mlir-tblgen filter=lfs diff=lfs merge=lfs -text |
@@ -162,6 +162,15 @@ void mgb::_on_cuda_error(const char* expr, cudaError_t err, const char* file, | |||||
cudaGetErrorString(err), expr, file, func, line); | cudaGetErrorString(err), expr, file, func, line); | ||||
} | } | ||||
void mgb::_on_cuda_cu_error(const char* expr, CUresult err, const char* file, | |||||
const char* func, int line) { | |||||
const char* msg; | |||||
cuGetErrorName(err, &msg); | |||||
mgb_throw(CudaError, "cuda error %d: %s (%s at %s:%s:%d)", int(err), msg, | |||||
expr, file, func, line); | |||||
} | |||||
void CompNodeEnv::init_cuda_async(int dev, CompNode comp_node, | void CompNodeEnv::init_cuda_async(int dev, CompNode comp_node, | ||||
const ContinuationCtx<cudaStream_t>& cont) { | const ContinuationCtx<cudaStream_t>& cont) { | ||||
m_comp_node = comp_node; | m_comp_node = comp_node; | ||||
@@ -22,6 +22,7 @@ | |||||
#if MGB_CUDA | #if MGB_CUDA | ||||
#include <cuda_runtime.h> | #include <cuda_runtime.h> | ||||
#include <cuda.h> | |||||
#if MGB_ENABLE_LOGGING | #if MGB_ENABLE_LOGGING | ||||
#define MGB_CUDA_CHECK(expr) \ | #define MGB_CUDA_CHECK(expr) \ | ||||
@@ -32,6 +33,16 @@ | |||||
__func__, __LINE__); \ | __func__, __LINE__); \ | ||||
} \ | } \ | ||||
} while (0) | } while (0) | ||||
#define MGB_CUDA_CU_CHECK(expr) \ | |||||
do { \ | |||||
CUresult __cuda_check_code = (expr); \ | |||||
if (!mgb_likely(__cuda_check_code == CUDA_SUCCESS)) { \ | |||||
::mgb::_on_cuda_cu_error(#expr, __cuda_check_code, __FILE__, \ | |||||
__func__, __LINE__); \ | |||||
} \ | |||||
} while (0) | |||||
#else | #else | ||||
#define MGB_CUDA_CHECK(expr) \ | #define MGB_CUDA_CHECK(expr) \ | ||||
do { \ | do { \ | ||||
@@ -41,6 +52,14 @@ | |||||
} \ | } \ | ||||
} while (0) | } while (0) | ||||
#define MGB_CUDA_CU_CHECK(expr) \ | |||||
do { \ | |||||
CUresult __cuda_check_code = (expr); \ | |||||
if (!mgb_likely(__cuda_check_code == CUDA_SUCCESS)) { \ | |||||
::mgb::_on_cuda_cu_error(#expr, __cuda_check_code, "", "", 1); \ | |||||
} \ | |||||
} while (0) | |||||
#endif //MGB_ENABLE_LOGGING | #endif //MGB_ENABLE_LOGGING | ||||
#endif //MGB_CUDA | #endif //MGB_CUDA | ||||
@@ -178,6 +197,9 @@ namespace mgb { | |||||
#if MGB_CUDA | #if MGB_CUDA | ||||
[[noreturn]] void _on_cuda_error(const char* expr, cudaError_t err, | [[noreturn]] void _on_cuda_error(const char* expr, cudaError_t err, | ||||
const char* file, const char* func, int line); | const char* file, const char* func, int line); | ||||
[[noreturn]] void _on_cuda_cu_error(const char* expr, CUresult err, | |||||
const char* file, const char* func, | |||||
int line); | |||||
#endif | #endif | ||||
@@ -11,6 +11,7 @@ | |||||
#include "./halide/compiler_cuda.h" | #include "./halide/compiler_cuda.h" | ||||
#include "./nvrtc/compiler_cuda.h" | #include "./nvrtc/compiler_cuda.h" | ||||
#include "./mlir/compiler.h" | |||||
#include "megbrain/jit/compiler.h" | #include "megbrain/jit/compiler.h" | ||||
#include "megbrain/utils/hash.h" | #include "megbrain/utils/hash.h" | ||||
@@ -54,6 +55,8 @@ bool Compiler::is_supported_device(CompNode::DeviceType device) { | |||||
case CompNode::DeviceType::CUDA: | case CompNode::DeviceType::CUDA: | ||||
return true; | return true; | ||||
#endif | #endif | ||||
case CompNode::DeviceType::CPU: | |||||
return true; | |||||
default: | default: | ||||
return false; | return false; | ||||
} | } | ||||
@@ -87,12 +90,30 @@ Compiler* Compiler::get(ComputingGraph& graph, CompNode comp_node) { | |||||
break; | break; | ||||
} | } | ||||
#endif | #endif | ||||
#if MGB_JIT_MLIR | |||||
if (!backend || !strcmp(backend, "MLIR")) { | |||||
compiler = std::make_unique<MLIRCompiler>( | |||||
CompNode::DeviceType::CUDA); | |||||
break; | |||||
} | |||||
#endif | |||||
if (!backend || !strcmp(backend, "NVRTC")) { | if (!backend || !strcmp(backend, "NVRTC")) { | ||||
compiler = std::make_unique<CudaCompiler>(); | compiler = std::make_unique<CudaCompiler>(); | ||||
break; | break; | ||||
} | } | ||||
#endif | #endif | ||||
// fall through | |||||
mgb_throw(InternalError, "No compiler support for cuda"); | |||||
break; | |||||
case CompNode::DeviceType::CPU: | |||||
#if MGB_JIT_MLIR | |||||
if (!backend || !strcmp(backend, "MLIR")) { | |||||
compiler = std::make_unique<MLIRCompiler>( | |||||
CompNode::DeviceType::CPU); | |||||
break; | |||||
} | |||||
#endif | |||||
mgb_throw(InternalError, "No compiler support for cpu"); | |||||
break; | |||||
default: | default: | ||||
mgb_throw(InternalError, | mgb_throw(InternalError, | ||||
"unsupported JIT config: " | "unsupported JIT config: " | ||||
@@ -0,0 +1,198 @@ | |||||
/** | |||||
* \file src/jit/impl/mlir/compiler.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | |||||
#include "megbrain_build_config.h" | |||||
#if MGB_JIT && MGB_JIT_MLIR | |||||
#include "./compiler.h" | |||||
#include "./executable_cpu.h" | |||||
#include "./executable_cuda.h" | |||||
#include "./mlir_gen.h" | |||||
#include "megbrain/common.h" | |||||
#include "megbrain/comp_node_env.h" | |||||
#include "megbrain/jit/mlir/ir/dialect.h" | |||||
#include "megbrain/jit/mlir/ir/passes.h" | |||||
#include <mlir/Conversion/GPUCommon/GPUCommonPass.h> | |||||
#include <mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h> | |||||
#include <mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h> | |||||
#include <mlir/Dialect/GPU/Passes.h> | |||||
#include <mlir/IR/Dialect.h> | |||||
#include <mlir/IR/MLIRContext.h> | |||||
#include <mlir/IR/Module.h> | |||||
#include <mlir/InitAllDialects.h> | |||||
#include <mlir/Pass/PassManager.h> | |||||
#include <mlir/Support/LogicalResult.h> | |||||
#include <mlir/Target/NVVMIR.h> | |||||
#include <mlir/Transforms/Passes.h> | |||||
#include <llvm/Support/TargetSelect.h> | |||||
using namespace mgb; | |||||
using namespace jit; | |||||
namespace { | |||||
struct LLVMInitializer { | |||||
LLVMInitializer() { | |||||
llvm::InitializeNativeTarget(); | |||||
llvm::InitializeNativeTargetAsmPrinter(); | |||||
} | |||||
}; | |||||
static LLVMInitializer initializer; | |||||
#if MGB_CUDA | |||||
mlir::OwnedBlob compile_ptx_to_cubin(const std::string ptx, mlir::Location, | |||||
llvm::StringRef) { | |||||
OwnedBlob result = std::make_unique<std::vector<char>>( | |||||
ptx.data(), ptx.data() + ptx.size()); | |||||
return result; | |||||
} | |||||
#endif | |||||
void add_cpu_lowering_pass(mlir::PassManager& manager) { | |||||
{ | |||||
mlir::OpPassManager& opt_pm = manager.nest<mlir::FuncOp>(); | |||||
opt_pm.addPass(create_shape_inference_pass()); | |||||
opt_pm.addPass(mlir::createCanonicalizerPass()); | |||||
opt_pm.addPass(mlir::createCSEPass()); | |||||
} | |||||
manager.addPass(create_lower_to_affine_pass()); | |||||
{ | |||||
mlir::OpPassManager& opt_pm = manager.nest<mlir::FuncOp>(); | |||||
opt_pm.addPass(mlir::createCanonicalizerPass()); | |||||
opt_pm.addPass(mlir::createCSEPass()); | |||||
opt_pm.addPass(mlir::createLoopFusionPass()); | |||||
opt_pm.addPass(mlir::createMemRefDataFlowOptPass()); | |||||
} | |||||
manager.addPass(create_lower_to_llvm_pass()); | |||||
} | |||||
#if MGB_CUDA | |||||
void add_cuda_lowering_pass(mlir::PassManager& manager, CompNode cn) { | |||||
{ | |||||
mlir::OpPassManager& opt_pm = manager.nest<mlir::FuncOp>(); | |||||
opt_pm.addPass(create_shape_inference_pass()); | |||||
opt_pm.addPass(mlir::createCanonicalizerPass()); | |||||
opt_pm.addPass(mlir::createCSEPass()); | |||||
} | |||||
manager.addPass(create_lower_to_gpu_pass()); | |||||
{ | |||||
mlir::OpPassManager& opt_pm = manager.nest<mlir::FuncOp>(); | |||||
opt_pm.addPass(mlir::createCanonicalizerPass()); | |||||
opt_pm.addPass(mlir::createCSEPass()); | |||||
opt_pm.addPass(mlir::createLoopFusionPass()); | |||||
opt_pm.addPass(mlir::createMemRefDataFlowOptPass()); | |||||
} | |||||
manager.addPass(mlir::createGpuKernelOutliningPass()); | |||||
{ | |||||
auto& kernel_pm = manager.nest<gpu::GPUModuleOp>(); | |||||
kernel_pm.addPass(mlir::createLowerGpuOpsToNVVMOpsPass()); | |||||
auto&& prop = CompNodeEnv::from_comp_node(cn).cuda_env().device_prop; | |||||
kernel_pm.addPass(mlir::createConvertGPUKernelToBlobPass( | |||||
mlir::translateModuleToNVVMIR, compile_ptx_to_cubin, | |||||
"nvptx64-nvidia-cuda", | |||||
ssprintf("sm_%d%d", prop.major, prop.minor), "+ptx60", | |||||
MLIRCUDAExecutable::sm_blob_annotation)); | |||||
} | |||||
} | |||||
#endif | |||||
} // namespace | |||||
/* ==================== MLIRCompiler ===================== */ | |||||
thread_local mlir::MLIRContext MLIRCompiler::sm_ctx; | |||||
MLIRCompiler::MLIRCompiler(CompNode::DeviceType device_type) | |||||
: m_device_type{device_type} { | |||||
mlir::registerAllDialects(); | |||||
mlir::registerDialect<MgbDialect>(); | |||||
#if MGB_CUDA | |||||
if (m_device_type == CompNode::DeviceType::CUDA) { | |||||
LLVMInitializeNVPTXTarget(); | |||||
LLVMInitializeNVPTXTargetInfo(); | |||||
LLVMInitializeNVPTXTargetMC(); | |||||
LLVMInitializeNVPTXAsmPrinter(); | |||||
} | |||||
#endif | |||||
} | |||||
void MLIRCompiler::run_lowering_pass(mlir::OwningModuleRef& module, | |||||
CompNode cn) { | |||||
mgb_assert(cn.device_type() == m_device_type); | |||||
mlir::PassManager manager(module->getContext()); | |||||
switch (m_device_type) { | |||||
case CompNode::DeviceType::CPU: | |||||
add_cpu_lowering_pass(manager); | |||||
break; | |||||
#if MGB_CUDA | |||||
case CompNode::DeviceType::CUDA: | |||||
add_cuda_lowering_pass(manager, cn); | |||||
break; | |||||
#endif | |||||
default: | |||||
mgb_throw(InternalError, "Unsupport device type: %d", | |||||
static_cast<int>(m_device_type)); | |||||
break; | |||||
} | |||||
mgb_assert(mlir::succeeded(manager.run(*module))); | |||||
} | |||||
std::unique_ptr<Executable> MLIRCompiler::do_compile( | |||||
const InternalGraph& graph, const JITExecutor::Args& args) { | |||||
MGB_MARK_USED_VAR(graph); | |||||
MGB_MARK_USED_VAR(args); | |||||
mlir::MLIRContext ctx; | |||||
ctx.printStackTraceOnDiagnostic(true); | |||||
ctx.printOpOnDiagnostic(true); | |||||
auto&& res = mlir_gen(ctx, graph, args); | |||||
mgb_assert(res.second, "failed to generate module"); | |||||
CompNode cn = args.owner->comp_node(); | |||||
run_lowering_pass(res.second, cn); | |||||
switch (cn.device_type()) { | |||||
case CompNode::DeviceType::CPU: | |||||
return std::make_unique<MLIRCPUExecutable>(res.second, | |||||
res.first.str()); | |||||
#if MGB_CUDA | |||||
case CompNode::DeviceType::CUDA: | |||||
return std::make_unique<MLIRCUDAExecutable>(res.second, | |||||
res.first.str()); | |||||
#endif | |||||
default: | |||||
mgb_throw(InternalError, "Unsupport device type: %d", | |||||
static_cast<int>(cn.device_type())); | |||||
return nullptr; | |||||
} | |||||
} | |||||
size_t MLIRCompiler::get_nr_workspace_outputs(JITExecutor* opr) const { | |||||
MGB_MARK_USED_VAR(opr); | |||||
return 0; | |||||
} | |||||
void MLIRCompiler::init_workspace_size_infer(JITExecutor* opr) { | |||||
MGB_MARK_USED_VAR(opr); | |||||
} | |||||
#endif // MGB_JIT && MGB_JIT_MLIR | |||||
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} |
@@ -0,0 +1,56 @@ | |||||
/** | |||||
* \file src/jit/impl/mlir/compiler.h | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#pragma once | |||||
#include "megbrain_build_config.h" | |||||
#if MGB_JIT && MGB_JIT_MLIR | |||||
#include "megbrain/jit/compiler.h" | |||||
#include <mlir/ExecutionEngine/ExecutionEngine.h> | |||||
#include <mlir/IR/Module.h> | |||||
namespace mgb { | |||||
namespace jit { | |||||
/*! | |||||
* \brief MLIR compiler | |||||
*/ | |||||
class MLIRCompiler final : public Compiler { | |||||
std::unique_ptr<Executable> do_compile( | |||||
const InternalGraph& graph, const JITExecutor::Args& args) override; | |||||
public: | |||||
MLIRCompiler(CompNode::DeviceType device_type = CompNode::DeviceType::CPU); | |||||
Property property() const override { | |||||
using F = Property::Flag; | |||||
return Property{F::NEED_INPUT_COLLAPSE | F::BIND_NDIM, | |||||
JITFeatureBits::NONE, 64}; | |||||
} | |||||
size_t get_nr_workspace_outputs(JITExecutor* opr) const override; | |||||
void init_workspace_size_infer(JITExecutor* opr) override; | |||||
private: | |||||
void run_lowering_pass(mlir::OwningModuleRef& module, CompNode cn); | |||||
CompNode::DeviceType m_device_type; | |||||
static thread_local mlir::MLIRContext sm_ctx; | |||||
}; | |||||
} // namespace jit | |||||
} // namespace mgb | |||||
#endif // MGB_JIT && MGB_JIT_MLIR | |||||
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} |
@@ -0,0 +1,134 @@ | |||||
/** | |||||
* \file src/jit/impl/mlir/executable_cpu.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | |||||
#include "megbrain_build_config.h" | |||||
#if MGB_JIT && MGB_JIT_MLIR | |||||
#include "./executable_cpu.h" | |||||
#include "./utils.h" | |||||
#include <mlir/ExecutionEngine/OptUtils.h> | |||||
using namespace mgb; | |||||
using namespace jit; | |||||
namespace { | |||||
template <int N> | |||||
void* tensor2memref_dim(const megdnn::TensorND& tensor) { | |||||
switch (tensor.layout.dtype.enumv()) { | |||||
case megdnn::DTypeEnum::Float32: { | |||||
StridedMemRefType<float, N>* desc = | |||||
static_cast<StridedMemRefType<float, N>*>( | |||||
malloc(sizeof(StridedMemRefType<float, N>))); | |||||
desc->basePtr = tensor.ptr<float>(); | |||||
desc->data = tensor.ptr<float>(); | |||||
desc->offset = 0; | |||||
for (size_t i = 0; i < tensor.layout.ndim; i++) { | |||||
desc->sizes[i] = tensor.layout.shape[i]; | |||||
desc->strides[i] = tensor.layout.stride[i]; | |||||
} | |||||
return desc; | |||||
break; | |||||
} | |||||
default: | |||||
mgb_throw(InternalError, "Unsupport dtype, got %s", | |||||
tensor.layout.dtype.name()); | |||||
break; | |||||
} | |||||
return nullptr; | |||||
} | |||||
void* tensor2memref(const megdnn::TensorND& tensor) { | |||||
switch (tensor.layout.ndim) { | |||||
#define cb(i) \ | |||||
case i: \ | |||||
return tensor2memref_dim<i>(tensor) | |||||
cb(1); | |||||
cb(2); | |||||
cb(3); | |||||
cb(4); | |||||
cb(5); | |||||
default: | |||||
mgb_throw(InternalError, "Unsupported ndim, got %zu", | |||||
tensor.layout.ndim); | |||||
#undef cb | |||||
} | |||||
} | |||||
} // namespace | |||||
MLIRCPUExecutable::MLIRCPUExecutable(mlir::OwningModuleRef& module, | |||||
const std::string& kernel_name) | |||||
: m_kernel_name{kernel_name} { | |||||
auto opt_pipeline = mlir::makeOptimizingTransformer(3, 3, 0); | |||||
std::vector<std::string> libs; | |||||
auto&& engine = mlir::ExecutionEngine::create( | |||||
*module, opt_pipeline, llvm::None, | |||||
std::vector<llvm::StringRef>(libs.begin(), libs.end()), true, | |||||
false); | |||||
mgb_assert(engine); | |||||
m_engine = std::move(*engine); | |||||
} | |||||
void MLIRCPUExecutable::execute(JITExecutor* fusion_opr) { | |||||
auto&& args = fusion_opr->args(); | |||||
std::vector<void*> args_array(args.inputs.size() + args.outputs.size()); | |||||
std::vector<void*> args_array_pointer(args.inputs.size() + | |||||
args.outputs.size()); | |||||
size_t idx = 0; | |||||
for (size_t i = 0; i < args.inputs.size(); i++) { | |||||
args_array[idx] = | |||||
tensor2memref({args.inputs[i].from->dev_tensor().raw_ptr(), | |||||
args.inputs[i].layout}); | |||||
args_array_pointer[idx] = &args_array[idx]; | |||||
idx++; | |||||
} | |||||
int64_t nr_elements = 0; | |||||
for (size_t i = 0; i < args.outputs.size(); i++) { | |||||
if (nr_elements == 0) { | |||||
nr_elements = args.outputs[i].layout.total_nr_elems(); | |||||
} else { | |||||
mgb_assert(static_cast<size_t>(nr_elements) == | |||||
args.outputs[i].layout.total_nr_elems(), | |||||
"The number of elements of outputs mismatch, expected: " | |||||
"%zu got: %zu(%s)", | |||||
static_cast<size_t>(nr_elements), | |||||
args.outputs[i].layout.total_nr_elems(), | |||||
args.outputs[i].layout.to_string().c_str()); | |||||
} | |||||
args_array[idx] = | |||||
tensor2memref({args.outputs[i].from->dev_tensor().raw_ptr(), | |||||
args.outputs[i].layout}); | |||||
args_array_pointer[idx] = &args_array[idx]; | |||||
idx++; | |||||
} | |||||
args_array_pointer[idx++] = &nr_elements; | |||||
std::string adapter_name = std::string("_mlir_ciface_") + m_kernel_name; | |||||
auto err = m_engine->invoke( | |||||
adapter_name, llvm::MutableArrayRef<void*>(args_array_pointer)); | |||||
if (err) { | |||||
mgb_throw(InternalError, "failed to run MLIR kernel %s\n", | |||||
m_kernel_name.c_str()); | |||||
} | |||||
for (size_t i = 0; i < args_array.size(); i++) { | |||||
free(args_array[i]); | |||||
} | |||||
} | |||||
MLIRCPUExecutable::~MLIRCPUExecutable() {} | |||||
#endif // MGB_JIT && MGB_JIT_MLIR | |||||
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} |
@@ -0,0 +1,51 @@ | |||||
/** | |||||
* \file src/jit/impl/mlir/executable_cpu.h | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | |||||
#pragma once | |||||
#include "megbrain_build_config.h" | |||||
#if MGB_JIT && MGB_JIT_MLIR | |||||
#include "megbrain/jit/compiler.h" | |||||
#include <mlir/ExecutionEngine/ExecutionEngine.h> | |||||
#include <mlir/IR/Module.h> | |||||
namespace mgb { | |||||
namespace jit { | |||||
/*! | |||||
* \brief Executable class for MLIR | |||||
*/ | |||||
class MLIRCPUExecutable final : public Executable { | |||||
public: | |||||
MLIRCPUExecutable(mlir::OwningModuleRef& module, | |||||
const std::string& kernel_name); | |||||
~MLIRCPUExecutable(); | |||||
/*! | |||||
* \brief execute | |||||
* A executable instance can be executed by one or more fusion_opr | |||||
*/ | |||||
void execute(JITExecutor* fusion_opr) override final; | |||||
private: | |||||
std::unique_ptr<mlir::ExecutionEngine> m_engine; | |||||
std::string m_kernel_name; | |||||
}; | |||||
} // namespace jit | |||||
} // namespace mgb | |||||
#endif // MGB_JIT && MGB_JIT_MLIR | |||||
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} |
@@ -0,0 +1,166 @@ | |||||
/** | |||||
* \file src/jit/impl/mlir/executable_cuda.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | |||||
#include <vector> | |||||
#include "megbrain_build_config.h" | |||||
#include "megdnn/dtype.h" | |||||
#if MGB_JIT && MGB_JIT_MLIR | |||||
#if MGB_CUDA | |||||
#include "./executable_cuda.h" | |||||
#include "./utils.h" | |||||
#include "megbrain/utils/timer.h" | |||||
#include "megbrain/utils/persistent_cache.h" | |||||
#include "megbrain/comp_node_env.h" | |||||
#include <mlir/ExecutionEngine/OptUtils.h> | |||||
#include <mlir/Dialect/GPU/GPUDialect.h> | |||||
#include <mlir/IR/OpDefinition.h> | |||||
using namespace mgb; | |||||
using namespace jit; | |||||
namespace { | |||||
template <int out_dim, typename ctype> | |||||
void setup_and_launch(const JITExecutor* fusion_opr, CUfunction func, | |||||
int block_size) { | |||||
auto&& args = fusion_opr->args(); | |||||
std::vector<StridedMemRefType<ctype, out_dim>> param_holders; | |||||
std::vector<void*> params; | |||||
auto set_params = [¶m_holders, ¶ms]( | |||||
void* ptr, const megdnn::TensorLayout& layout) { | |||||
param_holders.push_back(StridedMemRefType<ctype, out_dim>{}); | |||||
StridedMemRefType<ctype, out_dim>& desc = param_holders.back(); | |||||
desc.basePtr = static_cast<ctype*>(ptr); | |||||
params.push_back(&(desc.basePtr)); | |||||
desc.data = static_cast<ctype*>(ptr); | |||||
params.push_back(&(desc.data)); | |||||
desc.offset = 0; | |||||
params.push_back(&(desc.offset)); | |||||
for (size_t i = 0; i < layout.ndim; i++) { | |||||
desc.sizes[i] = layout.shape[i]; | |||||
params.push_back(&(desc.sizes[i])); | |||||
desc.strides[i] = layout.stride[i]; | |||||
params.push_back(&(desc.strides[i])); | |||||
} | |||||
}; | |||||
for (const auto& arg : args.inputs) { | |||||
set_params(arg.from->dev_tensor().raw_ptr(), arg.layout); | |||||
} | |||||
int64_t nr_elements = 0; | |||||
for (const auto& arg : args.outputs) { | |||||
if (nr_elements == 0) { | |||||
nr_elements = arg.layout.total_nr_elems(); | |||||
} else { | |||||
mgb_assert(static_cast<size_t>(nr_elements) == | |||||
arg.layout.total_nr_elems(), | |||||
"The number of elements of outputs mismatch, expected: " | |||||
"%zu got: %zu(%s)", | |||||
static_cast<size_t>(nr_elements), | |||||
arg.layout.total_nr_elems(), | |||||
arg.layout.to_string().c_str()); | |||||
} | |||||
set_params(arg.from->dev_tensor().raw_ptr(), arg.layout); | |||||
} | |||||
const CompNodeEnv& env = | |||||
CompNodeEnv::from_comp_node(fusion_opr->comp_node()); | |||||
int64_t num_block = (nr_elements - 1) / block_size + 1; | |||||
params.insert(params.begin(), &nr_elements); | |||||
MGB_CUDA_CU_CHECK(cuLaunchKernel(func, num_block, 1, 1, block_size, 1, 1, 0, | |||||
env.cuda_env().stream, params.data(), 0)); | |||||
} | |||||
} // namespace | |||||
const std::string MLIRCUDAExecutable::sm_blob_annotation = "nvvm.cubin"; | |||||
MLIRCUDAExecutable::MLIRCUDAExecutable(mlir::OwningModuleRef& module, | |||||
const std::string& kernel_name) { | |||||
m_kernel_name = kernel_name + "_kernel"; | |||||
auto kernel_module = | |||||
module->lookupSymbol<mlir::gpu::GPUModuleOp>(m_kernel_name); | |||||
mgb_assert(kernel_module, "Expected gpu kernel module"); | |||||
auto binary_attr = kernel_module.getAttrOfType<mlir::StringAttr>( | |||||
llvm::StringRef(sm_blob_annotation)); | |||||
mgb_assert(binary_attr, "Missing %s attribute in gpu kernel module", | |||||
sm_blob_annotation.c_str()); | |||||
m_kernel_data = binary_attr.getValue().str(); | |||||
} | |||||
void MLIRCUDAExecutable::execute(JITExecutor* fusion_opr) { | |||||
FuncCache* func; | |||||
auto cn = fusion_opr->comp_node(); | |||||
auto&& prop = CompNodeEnv::from_comp_node(cn).cuda_env().device_prop; | |||||
func = &m_func_cache[{prop.major, prop.minor}]; | |||||
func->kernel_data = m_kernel_data; | |||||
func->exec(fusion_opr, this); | |||||
} | |||||
MLIRCUDAExecutable::~MLIRCUDAExecutable() {} | |||||
void MLIRCUDAExecutable::FuncCache::exec(const JITExecutor* fusion_opr, | |||||
const MLIRCUDAExecutable* cuda_exe) { | |||||
Func* func; | |||||
{ | |||||
MGB_LOCK_GUARD(mtx); | |||||
auto ins = cn2func.insert({fusion_opr->comp_node(), {}}); | |||||
func = &ins.first->second; | |||||
if (ins.second) { | |||||
MGB_CUDA_CU_CHECK( | |||||
cuModuleLoadData(&func->module, kernel_data.data())); | |||||
MGB_CUDA_CU_CHECK( | |||||
cuModuleGetFunction(&func->func, func->module, | |||||
cuda_exe->m_kernel_name.c_str())); | |||||
int min_grid_size = 0; | |||||
MGB_CUDA_CU_CHECK(cuOccupancyMaxPotentialBlockSize( | |||||
&min_grid_size, &func->block_size, func->func, nullptr, 0, | |||||
0)); | |||||
} | |||||
} | |||||
mgb_assert(fusion_opr->args().outputs.size() == 1, | |||||
"Currently only support 1 outputs, got %zu", | |||||
fusion_opr->args().outputs.size()); | |||||
int out_dim = fusion_opr->args().outputs[0].layout.ndim; | |||||
DType dtype = fusion_opr->args().outputs[0].layout.dtype; | |||||
#define cb_outdim(_ndim, _dtype) \ | |||||
if (_ndim == out_dim) { \ | |||||
setup_and_launch<_ndim, _dtype>(fusion_opr, func->func, \ | |||||
func->block_size); \ | |||||
return; \ | |||||
} | |||||
#define cb(_dtype) \ | |||||
cb_outdim(1, float); \ | |||||
cb_outdim(2, float); \ | |||||
cb_outdim(3, float); \ | |||||
cb_outdim(4, float); \ | |||||
mgb_throw(InternalError, "unsupported out_dim=%zu", \ | |||||
static_cast<size_t>(out_dim)); \ | |||||
return; | |||||
switch (dtype.enumv()) { | |||||
case DTypeEnum::Float32: | |||||
cb(float); | |||||
default: | |||||
mgb_throw(InternalError, "unsupport dtype: %s", dtype.name()); | |||||
} | |||||
#undef cb | |||||
#undef cb_outdim | |||||
} | |||||
#endif // MGB_CUDA | |||||
#endif // MGB_JIT && MGB_JIT_MLIR | |||||
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} |
@@ -0,0 +1,74 @@ | |||||
/** | |||||
* \file src/jit/impl/mlir/executable_cuda.h | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | |||||
#pragma once | |||||
#include "megbrain_build_config.h" | |||||
#if MGB_JIT && MGB_JIT_MLIR | |||||
#if MGB_CUDA | |||||
#include "megbrain/jit/compiler.h" | |||||
#include <mlir/IR/Module.h> | |||||
#include <cuda.h> | |||||
namespace mgb { | |||||
namespace jit { | |||||
/*! | |||||
* \brief Executable class for MLIR | |||||
*/ | |||||
class MLIRCUDAExecutable final : public Executable { | |||||
public: | |||||
MLIRCUDAExecutable(mlir::OwningModuleRef& module, | |||||
const std::string& kernel_name); | |||||
~MLIRCUDAExecutable(); | |||||
/*! | |||||
* \brief execute | |||||
* A executable instance can be executed by one or more fusion_opr | |||||
*/ | |||||
void execute(JITExecutor* fusion_opr) override final; | |||||
const static std::string sm_blob_annotation; | |||||
private: | |||||
//! cache for a func on a specific device | |||||
struct FuncCache { | |||||
struct Func { | |||||
int block_size{-1}; | |||||
CUmodule module{nullptr}; | |||||
CUfunction func{nullptr}; | |||||
}; | |||||
std::mutex mtx; | |||||
std::string kernel_data; | |||||
CompNode::UnorderedMap<Func> cn2func; | |||||
void exec(const JITExecutor* fusion_opr, | |||||
const MLIRCUDAExecutable* cuda_exe); | |||||
}; | |||||
std::string m_kernel_name; | |||||
std::string m_kernel_data; | |||||
//! (cuda_major, cuda_minor) => func | |||||
ThinHashMap<std::pair<uint32_t, uint32_t>, FuncCache> m_func_cache; | |||||
}; | |||||
} // namespace jit | |||||
} // namespace mgb | |||||
#endif // MGB_CUDA | |||||
#endif // MGB_JIT && MGB_JIT_MLIR | |||||
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} |
@@ -0,0 +1,41 @@ | |||||
/** | |||||
* \file src/jit/impl/mlir/ir/common.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | |||||
#include "megbrain_build_config.h" | |||||
#if MGB_JIT && MGB_JIT_MLIR | |||||
#include "common.h" | |||||
#include <mlir/Dialect/Affine/IR/AffineOps.h> | |||||
using namespace mgb; | |||||
using namespace jit; | |||||
mlir::Value jit::insert_alloc_and_dealloc(mlir::MemRefType type, | |||||
mlir::Location loc, | |||||
mlir::PatternRewriter& rewriter) { | |||||
auto alloc = rewriter.create<mlir::AllocOp>(loc, type); | |||||
// Make sure to allocate at the beginning of the block. | |||||
auto* parent_block = alloc.getOperation()->getBlock(); | |||||
alloc.getOperation()->moveBefore(&parent_block->front()); | |||||
// Make sure to deallocate this alloc at the end of the block. This is fine | |||||
// as toy functions have no control flow. | |||||
auto dealloc = rewriter.create<mlir::DeallocOp>(loc, alloc); | |||||
dealloc.getOperation()->moveBefore(&parent_block->back()); | |||||
return alloc; | |||||
} | |||||
#endif // MGB_JIT_MLIR | |||||
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} |
@@ -0,0 +1,32 @@ | |||||
/** | |||||
* \file src/jit/impl/mlir/ir/common.h | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#pragma once | |||||
#include "megbrain_build_config.h" | |||||
#if MGB_JIT && MGB_JIT_MLIR | |||||
#include <mlir/IR/PatternMatch.h> | |||||
#include <mlir/IR/StandardTypes.h> | |||||
#include <mlir/IR/Value.h> | |||||
namespace mgb { | |||||
namespace jit { | |||||
mlir::Value insert_alloc_and_dealloc(mlir::MemRefType type, mlir::Location loc, | |||||
mlir::PatternRewriter& rewriter); | |||||
} // namespace jit | |||||
} // namespace mgb | |||||
#endif // MGB_JIT_MLIR | |||||
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} |
@@ -0,0 +1,91 @@ | |||||
/** | |||||
* \file src/jit/impl/mlir/ir/dialect.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | |||||
#include "megbrain_build_config.h" | |||||
#if MGB_JIT && MGB_JIT_MLIR | |||||
#include "megbrain/jit/mlir/ir/dialect.h" | |||||
#include <mlir/Support/LogicalResult.h> | |||||
#include <mlir/IR/Builders.h> | |||||
#include <mlir/IR/OpImplementation.h> | |||||
#include <mlir/IR/StandardTypes.h> | |||||
using namespace mgb; | |||||
using namespace jit; | |||||
MgbDialect::MgbDialect(mlir::MLIRContext *ctx) : mlir::Dialect("mgb", ctx) { | |||||
addOperations< | |||||
#define GET_OP_LIST | |||||
#include "megbrain/jit/mlir/ir/ops.cpp.inc" | |||||
>(); | |||||
} | |||||
static mlir::ParseResult parseBinaryOp(mlir::OpAsmParser &parser, | |||||
mlir::OperationState &result) { | |||||
SmallVector<mlir::OpAsmParser::OperandType, 2> operands; | |||||
llvm::SMLoc operandsLoc = parser.getCurrentLocation(); | |||||
Type type; | |||||
if (parser.parseOperandList(operands, /*requiredOperandCount=*/2) || | |||||
parser.parseOptionalAttrDict(result.attributes) || | |||||
parser.parseColonType(type)) | |||||
return mlir::failure(); | |||||
// If the type is a function type, it contains the input and result types of | |||||
// this operation. | |||||
if (FunctionType funcType = type.dyn_cast<FunctionType>()) { | |||||
if (parser.resolveOperands(operands, funcType.getInputs(), operandsLoc, | |||||
result.operands)) | |||||
return mlir::failure(); | |||||
result.addTypes(funcType.getResults()); | |||||
return mlir::success(); | |||||
} | |||||
// Otherwise, the parsed type is the type of both operands and results. | |||||
if (parser.resolveOperands(operands, type, result.operands)) | |||||
return mlir::failure(); | |||||
result.addTypes(type); | |||||
return mlir::success(); | |||||
} | |||||
static void printBinaryOp(mlir::OpAsmPrinter &printer, mlir::Operation *op) { | |||||
printer << op->getName() << " " << op->getOperands(); | |||||
printer.printOptionalAttrDict(op->getAttrs()); | |||||
printer << " : "; | |||||
// If all of the types are the same, print the type directly. | |||||
Type resultType = *op->result_type_begin(); | |||||
if (llvm::all_of(op->getOperandTypes(), | |||||
[=](Type type) { return type == resultType; })) { | |||||
printer << resultType; | |||||
return; | |||||
} | |||||
// Otherwise, print a functional type. | |||||
printer.printFunctionalType(op->getOperandTypes(), op->getResultTypes()); | |||||
} | |||||
///////////////////////// ElemwiseOp ///////////////////////////////////////////// | |||||
void AddOp::build(mlir::OpBuilder &builder, mlir::OperationState &state, | |||||
mlir::Value lhs, mlir::Value rhs) { | |||||
state.addTypes(lhs.getType()); | |||||
state.addOperands({lhs, rhs}); | |||||
} | |||||
void AddOp::infer_shapes() { getResult().setType(getOperand(0).getType()); } | |||||
#define GET_OP_CLASSES | |||||
#include "megbrain/jit/mlir/ir/ops.cpp.inc" | |||||
#endif // MGB_JIT_MLIR | |||||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,159 @@ | |||||
/** | |||||
* \file src/jit/impl/mlir/ir/lower_to_affine_pass.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | |||||
#include "megbrain_build_config.h" | |||||
#if MGB_JIT && MGB_JIT_MLIR | |||||
#include "megbrain/common.h" | |||||
#include "megbrain/jit/mlir/ir/dialect.h" | |||||
#include "megbrain/jit/mlir/ir/passes.h" | |||||
#include "./common.h" | |||||
#include <mlir/Dialect/Affine/IR/AffineOps.h> | |||||
#include <mlir/Dialect/StandardOps/IR/Ops.h> | |||||
#include <mlir/Pass/Pass.h> | |||||
#include <mlir/Transforms/DialectConversion.h> | |||||
#include <llvm/ADT/Sequence.h> | |||||
using namespace mgb; | |||||
using namespace jit; | |||||
namespace { | |||||
using LoopIterationFn = function_ref<Value( | |||||
OpBuilder& rewriter, ValueRange memRefOperands, ValueRange loopIvs)>; | |||||
void lower_op_to_loops(Operation* op, ValueRange operands, | |||||
PatternRewriter& rewriter, | |||||
LoopIterationFn process_iteration) { | |||||
auto memref_type = (*op->result_type_begin()).cast<MemRefType>(); | |||||
auto loc = op->getLoc(); | |||||
auto alloc = jit::insert_alloc_and_dealloc(memref_type, loc, rewriter); | |||||
SmallVector<int64_t, 4> lower_bounds(memref_type.getRank(), 0); | |||||
SmallVector<int64_t, 4> steps(memref_type.getRank(), 1); | |||||
buildAffineLoopNest( | |||||
rewriter, loc, lower_bounds, memref_type.getShape(), steps, | |||||
[&](OpBuilder& nested_builder, Location loc, ValueRange ivs) { | |||||
Value value_to_store = | |||||
process_iteration(nested_builder, operands, ivs); | |||||
nested_builder.create<AffineStoreOp>(loc, value_to_store, alloc, | |||||
ivs); | |||||
}); | |||||
// Replace this operation with the generated alloc. | |||||
rewriter.replaceOp(op, alloc); | |||||
} | |||||
template <typename BinaryOp, typename LoweredBinaryOp> | |||||
struct BinaryOpLowering : public ConversionPattern { | |||||
BinaryOpLowering(MLIRContext* ctx) | |||||
: ConversionPattern(BinaryOp::getOperationName(), 1, ctx) {} | |||||
LogicalResult matchAndRewrite( | |||||
Operation* op, ArrayRef<Value> operands, | |||||
ConversionPatternRewriter& rewriter) const final { | |||||
auto loc = op->getLoc(); | |||||
lower_op_to_loops( | |||||
op, operands, rewriter, | |||||
[loc](OpBuilder& builder, ValueRange memref_operands, | |||||
ValueRange loop_ivs) { | |||||
typename BinaryOp::Adaptor binary_adaptor(memref_operands); | |||||
auto loaded_lhs = builder.create<AffineLoadOp>( | |||||
loc, binary_adaptor.lhs(), loop_ivs); | |||||
auto loaded_rhs = builder.create<AffineLoadOp>( | |||||
loc, binary_adaptor.rhs(), loop_ivs); | |||||
return builder.create<LoweredBinaryOp>(loc, loaded_lhs, | |||||
loaded_rhs); | |||||
}); | |||||
return success(); | |||||
} | |||||
}; | |||||
using AddOpLowering = BinaryOpLowering<jit::AddOp, AddFOp>; | |||||
struct AssignOpLowering : public ConversionPattern { | |||||
AssignOpLowering(MLIRContext* ctx) | |||||
: ConversionPattern(jit::AssignOp::getOperationName(), 1, ctx) {} | |||||
LogicalResult matchAndRewrite( | |||||
Operation* op, ArrayRef<Value> operands, | |||||
ConversionPatternRewriter& rewriter) const final { | |||||
auto loc = op->getLoc(); | |||||
auto memref_type = operands[0].getType().cast<MemRefType>(); | |||||
AssignOpAdaptor assign_adaptor(operands); | |||||
SmallVector<int64_t, 4> lower_bounds(memref_type.getRank(), 0); | |||||
SmallVector<int64_t, 4> steps(memref_type.getRank(), 1); | |||||
buildAffineLoopNest( | |||||
rewriter, loc, lower_bounds, memref_type.getShape(), steps, | |||||
[&](OpBuilder& nested_builder, Location loc, ValueRange ivs) { | |||||
auto loaded_lhs = nested_builder.create<AffineLoadOp>( | |||||
loc, assign_adaptor.lhs(), ivs); | |||||
nested_builder.create<AffineStoreOp>( | |||||
loc, loaded_lhs, assign_adaptor.rhs(), ivs); | |||||
}); | |||||
rewriter.eraseOp(op); | |||||
return success(); | |||||
} | |||||
}; | |||||
struct ReturnOpLowering : public OpRewritePattern<jit::ReturnOp> { | |||||
using OpRewritePattern<jit::ReturnOp>::OpRewritePattern; | |||||
LogicalResult matchAndRewrite(jit::ReturnOp op, | |||||
PatternRewriter& rewriter) const final { | |||||
rewriter.replaceOpWithNewOp<mlir::ReturnOp>(op); | |||||
return success(); | |||||
} | |||||
}; | |||||
class MgbToAffineLoweringPass | |||||
: public PassWrapper<MgbToAffineLoweringPass, FunctionPass> { | |||||
public: | |||||
void runOnFunction() override final { | |||||
auto function = getFunction(); | |||||
// Verify that the given main has no inputs and results. | |||||
if (function.getType().getNumResults()) { | |||||
mgb_log_error("expected 'main' to have 0 results"); | |||||
return signalPassFailure(); | |||||
} | |||||
ConversionTarget target(getContext()); | |||||
target.addLegalDialect<AffineDialect, StandardOpsDialect>(); | |||||
target.addIllegalDialect<MgbDialect>(); | |||||
OwningRewritePatternList patterns; | |||||
patterns.insert<AddOpLowering, ReturnOpLowering, AssignOpLowering>( | |||||
&getContext()); | |||||
if (failed(applyPartialConversion(getFunction(), target, patterns))) { | |||||
signalPassFailure(); | |||||
} | |||||
} | |||||
}; | |||||
} // namespace | |||||
std::unique_ptr<mlir::Pass> mgb::jit::create_lower_to_affine_pass() { | |||||
return std::make_unique<MgbToAffineLoweringPass>(); | |||||
} | |||||
#endif // MGB_JIT_MLIR | |||||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,211 @@ | |||||
/** | |||||
* \file src/jit/impl/mlir/ir/lower_to_gpu_pass.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | |||||
#include "megbrain_build_config.h" | |||||
#if MGB_JIT && MGB_JIT_MLIR | |||||
#include "megbrain/common.h" | |||||
#include "megbrain/jit/mlir/ir/dialect.h" | |||||
#include "megbrain/jit/mlir/ir/passes.h" | |||||
#include "../utils.h" | |||||
#include <mlir/Dialect/GPU/GPUDialect.h> | |||||
#include <mlir/Dialect/SCF/SCF.h> | |||||
#include <mlir/Dialect/StandardOps/IR/Ops.h> | |||||
#include <mlir/EDSC/Builders.h> | |||||
#include <mlir/IR/StandardTypes.h> | |||||
#include <mlir/Pass/Pass.h> | |||||
#include <mlir/Transforms/DialectConversion.h> | |||||
#include <llvm/ADT/PointerUnion.h> | |||||
#include <llvm/ADT/Sequence.h> | |||||
#include <llvm/ADT/SetVector.h> | |||||
#include <llvm/ADT/Twine.h> | |||||
#include <llvm/IR/Type.h> | |||||
using namespace mgb; | |||||
using namespace jit; | |||||
namespace { | |||||
mlir::Value get_operand(ConversionPatternRewriter& rewriter, | |||||
const mlir::Location& loc, const mlir::Value& val, | |||||
const mlir::Value& index) { | |||||
if (val.getType().isa<mlir::MemRefType>()) { | |||||
return rewriter.create<LoadOp>(loc, val, index); | |||||
} else { | |||||
return val; | |||||
} | |||||
} | |||||
mlir::Value get_tid(ConversionPatternRewriter& rewriter, const Location& loc) { | |||||
auto thread_idx = rewriter.create<gpu::ThreadIdOp>( | |||||
loc, rewriter.getIndexType(), rewriter.getStringAttr("x")); | |||||
auto block_idx = rewriter.create<gpu::BlockIdOp>( | |||||
loc, rewriter.getIndexType(), rewriter.getStringAttr("x")); | |||||
auto group_size = rewriter.create<gpu::BlockDimOp>( | |||||
loc, rewriter.getIndexType(), rewriter.getStringAttr("x")); | |||||
Value index = rewriter.create<AddIOp>( | |||||
loc, thread_idx, | |||||
rewriter.create<MulIOp>(loc, block_idx, group_size)); | |||||
return index; | |||||
} | |||||
template <typename BinaryOp, typename LoweredBinaryOp> | |||||
struct BinaryOpLowering : public ConversionPattern { | |||||
BinaryOpLowering(MLIRContext* ctx, gpu::LaunchOp* launch_op) | |||||
: ConversionPattern(BinaryOp::getOperationName(), 1, ctx), | |||||
m_launch_op{launch_op} {} | |||||
LogicalResult matchAndRewrite( | |||||
Operation* op, ArrayRef<Value> operands, | |||||
ConversionPatternRewriter& rewriter) const final { | |||||
auto loc = op->getLoc(); | |||||
typename BinaryOp::Adaptor binary_adaptor(operands); | |||||
rewriter.setInsertionPointToEnd(&(m_launch_op->body().front())); | |||||
auto index = get_tid(rewriter, loc); | |||||
auto loaded_lhs = | |||||
get_operand(rewriter, loc, binary_adaptor.lhs(), index); | |||||
auto loaded_rhs = | |||||
get_operand(rewriter, loc, binary_adaptor.rhs(), index); | |||||
auto binary_op = | |||||
rewriter.create<LoweredBinaryOp>(loc, loaded_lhs, loaded_rhs); | |||||
rewriter.replaceOp(op, binary_op.getResult()); | |||||
return success(); | |||||
} | |||||
private: | |||||
gpu::LaunchOp* m_launch_op; | |||||
}; | |||||
using AddOpLowering = BinaryOpLowering<jit::AddOp, AddFOp>; | |||||
struct ReturnOpLowering : public ConversionPattern { | |||||
ReturnOpLowering(MLIRContext* ctx, gpu::LaunchOp* launch_op) | |||||
: ConversionPattern(jit::ReturnOp::getOperationName(), 1, ctx), | |||||
m_launch_op{launch_op} {} | |||||
LogicalResult matchAndRewrite( | |||||
Operation* op, ArrayRef<Value>, | |||||
ConversionPatternRewriter& rewriter) const final { | |||||
rewriter.replaceOpWithNewOp<mlir::ReturnOp>(op); | |||||
auto loc = op->getLoc(); | |||||
//! remove the first gpu.terminator | |||||
m_launch_op->body().front().front().erase(); | |||||
//! if (tid >= nr_tid) {return;} in the begin of the block | |||||
rewriter.setInsertionPointToStart(&(m_launch_op->body().front())); | |||||
Block* cond_block = rewriter.getInsertionBlock(); | |||||
Block::iterator op_position = rewriter.getInsertionPoint(); | |||||
Block* remaining_ops_block = | |||||
rewriter.splitBlock(cond_block, op_position); | |||||
rewriter.setInsertionPointToEnd(cond_block); | |||||
auto index = get_tid(rewriter, loc); | |||||
auto comparison = rewriter.create<mlir::CmpIOp>( | |||||
loc, CmpIPredicate::sge, index, | |||||
m_launch_op->getParentOfType<mlir::FuncOp>() | |||||
.getArguments() | |||||
.back()); | |||||
Block* then_block = | |||||
rewriter.splitBlock(cond_block, rewriter.getInsertionPoint()); | |||||
rewriter.setInsertionPointToEnd(then_block); | |||||
rewriter.create<gpu::TerminatorOp>(loc); | |||||
rewriter.setInsertionPointToEnd(cond_block); | |||||
rewriter.create<mlir::CondBranchOp>( | |||||
loc, comparison, then_block, ArrayRef<Value>(), | |||||
remaining_ops_block, ArrayRef<Value>()); | |||||
rewriter.setInsertionPointToEnd(remaining_ops_block); | |||||
rewriter.create<gpu::TerminatorOp>(loc); | |||||
return success(); | |||||
} | |||||
private: | |||||
gpu::LaunchOp* m_launch_op; | |||||
}; | |||||
struct AssignOpLowering : public ConversionPattern { | |||||
AssignOpLowering(MLIRContext* ctx, gpu::LaunchOp* launch_op) | |||||
: ConversionPattern(jit::AssignOp::getOperationName(), 2, ctx), | |||||
m_launch_op{launch_op} {} | |||||
LogicalResult matchAndRewrite( | |||||
Operation* op, ArrayRef<Value> operands, | |||||
ConversionPatternRewriter& rewriter) const final { | |||||
auto loc = op->getLoc(); | |||||
AssignOpAdaptor assign_adaptor(operands); | |||||
rewriter.setInsertionPointToEnd(&(m_launch_op->body().front())); | |||||
auto index = get_tid(rewriter, loc); | |||||
auto loaded_lhs = | |||||
get_operand(rewriter, loc, assign_adaptor.lhs(), index); | |||||
rewriter.create<StoreOp>(loc, loaded_lhs, assign_adaptor.rhs(), index); | |||||
rewriter.eraseOp(op); | |||||
return success(); | |||||
} | |||||
private: | |||||
gpu::LaunchOp* m_launch_op; | |||||
}; | |||||
class MgbToGpuLoweringPass | |||||
: public PassWrapper<MgbToGpuLoweringPass, FunctionPass> { | |||||
public: | |||||
void runOnFunction() override final { | |||||
auto func_op = getFunction(); | |||||
Location loc = func_op.getLoc(); | |||||
OpBuilder builder(&func_op.getBody()); | |||||
Value constantOne = builder.create<ConstantIndexOp>(loc, 1); | |||||
gpu::LaunchOp launch_op = builder.create<gpu::LaunchOp>( | |||||
loc, constantOne, constantOne, constantOne, constantOne, | |||||
constantOne, constantOne); | |||||
builder.setInsertionPointToEnd(&(launch_op.body().front())); | |||||
builder.create<gpu::TerminatorOp>(loc); | |||||
OwningRewritePatternList patterns; | |||||
ConversionTarget target(getContext()); | |||||
target.addLegalDialect<StandardOpsDialect>(); | |||||
target.addLegalDialect<gpu::GPUDialect>(); | |||||
target.addIllegalDialect<MgbDialect>(); | |||||
patterns.insert<AddOpLowering, AssignOpLowering, ReturnOpLowering>( | |||||
&getContext(), &launch_op); | |||||
if (failed(applyPartialConversion(func_op, target, patterns))) { | |||||
signalPassFailure(); | |||||
} | |||||
} | |||||
}; | |||||
} // namespace | |||||
std::unique_ptr<mlir::Pass> mgb::jit::create_lower_to_gpu_pass() { | |||||
return std::make_unique<MgbToGpuLoweringPass>(); | |||||
} | |||||
#endif // MGB_JIT_MLIR | |||||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,56 @@ | |||||
/** | |||||
* \file src/jit/impl/mlir/ir/lower_to_llvm_pass.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | |||||
#include "megbrain_build_config.h" | |||||
#if MGB_JIT && MGB_JIT_MLIR | |||||
#include "megbrain/common.h" | |||||
#include "megbrain/jit/mlir/ir/dialect.h" | |||||
#include "megbrain/jit/mlir/ir/passes.h" | |||||
#include <mlir/Conversion/AffineToStandard/AffineToStandard.h> | |||||
#include <mlir/Conversion/SCFToStandard/SCFToStandard.h> | |||||
#include <mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h> | |||||
#include <mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h> | |||||
using namespace mgb; | |||||
using namespace jit; | |||||
namespace { | |||||
class AffineToLLVMLoweringPass : public PassWrapper<AffineToLLVMLoweringPass, | |||||
OperationPass<ModuleOp>> { | |||||
void runOnOperation() final { | |||||
LLVMConversionTarget target(getContext()); | |||||
target.addLegalOp<ModuleOp, ModuleTerminatorOp>(); | |||||
LLVMTypeConverter typeConverter(&getContext()); | |||||
OwningRewritePatternList patterns; | |||||
populateAffineToStdConversionPatterns(patterns, &getContext()); | |||||
populateLoopToStdConversionPatterns(patterns, &getContext()); | |||||
populateStdToLLVMConversionPatterns(typeConverter, patterns); | |||||
auto module = getOperation(); | |||||
if (failed(applyFullConversion(module, target, patterns))) | |||||
signalPassFailure(); | |||||
} | |||||
}; | |||||
} // namespace | |||||
std::unique_ptr<mlir::Pass> mgb::jit::create_lower_to_llvm_pass() { | |||||
return std::make_unique<AffineToLLVMLoweringPass>(); | |||||
} | |||||
#endif // MGB_JIT_MLIR | |||||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,72 @@ | |||||
/** | |||||
* \file src/jit/impl/mlir/ir/ops.td | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | |||||
#ifndef MGB_MLIR_OPS | |||||
#define MGB_MLIR_OPS | |||||
include "mlir/IR/OpBase.td" | |||||
include "mlir/Interfaces/SideEffectInterfaces.td" | |||||
include "./shape_inference_interface.td" | |||||
def Mgb_Dialect : Dialect { | |||||
let name = "mgb"; | |||||
let cppNamespace = "mgb::jit"; | |||||
} | |||||
class ElemwiseOp<string mnemonic, list<OpTrait> traits = []> : | |||||
Op<Mgb_Dialect, mnemonic, traits>; | |||||
class GenericOp<string mnemonic, list<OpTrait> traits = []> : | |||||
Op<Mgb_Dialect, mnemonic, traits>; | |||||
def AddOp : ElemwiseOp<"add", | |||||
[NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> { | |||||
let summary = "element-wise addition operation"; | |||||
let description = [{ | |||||
The "add" operation performs element-wise addition between two tensors. | |||||
The shapes of the tensor operands are expected to match. | |||||
}]; | |||||
let arguments = (ins F32MemRef:$lhs, F32MemRef:$rhs); | |||||
let results = (outs F32MemRef); | |||||
// Specify a parser and printer method. | |||||
let parser = [{ return ::parseBinaryOp(parser, result); }]; | |||||
let printer = [{ return ::printBinaryOp(p, *this); }]; | |||||
// Allow building an AddOp with from the two input operands. | |||||
let builders = [ | |||||
OpBuilder<"OpBuilder &b, OperationState &state, Value lhs, Value rhs"> | |||||
]; | |||||
} | |||||
def ReturnOp : GenericOp<"return", | |||||
[NoSideEffect, HasParent<"FuncOp">, Terminator]> { | |||||
let summary = "return operation"; | |||||
let description = [{ | |||||
The "return" operation represents a return operation within a function. | |||||
The operation takes an no tensor operand and produces no results. | |||||
}]; | |||||
} | |||||
def AssignOp : GenericOp<"assign", []> { | |||||
let summary = "assign op"; | |||||
let description = [{ | |||||
assign rhs to lhs without results | |||||
}]; | |||||
let arguments = (ins F32MemRef:$lhs, F32MemRef:$rhs); | |||||
} | |||||
#endif |
@@ -0,0 +1,30 @@ | |||||
/** | |||||
* \file src/jit/impl/mlir/ir/shape_inference_interface.td | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | |||||
#ifndef MGB_JIT_SHAPE_INFERENCE_INTERFACE | |||||
#define MGB_JIT_SHAPE_INFERENCE_INTERFACE | |||||
include "mlir/IR/OpBase.td" | |||||
def ShapeInferenceOpInterface : OpInterface<"ShapeInference"> { | |||||
let description = [{ | |||||
Interface to access a registered method to infer the return types for an | |||||
operation that can be used during type inference. | |||||
}]; | |||||
let methods = [ | |||||
InterfaceMethod<"Infer and set the output shape for the current operation.", | |||||
"void", "infer_shapes"> | |||||
]; | |||||
} | |||||
#endif // MGB_SHAPE_INFERENCE_INTERFACE |
@@ -0,0 +1,100 @@ | |||||
/** | |||||
* \file src/jit/impl/mlir/ir/shape_inference_pass.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | |||||
#include "megbrain_build_config.h" | |||||
#if MGB_JIT && MGB_JIT_MLIR | |||||
#include "megbrain/common.h" | |||||
#include "megbrain/jit/mlir/ir/dialect.h" | |||||
#include "megbrain/jit/mlir/ir/passes.h" | |||||
#include "megbrain/jit/mlir/ir/shape_inference_interface.h" | |||||
#include <llvm/ADT/SmallPtrSet.h> | |||||
#include <mlir/IR/StandardTypes.h> | |||||
#include <mlir/Pass/Pass.h> | |||||
using namespace mgb; | |||||
using namespace jit; | |||||
#include "megbrain/jit/mlir/ir/shape_inference_interface.cpp.inc" | |||||
namespace { | |||||
class ShapeInferencePass | |||||
: public mlir::PassWrapper<ShapeInferencePass, FunctionPass> { | |||||
public: | |||||
void runOnFunction() override { | |||||
auto f = getFunction(); | |||||
llvm::SmallPtrSet<mlir::Operation*, 16> op_worklist; | |||||
f.walk([&](mlir::Operation* op) { | |||||
if (returns_dynamic_shape(op)) | |||||
op_worklist.insert(op); | |||||
}); | |||||
// Iterate on the operations in the worklist until all operations have | |||||
// been inferred or no change happened (fix point). | |||||
while (!op_worklist.empty()) { | |||||
// Find the next operation ready for inference, that is an operation | |||||
// with all operands already resolved (non-generic). | |||||
auto nextop = llvm::find_if(op_worklist, all_operands_inferred); | |||||
if (nextop == op_worklist.end()) | |||||
break; | |||||
Operation* op = *nextop; | |||||
op_worklist.erase(op); | |||||
if (auto shapeOp = dyn_cast<ShapeInference>(op)) { | |||||
shapeOp.infer_shapes(); | |||||
} else { | |||||
mgb_log_error( | |||||
"unable to infer shape of operation without shape " | |||||
"inference interface"); | |||||
return signalPassFailure(); | |||||
} | |||||
} | |||||
// If the operation worklist isn't empty, this indicates a failure. | |||||
if (!op_worklist.empty()) { | |||||
mgb_log_error( | |||||
"Shape inference failed, %zu operations couldn't be " | |||||
"inferred", | |||||
op_worklist.size()); | |||||
signalPassFailure(); | |||||
} | |||||
} | |||||
//! A utility method that returns if the given operation has all of its | |||||
//! operands inferred. | |||||
static bool all_operands_inferred(Operation* op) { | |||||
return llvm::all_of(op->getOperandTypes(), [](Type operandType) { | |||||
return operandType.isa<mlir::MemRefType>(); | |||||
}); | |||||
} | |||||
//! A utility method that returns if the given operation has a dynamically | |||||
//! shaped result. | |||||
static bool returns_dynamic_shape(Operation* op) { | |||||
return llvm::any_of(op->getResultTypes(), [](Type resultType) { | |||||
return !resultType.isa<mlir::MemRefType>(); | |||||
}); | |||||
} | |||||
}; | |||||
} // namespace | |||||
std::unique_ptr<mlir::Pass> mgb::jit::create_shape_inference_pass() { | |||||
return std::make_unique<ShapeInferencePass>(); | |||||
} | |||||
#endif // MGB_JIT_MLIR | |||||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,207 @@ | |||||
/** | |||||
* \file src/jit/impl/mlir/mlir_gen.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | |||||
#include "megbrain_build_config.h" | |||||
#if MGB_JIT && MGB_JIT_MLIR | |||||
#include "./mlir_gen.h" | |||||
#include "./utils.h" | |||||
#include "megbrain/jit/mlir/ir/dialect.h" | |||||
#include "megbrain/opr/basic_arith.h" | |||||
#include "megdnn/dtype.h" | |||||
#include <mlir/Dialect/Affine/IR/AffineOps.h> | |||||
#include <mlir/Dialect/StandardOps/IR/Ops.h> | |||||
#include <mlir/IR/Attributes.h> | |||||
#include <mlir/IR/Builders.h> | |||||
#include <mlir/IR/Function.h> | |||||
#include <mlir/IR/MLIRContext.h> | |||||
#include <mlir/IR/Module.h> | |||||
#include <mlir/IR/StandardTypes.h> | |||||
#include <mlir/IR/Types.h> | |||||
#include <mlir/IR/Value.h> | |||||
#include <mlir/IR/Verifier.h> | |||||
#include <mlir/Support/LogicalResult.h> | |||||
#include <llvm/ADT/ScopedHashTable.h> | |||||
#include <llvm/Support/raw_ostream.h> | |||||
using namespace mgb; | |||||
using namespace jit; | |||||
namespace { | |||||
class MLIRGenImpl { | |||||
public: | |||||
MLIRGenImpl(mlir::MLIRContext& context) : m_builder(&context) {} | |||||
std::pair<llvm::StringRef, mlir::OwningModuleRef> gen( | |||||
const InternalGraph& internal_graph, | |||||
const JITExecutor::Args& args) { | |||||
mlir::ModuleOp module = | |||||
mlir::ModuleOp::create(m_builder.getUnknownLoc()); | |||||
//! Create main routine function | |||||
auto func_op = gen_func_op(internal_graph, args); | |||||
module.push_back(func_op); | |||||
if (mlir::failed(mlir::verify(module))) { | |||||
module.emitError("module verification error"); | |||||
return {}; | |||||
} | |||||
return {func_op.getName(), module}; | |||||
} | |||||
private: | |||||
mlir::OpBuilder m_builder; | |||||
llvm::ScopedHashTable<mlir::StringRef, mlir::Value> m_symbol_table; | |||||
mlir::FuncOp gen_func_op(const InternalGraph& internal_graph, | |||||
const JITExecutor::Args& args) { | |||||
llvm::ScopedHashTableScope<llvm::StringRef, mlir::Value> var_scope( | |||||
m_symbol_table); | |||||
std::vector<mlir::Type> func_args; | |||||
for (auto&& arg : args.inputs) { | |||||
func_args.push_back(get_type(arg.layout)); | |||||
} | |||||
for (auto&& arg : args.outputs) { | |||||
func_args.push_back(get_type(arg.layout)); | |||||
} | |||||
//! the last arg is nr_elements | |||||
func_args.push_back(m_builder.getIndexType()); | |||||
auto func_type = m_builder.getFunctionType(func_args, llvm::None); | |||||
//! function name maybe renamed in later pass | |||||
mlir::FuncOp func_op = mlir::FuncOp::create(m_builder.getUnknownLoc(), | |||||
"func", func_type); | |||||
if (!func_op) | |||||
return nullptr; | |||||
func_op.setAttr("llvm.emit_c_interface", | |||||
mlir::UnitAttr::get(m_builder.getContext())); | |||||
auto& entry_block = *func_op.addEntryBlock(); | |||||
size_t idx = 0; | |||||
for (auto&& input : args.inputs) { | |||||
if (mlir::failed(declare(internal_graph.placeholders()[input.idx] | |||||
->output(0) | |||||
->name(), | |||||
entry_block.getArgument(idx)))) { | |||||
return nullptr; | |||||
} | |||||
idx++; | |||||
} | |||||
for (auto&& output : args.outputs) { | |||||
if (mlir::failed(declare(output.from->name(), | |||||
entry_block.getArgument(idx)))) { | |||||
return nullptr; | |||||
} | |||||
idx++; | |||||
} | |||||
m_builder.setInsertionPointToStart(&entry_block); | |||||
if (mlir::failed(gen_func_body(internal_graph, args))) { | |||||
func_op.erase(); | |||||
return nullptr; | |||||
} | |||||
jit::ReturnOp return_op; | |||||
if (!return_op) { | |||||
m_builder.create<jit::ReturnOp>(m_builder.getUnknownLoc()); | |||||
} | |||||
std::string op_content = to_string(func_op); | |||||
func_op.setName( | |||||
ssprintf("jit_mlir_%" PRIx64, | |||||
XXHash{}.update(op_content.data(), op_content.size()) | |||||
.digest())); | |||||
return func_op; | |||||
} | |||||
mlir::LogicalResult gen_func_body(const InternalGraph& internal_graph, | |||||
const JITExecutor::Args& args) { | |||||
llvm::ScopedHashTableScope<llvm::StringRef, mlir::Value> var_scope( | |||||
m_symbol_table); | |||||
cg::DepOprIter{[&](cg::OperatorNodeBase* opr) { | |||||
if (opr->same_type<JITPlaceholder>()) { | |||||
return; | |||||
} | |||||
if (opr->same_type<opr::Elemwise>()) { | |||||
auto&& out = gen_op(opr->cast_final<opr::Elemwise>()); | |||||
mgb_assert( | |||||
mlir::succeeded(declare(opr->output(0)->name(), out))); | |||||
} | |||||
}}.add(internal_graph.output()); | |||||
m_builder.create<AssignOp>(m_builder.getUnknownLoc(), | |||||
get(internal_graph.output()), | |||||
get(args.outputs[0].from)); | |||||
return mlir::success(); | |||||
} | |||||
mlir::Value gen_op(const opr::Elemwise& opr) { | |||||
switch (opr.param().mode) { | |||||
case opr::Elemwise::Mode::ADD: | |||||
return m_builder.create<AddOp>(m_builder.getUnknownLoc(), | |||||
get(opr.input(0)), | |||||
get(opr.input(1))); | |||||
break; | |||||
default: | |||||
return nullptr; | |||||
} | |||||
return nullptr; | |||||
} | |||||
mlir::Type get_type(const TensorLayout& layout) { | |||||
std::vector<int64_t> shape; | |||||
for (size_t i = 0; i < layout.ndim; i++) { | |||||
shape.push_back(layout[i]); | |||||
} | |||||
mgb_assert(layout.ndim != 0); | |||||
switch (layout.dtype.enumv()) { | |||||
case DTypeEnum::Float32: | |||||
return mlir::MemRefType::get(shape, m_builder.getF32Type()); | |||||
default: | |||||
mgb_throw(InternalError, "No supported dtype: %s", | |||||
layout.dtype.name()); | |||||
} | |||||
return mlir::UnrankedMemRefType::get(m_builder.getNoneType(), 0); | |||||
} | |||||
mlir::Value get(const VarNode* var) { | |||||
if (auto ret = m_symbol_table.lookup(var->name())) { | |||||
return ret; | |||||
} | |||||
mgb_throw(InternalError, "Unknown var: %s", var->cname()); | |||||
} | |||||
mlir::LogicalResult declare(llvm::StringRef var, mlir::Value value) { | |||||
if (m_symbol_table.count(var)) { | |||||
return mlir::failure(); | |||||
} | |||||
m_symbol_table.insert(var, value); | |||||
return mlir::success(); | |||||
} | |||||
}; | |||||
} // namespace | |||||
std::pair<llvm::StringRef, mlir::OwningModuleRef> mgb::jit::mlir_gen( | |||||
mlir::MLIRContext& context, | |||||
const mgb::jit::InternalGraph& internal_graph, | |||||
const mgb::jit::JITExecutor::Args& args) { | |||||
return MLIRGenImpl(context).gen(internal_graph, args); | |||||
} | |||||
#endif // MGB_JIT_MLIR | |||||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,42 @@ | |||||
/** | |||||
* \file src/jit/impl/mlir/mlir_gen.h | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | |||||
#pragma once | |||||
#include "megbrain_build_config.h" | |||||
#if MGB_JIT && MGB_JIT_MLIR | |||||
#include "megbrain/jit/executor_opr.h" | |||||
#include "megbrain/jit/internal_graph.h" | |||||
#include <mlir/IR/Module.h> | |||||
namespace mgb { | |||||
namespace jit { | |||||
/** | |||||
* \brief generate mlir from subgraph. | |||||
* | |||||
* \param context mlir context | |||||
* \param internal_graph internal graph used to generate mlir | |||||
* \param args input args for the internal graph | |||||
* \return A pair of {kernel_name, module} | |||||
**/ | |||||
std::pair<llvm::StringRef, mlir::OwningModuleRef> mlir_gen( | |||||
mlir::MLIRContext& context, const InternalGraph& internal_graph, | |||||
const JITExecutor::Args& args); | |||||
} | |||||
} // namespace mgb | |||||
#endif // MGB_JIT_MLIR | |||||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,45 @@ | |||||
/** | |||||
* \file src/jit/impl/mlir/utils.h | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | |||||
#pragma once | |||||
#include "megbrain_build_config.h" | |||||
#if MGB_JIT && MGB_JIT_MLIR | |||||
#include "megbrain/common.h" | |||||
#include "megbrain/exception.h" | |||||
#include "megdnn/basic_types.h" | |||||
#include "megdnn/dtype.h" | |||||
#include <string> | |||||
#include <mlir/ExecutionEngine/CRunnerUtils.h> | |||||
#include <llvm/Support/raw_ostream.h> | |||||
namespace mgb { | |||||
namespace jit { | |||||
template <typename T> | |||||
std::string to_string(T&& t) { | |||||
std::string ret; | |||||
llvm::raw_string_ostream stream(ret); | |||||
t.print(stream); | |||||
return ret; | |||||
} | |||||
} // namespace jit | |||||
} // namespace mgb | |||||
#endif // MGB_JIT && MGB_JIT_MLIR | |||||
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} |
@@ -0,0 +1,45 @@ | |||||
/** | |||||
* \file src/jit/impl/mlir/ir/dialect.h | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | |||||
#pragma once | |||||
#include "megbrain_build_config.h" | |||||
#if MGB_JIT && MGB_JIT_MLIR | |||||
#include <mlir/IR/OpDefinition.h> | |||||
#include <mlir/IR/Dialect.h> | |||||
#include <mlir/IR/Function.h> | |||||
#include <mlir/Interfaces/SideEffectInterfaces.h> | |||||
#include "megbrain/jit/mlir/ir/shape_inference_interface.h" | |||||
namespace mgb { | |||||
namespace jit { | |||||
class MgbDialect : public ::mlir::Dialect { | |||||
public: | |||||
explicit MgbDialect(::mlir::MLIRContext* ctx); | |||||
//! We should register this function in dialect | |||||
static llvm::StringRef getDialectNamespace() { return "mgb::jit"; } | |||||
}; | |||||
#define GET_OP_CLASSES | |||||
using namespace mlir; | |||||
#include "megbrain/jit/mlir/ir/ops.h.inc" | |||||
} // namespace jit | |||||
} // namespace mgb | |||||
#endif // MGB_JIT_MLIR | |||||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,43 @@ | |||||
/** | |||||
* \file src/jit/impl/mlir/ir/passes.h | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | |||||
#pragma once | |||||
#include <mlir/IR/Module.h> | |||||
#include "megbrain_build_config.h" | |||||
#if MGB_JIT && MGB_JIT_MLIR | |||||
#include <memory> | |||||
#include <mlir/Pass/Pass.h> | |||||
namespace mgb { | |||||
namespace jit { | |||||
std::unique_ptr<mlir::Pass> create_shape_inference_pass(); | |||||
/** | |||||
* \brief Create a pass for lowering to operations in the `Affine` and `Std` | |||||
* dialects, for a subset of the megbrain IR. | |||||
*/ | |||||
std::unique_ptr<mlir::Pass> create_lower_to_affine_pass(); | |||||
std::unique_ptr<mlir::Pass> create_lower_to_llvm_pass(); | |||||
std::unique_ptr<mlir::Pass> create_lower_to_gpu_pass(); | |||||
} // namespace jit | |||||
} // namespace mgb | |||||
#endif // MGB_JIT_MLIR | |||||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,33 @@ | |||||
/** | |||||
* \file src/jit/impl/mlir/ir/shape_inference_interface.h | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | |||||
#pragma once | |||||
#include "megbrain_build_config.h" | |||||
#if MGB_JIT && MGB_JIT_MLIR | |||||
#include "mlir/IR/OpDefinition.h" | |||||
namespace mgb { | |||||
namespace jit { | |||||
/// Include the auto-generated declarations. | |||||
#include "megbrain/jit/mlir/ir/shape_inference_interface.h.inc" | |||||
} // end namespace toy | |||||
} // end namespace mlir | |||||
#endif // MGB_JIT_MLIR | |||||
// vim: syntax=cpp.doxygen |
@@ -18,8 +18,8 @@ | |||||
*/ | */ | ||||
/*! | /*! | ||||
* * \brief fast division for unsigned int | |||||
* */ | |||||
* \brief fast division for unsigned int | |||||
*/ | |||||
struct Uint32Fastdiv { | struct Uint32Fastdiv { | ||||
unsigned int m_mul, m_divisor, m_divisor_is_not_1, m_inc_dividend, m_shift; | unsigned int m_mul, m_divisor, m_divisor_is_not_1, m_inc_dividend, m_shift; | ||||
@@ -14,6 +14,7 @@ | |||||
#include "megbrain/jit/executor_opr.h" | #include "megbrain/jit/executor_opr.h" | ||||
#include "megbrain/opr/basic_arith_wrapper.h" | #include "megbrain/opr/basic_arith_wrapper.h" | ||||
#include "megbrain/test/helper.h" | #include "megbrain/test/helper.h" | ||||
#include "megdnn/dtype.h" | |||||
#if MGB_JIT | #if MGB_JIT | ||||
using namespace mgb; | using namespace mgb; | ||||
@@ -120,6 +121,44 @@ void run<grad>(Backend backend, CompNode cn) { | |||||
template <> | template <> | ||||
void run<void>(Backend, CompNode) {} | void run<void>(Backend, CompNode) {} | ||||
#if MGB_JIT_MLIR | |||||
void run_mlir(CompNode cn) { | |||||
set_backend(Backend::MLIR); | |||||
auto graph = ComputingGraph::make(); | |||||
HostTensorGenerator<dtype::Float32> gen; | |||||
auto host_x0 = gen({23, 42}, cn), host_x1 = gen({23, 42}, cn), | |||||
host_x2 = gen({23, 42}, cn); | |||||
auto a = opr::Host2DeviceCopy::make(*graph, host_x0), | |||||
b = opr::Host2DeviceCopy::make(*graph, host_x1), | |||||
c = opr::Host2DeviceCopy::make(*graph, host_x2); | |||||
auto y = a + b + c; | |||||
VarNodeArray inputs{a.node(), b.node(), c.node()}, outputs{y.node()}; | |||||
auto ig_gen = | |||||
std::make_unique<InternalGraphGenerator>(y.node()->owner_opr()); | |||||
for (auto i : get_rev_topo_order(y)) { | |||||
if (!i->same_type<opr::Host2DeviceCopy>()) { | |||||
ig_gen->add_opr(i); | |||||
} | |||||
} | |||||
auto igraph = ig_gen->generate(); | |||||
auto y_jit = JITExecutor::make(igraph, ig_gen->orig_inps()); | |||||
HostTensorND host_y, host_y_jit; | |||||
auto func = graph->compile({make_callback_copy(y, host_y), | |||||
make_callback_copy(y_jit, host_y_jit)}); | |||||
func->execute(); | |||||
MGB_ASSERT_TENSOR_EQ(host_y, host_y_jit); | |||||
} | |||||
#endif | |||||
} // anonymous namespace | } // anonymous namespace | ||||
#if MGB_JIT_HALIDE | #if MGB_JIT_HALIDE | ||||
@@ -140,6 +179,20 @@ TYPED_TEST(TestJITNvrtcCodeGen, run) { | |||||
run<TypeParam>(Backend::NVRTC, CompNode::load("gpu0")); | run<TypeParam>(Backend::NVRTC, CompNode::load("gpu0")); | ||||
} | } | ||||
#if MGB_JIT_MLIR | |||||
TEST(TestJITMlirCodeGen, Basic) { | |||||
auto cn = CompNode::load("cpu0"); | |||||
run_mlir(cn); | |||||
} | |||||
TEST(TestJITMlirCodeGen, BasicGPU) { | |||||
REQUIRE_GPU(1); | |||||
auto cn = CompNode::load("gpu0"); | |||||
run_mlir(cn); | |||||
} | |||||
#endif | |||||
#endif // MGB_JIT | #endif // MGB_JIT | ||||
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} |
@@ -35,6 +35,9 @@ void jit::set_backend(Backend backend) { | |||||
case Backend::NVRTC: | case Backend::NVRTC: | ||||
setenv("MGB_JIT_BACKEND", "NVRTC", 1); | setenv("MGB_JIT_BACKEND", "NVRTC", 1); | ||||
return; | return; | ||||
case Backend::MLIR: | |||||
setenv("MGB_JIT_BACKEND", "MLIR", 1); | |||||
return; | |||||
default: | default: | ||||
mgb_assert(0); | mgb_assert(0); | ||||
} | } | ||||
@@ -15,7 +15,7 @@ | |||||
namespace mgb { | namespace mgb { | ||||
namespace jit { | namespace jit { | ||||
enum class Backend { NONE, HALIDE, NVRTC }; | |||||
enum class Backend { NONE, HALIDE, NVRTC, MLIR }; | |||||
void set_backend(Backend backend); | void set_backend(Backend backend); | ||||