diff --git a/.gitattributes b/.gitattributes index 6e3614c6..1f0fc3a8 100644 --- a/.gitattributes +++ b/.gitattributes @@ -4,3 +4,4 @@ dnn/src/cuda/conv_bias/int8/kimpl/* binary dnn/src/cuda/conv_bias/int8_imma/kimpl/* binary dnn/src/cuda/batch_conv_bias/int8/kimpl/* binary dnn/src/cuda/sass/prebuilt/map_defs.cpp binary +tools/mlir/mlir-tblgen filter=lfs diff=lfs merge=lfs -text diff --git a/src/core/impl/comp_node_env.cpp b/src/core/impl/comp_node_env.cpp index 98c0c61c..c6520288 100644 --- a/src/core/impl/comp_node_env.cpp +++ b/src/core/impl/comp_node_env.cpp @@ -162,6 +162,15 @@ void mgb::_on_cuda_error(const char* expr, cudaError_t err, const char* file, cudaGetErrorString(err), expr, file, func, line); } +void mgb::_on_cuda_cu_error(const char* expr, CUresult err, const char* file, + const char* func, int line) { + const char* msg; + cuGetErrorName(err, &msg); + mgb_throw(CudaError, "cuda error %d: %s (%s at %s:%s:%d)", int(err), msg, + expr, file, func, line); +} + + void CompNodeEnv::init_cuda_async(int dev, CompNode comp_node, const ContinuationCtx& cont) { m_comp_node = comp_node; diff --git a/src/core/include/megbrain/comp_node_env.h b/src/core/include/megbrain/comp_node_env.h index ee5817f3..4b7a2f3e 100644 --- a/src/core/include/megbrain/comp_node_env.h +++ b/src/core/include/megbrain/comp_node_env.h @@ -22,6 +22,7 @@ #if MGB_CUDA #include +#include #if MGB_ENABLE_LOGGING #define MGB_CUDA_CHECK(expr) \ @@ -32,6 +33,16 @@ __func__, __LINE__); \ } \ } while (0) + +#define MGB_CUDA_CU_CHECK(expr) \ + do { \ + CUresult __cuda_check_code = (expr); \ + if (!mgb_likely(__cuda_check_code == CUDA_SUCCESS)) { \ + ::mgb::_on_cuda_cu_error(#expr, __cuda_check_code, __FILE__, \ + __func__, __LINE__); \ + } \ + } while (0) + #else #define MGB_CUDA_CHECK(expr) \ do { \ @@ -41,6 +52,14 @@ } \ } while (0) +#define MGB_CUDA_CU_CHECK(expr) \ + do { \ + CUresult __cuda_check_code = (expr); \ + if (!mgb_likely(__cuda_check_code == CUDA_SUCCESS)) { \ + ::mgb::_on_cuda_cu_error(#expr, __cuda_check_code, "", "", 1); \ + } \ + } while (0) + #endif //MGB_ENABLE_LOGGING #endif //MGB_CUDA @@ -178,6 +197,9 @@ namespace mgb { #if MGB_CUDA [[noreturn]] void _on_cuda_error(const char* expr, cudaError_t err, const char* file, const char* func, int line); +[[noreturn]] void _on_cuda_cu_error(const char* expr, CUresult err, + const char* file, const char* func, + int line); #endif diff --git a/src/jit/impl/compiler.cpp b/src/jit/impl/compiler.cpp index 032293c9..36a82055 100644 --- a/src/jit/impl/compiler.cpp +++ b/src/jit/impl/compiler.cpp @@ -11,6 +11,7 @@ #include "./halide/compiler_cuda.h" #include "./nvrtc/compiler_cuda.h" +#include "./mlir/compiler.h" #include "megbrain/jit/compiler.h" #include "megbrain/utils/hash.h" @@ -54,6 +55,8 @@ bool Compiler::is_supported_device(CompNode::DeviceType device) { case CompNode::DeviceType::CUDA: return true; #endif + case CompNode::DeviceType::CPU: + return true; default: return false; } @@ -87,12 +90,30 @@ Compiler* Compiler::get(ComputingGraph& graph, CompNode comp_node) { break; } #endif +#if MGB_JIT_MLIR + if (!backend || !strcmp(backend, "MLIR")) { + compiler = std::make_unique( + CompNode::DeviceType::CUDA); + break; + } +#endif if (!backend || !strcmp(backend, "NVRTC")) { compiler = std::make_unique(); break; } #endif - // fall through + mgb_throw(InternalError, "No compiler support for cuda"); + break; + case CompNode::DeviceType::CPU: +#if MGB_JIT_MLIR + if (!backend || !strcmp(backend, "MLIR")) { + compiler = std::make_unique( + CompNode::DeviceType::CPU); + break; + } +#endif + mgb_throw(InternalError, "No compiler support for cpu"); + break; default: mgb_throw(InternalError, "unsupported JIT config: " diff --git a/src/jit/impl/mlir/compiler.cpp b/src/jit/impl/mlir/compiler.cpp new file mode 100644 index 00000000..4742c685 --- /dev/null +++ b/src/jit/impl/mlir/compiler.cpp @@ -0,0 +1,198 @@ +/** + * \file src/jit/impl/mlir/compiler.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ + +#include "megbrain_build_config.h" +#if MGB_JIT && MGB_JIT_MLIR + +#include "./compiler.h" +#include "./executable_cpu.h" +#include "./executable_cuda.h" +#include "./mlir_gen.h" + +#include "megbrain/common.h" +#include "megbrain/comp_node_env.h" +#include "megbrain/jit/mlir/ir/dialect.h" +#include "megbrain/jit/mlir/ir/passes.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +using namespace mgb; +using namespace jit; + +namespace { + +struct LLVMInitializer { + LLVMInitializer() { + llvm::InitializeNativeTarget(); + llvm::InitializeNativeTargetAsmPrinter(); + } +}; +static LLVMInitializer initializer; + +#if MGB_CUDA +mlir::OwnedBlob compile_ptx_to_cubin(const std::string ptx, mlir::Location, + llvm::StringRef) { + OwnedBlob result = std::make_unique>( + ptx.data(), ptx.data() + ptx.size()); + + return result; +} + +#endif + +void add_cpu_lowering_pass(mlir::PassManager& manager) { + { + mlir::OpPassManager& opt_pm = manager.nest(); + opt_pm.addPass(create_shape_inference_pass()); + opt_pm.addPass(mlir::createCanonicalizerPass()); + opt_pm.addPass(mlir::createCSEPass()); + } + + manager.addPass(create_lower_to_affine_pass()); + { + mlir::OpPassManager& opt_pm = manager.nest(); + opt_pm.addPass(mlir::createCanonicalizerPass()); + opt_pm.addPass(mlir::createCSEPass()); + opt_pm.addPass(mlir::createLoopFusionPass()); + opt_pm.addPass(mlir::createMemRefDataFlowOptPass()); + } + manager.addPass(create_lower_to_llvm_pass()); +} + +#if MGB_CUDA +void add_cuda_lowering_pass(mlir::PassManager& manager, CompNode cn) { + { + mlir::OpPassManager& opt_pm = manager.nest(); + opt_pm.addPass(create_shape_inference_pass()); + opt_pm.addPass(mlir::createCanonicalizerPass()); + opt_pm.addPass(mlir::createCSEPass()); + } + manager.addPass(create_lower_to_gpu_pass()); + { + mlir::OpPassManager& opt_pm = manager.nest(); + opt_pm.addPass(mlir::createCanonicalizerPass()); + opt_pm.addPass(mlir::createCSEPass()); + opt_pm.addPass(mlir::createLoopFusionPass()); + opt_pm.addPass(mlir::createMemRefDataFlowOptPass()); + } + manager.addPass(mlir::createGpuKernelOutliningPass()); + { + auto& kernel_pm = manager.nest(); + kernel_pm.addPass(mlir::createLowerGpuOpsToNVVMOpsPass()); + + auto&& prop = CompNodeEnv::from_comp_node(cn).cuda_env().device_prop; + kernel_pm.addPass(mlir::createConvertGPUKernelToBlobPass( + mlir::translateModuleToNVVMIR, compile_ptx_to_cubin, + "nvptx64-nvidia-cuda", + ssprintf("sm_%d%d", prop.major, prop.minor), "+ptx60", + MLIRCUDAExecutable::sm_blob_annotation)); + } +} +#endif + +} // namespace + +/* ==================== MLIRCompiler ===================== */ + +thread_local mlir::MLIRContext MLIRCompiler::sm_ctx; + +MLIRCompiler::MLIRCompiler(CompNode::DeviceType device_type) + : m_device_type{device_type} { + mlir::registerAllDialects(); + mlir::registerDialect(); + +#if MGB_CUDA + if (m_device_type == CompNode::DeviceType::CUDA) { + LLVMInitializeNVPTXTarget(); + LLVMInitializeNVPTXTargetInfo(); + LLVMInitializeNVPTXTargetMC(); + LLVMInitializeNVPTXAsmPrinter(); + } +#endif +} + +void MLIRCompiler::run_lowering_pass(mlir::OwningModuleRef& module, + CompNode cn) { + mgb_assert(cn.device_type() == m_device_type); + mlir::PassManager manager(module->getContext()); + switch (m_device_type) { + case CompNode::DeviceType::CPU: + add_cpu_lowering_pass(manager); + break; +#if MGB_CUDA + case CompNode::DeviceType::CUDA: + add_cuda_lowering_pass(manager, cn); + break; +#endif + default: + mgb_throw(InternalError, "Unsupport device type: %d", + static_cast(m_device_type)); + break; + } + mgb_assert(mlir::succeeded(manager.run(*module))); +} + +std::unique_ptr MLIRCompiler::do_compile( + const InternalGraph& graph, const JITExecutor::Args& args) { + MGB_MARK_USED_VAR(graph); + MGB_MARK_USED_VAR(args); + + mlir::MLIRContext ctx; + ctx.printStackTraceOnDiagnostic(true); + ctx.printOpOnDiagnostic(true); + + auto&& res = mlir_gen(ctx, graph, args); + mgb_assert(res.second, "failed to generate module"); + + CompNode cn = args.owner->comp_node(); + run_lowering_pass(res.second, cn); + switch (cn.device_type()) { + case CompNode::DeviceType::CPU: + return std::make_unique(res.second, + res.first.str()); +#if MGB_CUDA + case CompNode::DeviceType::CUDA: + return std::make_unique(res.second, + res.first.str()); +#endif + default: + mgb_throw(InternalError, "Unsupport device type: %d", + static_cast(cn.device_type())); + return nullptr; + } +} + +size_t MLIRCompiler::get_nr_workspace_outputs(JITExecutor* opr) const { + MGB_MARK_USED_VAR(opr); + return 0; +} + +void MLIRCompiler::init_workspace_size_infer(JITExecutor* opr) { + MGB_MARK_USED_VAR(opr); +} + +#endif // MGB_JIT && MGB_JIT_MLIR + +// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} diff --git a/src/jit/impl/mlir/compiler.h b/src/jit/impl/mlir/compiler.h new file mode 100644 index 00000000..594a04ec --- /dev/null +++ b/src/jit/impl/mlir/compiler.h @@ -0,0 +1,56 @@ +/** + * \file src/jit/impl/mlir/compiler.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#pragma once + +#include "megbrain_build_config.h" +#if MGB_JIT && MGB_JIT_MLIR + +#include "megbrain/jit/compiler.h" + +#include +#include + +namespace mgb { +namespace jit { + +/*! + * \brief MLIR compiler + */ +class MLIRCompiler final : public Compiler { + std::unique_ptr do_compile( + const InternalGraph& graph, const JITExecutor::Args& args) override; + +public: + MLIRCompiler(CompNode::DeviceType device_type = CompNode::DeviceType::CPU); + Property property() const override { + using F = Property::Flag; + return Property{F::NEED_INPUT_COLLAPSE | F::BIND_NDIM, + JITFeatureBits::NONE, 64}; + } + + size_t get_nr_workspace_outputs(JITExecutor* opr) const override; + + void init_workspace_size_infer(JITExecutor* opr) override; + +private: + void run_lowering_pass(mlir::OwningModuleRef& module, CompNode cn); + + CompNode::DeviceType m_device_type; + static thread_local mlir::MLIRContext sm_ctx; +}; + +} // namespace jit +} // namespace mgb + +#endif // MGB_JIT && MGB_JIT_MLIR + +// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} diff --git a/src/jit/impl/mlir/executable_cpu.cpp b/src/jit/impl/mlir/executable_cpu.cpp new file mode 100644 index 00000000..d9190736 --- /dev/null +++ b/src/jit/impl/mlir/executable_cpu.cpp @@ -0,0 +1,134 @@ +/** + * \file src/jit/impl/mlir/executable_cpu.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ + +#include "megbrain_build_config.h" +#if MGB_JIT && MGB_JIT_MLIR + +#include "./executable_cpu.h" +#include "./utils.h" + +#include + +using namespace mgb; +using namespace jit; + +namespace { + +template +void* tensor2memref_dim(const megdnn::TensorND& tensor) { + switch (tensor.layout.dtype.enumv()) { + case megdnn::DTypeEnum::Float32: { + StridedMemRefType* desc = + static_cast*>( + malloc(sizeof(StridedMemRefType))); + desc->basePtr = tensor.ptr(); + desc->data = tensor.ptr(); + desc->offset = 0; + for (size_t i = 0; i < tensor.layout.ndim; i++) { + desc->sizes[i] = tensor.layout.shape[i]; + desc->strides[i] = tensor.layout.stride[i]; + } + return desc; + break; + } + default: + mgb_throw(InternalError, "Unsupport dtype, got %s", + tensor.layout.dtype.name()); + break; + } + return nullptr; +} + +void* tensor2memref(const megdnn::TensorND& tensor) { + switch (tensor.layout.ndim) { +#define cb(i) \ + case i: \ + return tensor2memref_dim(tensor) + + cb(1); + cb(2); + cb(3); + cb(4); + cb(5); + default: + mgb_throw(InternalError, "Unsupported ndim, got %zu", + tensor.layout.ndim); +#undef cb + } +} + +} // namespace +MLIRCPUExecutable::MLIRCPUExecutable(mlir::OwningModuleRef& module, + const std::string& kernel_name) + : m_kernel_name{kernel_name} { + auto opt_pipeline = mlir::makeOptimizingTransformer(3, 3, 0); + std::vector libs; + auto&& engine = mlir::ExecutionEngine::create( + *module, opt_pipeline, llvm::None, + std::vector(libs.begin(), libs.end()), true, + false); + mgb_assert(engine); + m_engine = std::move(*engine); +} + +void MLIRCPUExecutable::execute(JITExecutor* fusion_opr) { + auto&& args = fusion_opr->args(); + std::vector args_array(args.inputs.size() + args.outputs.size()); + std::vector args_array_pointer(args.inputs.size() + + args.outputs.size()); + size_t idx = 0; + for (size_t i = 0; i < args.inputs.size(); i++) { + args_array[idx] = + tensor2memref({args.inputs[i].from->dev_tensor().raw_ptr(), + args.inputs[i].layout}); + args_array_pointer[idx] = &args_array[idx]; + idx++; + } + int64_t nr_elements = 0; + for (size_t i = 0; i < args.outputs.size(); i++) { + if (nr_elements == 0) { + nr_elements = args.outputs[i].layout.total_nr_elems(); + } else { + mgb_assert(static_cast(nr_elements) == + args.outputs[i].layout.total_nr_elems(), + "The number of elements of outputs mismatch, expected: " + "%zu got: %zu(%s)", + static_cast(nr_elements), + args.outputs[i].layout.total_nr_elems(), + args.outputs[i].layout.to_string().c_str()); + } + args_array[idx] = + tensor2memref({args.outputs[i].from->dev_tensor().raw_ptr(), + args.outputs[i].layout}); + args_array_pointer[idx] = &args_array[idx]; + idx++; + } + + args_array_pointer[idx++] = &nr_elements; + std::string adapter_name = std::string("_mlir_ciface_") + m_kernel_name; + auto err = m_engine->invoke( + adapter_name, llvm::MutableArrayRef(args_array_pointer)); + if (err) { + mgb_throw(InternalError, "failed to run MLIR kernel %s\n", + m_kernel_name.c_str()); + } + + for (size_t i = 0; i < args_array.size(); i++) { + free(args_array[i]); + } +} + +MLIRCPUExecutable::~MLIRCPUExecutable() {} + +#endif // MGB_JIT && MGB_JIT_MLIR + +// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} diff --git a/src/jit/impl/mlir/executable_cpu.h b/src/jit/impl/mlir/executable_cpu.h new file mode 100644 index 00000000..730d8361 --- /dev/null +++ b/src/jit/impl/mlir/executable_cpu.h @@ -0,0 +1,51 @@ +/** + * \file src/jit/impl/mlir/executable_cpu.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ + +#pragma once + +#include "megbrain_build_config.h" +#if MGB_JIT && MGB_JIT_MLIR + +#include "megbrain/jit/compiler.h" + +#include +#include + +namespace mgb { +namespace jit { + +/*! + * \brief Executable class for MLIR + */ +class MLIRCPUExecutable final : public Executable { +public: + MLIRCPUExecutable(mlir::OwningModuleRef& module, + const std::string& kernel_name); + ~MLIRCPUExecutable(); + + /*! + * \brief execute + * A executable instance can be executed by one or more fusion_opr + */ + void execute(JITExecutor* fusion_opr) override final; + +private: + std::unique_ptr m_engine; + std::string m_kernel_name; +}; + +} // namespace jit +} // namespace mgb + +#endif // MGB_JIT && MGB_JIT_MLIR + +// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} diff --git a/src/jit/impl/mlir/executable_cuda.cpp b/src/jit/impl/mlir/executable_cuda.cpp new file mode 100644 index 00000000..2dea7499 --- /dev/null +++ b/src/jit/impl/mlir/executable_cuda.cpp @@ -0,0 +1,166 @@ +/** + * \file src/jit/impl/mlir/executable_cuda.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ + +#include +#include "megbrain_build_config.h" +#include "megdnn/dtype.h" +#if MGB_JIT && MGB_JIT_MLIR + +#if MGB_CUDA +#include "./executable_cuda.h" +#include "./utils.h" +#include "megbrain/utils/timer.h" +#include "megbrain/utils/persistent_cache.h" +#include "megbrain/comp_node_env.h" + +#include +#include +#include + +using namespace mgb; +using namespace jit; + +namespace { +template +void setup_and_launch(const JITExecutor* fusion_opr, CUfunction func, + int block_size) { + auto&& args = fusion_opr->args(); + std::vector> param_holders; + std::vector params; + + auto set_params = [¶m_holders, ¶ms]( + void* ptr, const megdnn::TensorLayout& layout) { + param_holders.push_back(StridedMemRefType{}); + StridedMemRefType& desc = param_holders.back(); + desc.basePtr = static_cast(ptr); + params.push_back(&(desc.basePtr)); + desc.data = static_cast(ptr); + params.push_back(&(desc.data)); + desc.offset = 0; + params.push_back(&(desc.offset)); + for (size_t i = 0; i < layout.ndim; i++) { + desc.sizes[i] = layout.shape[i]; + params.push_back(&(desc.sizes[i])); + desc.strides[i] = layout.stride[i]; + params.push_back(&(desc.strides[i])); + } + }; + for (const auto& arg : args.inputs) { + set_params(arg.from->dev_tensor().raw_ptr(), arg.layout); + } + int64_t nr_elements = 0; + for (const auto& arg : args.outputs) { + if (nr_elements == 0) { + nr_elements = arg.layout.total_nr_elems(); + } else { + mgb_assert(static_cast(nr_elements) == + arg.layout.total_nr_elems(), + "The number of elements of outputs mismatch, expected: " + "%zu got: %zu(%s)", + static_cast(nr_elements), + arg.layout.total_nr_elems(), + arg.layout.to_string().c_str()); + } + + set_params(arg.from->dev_tensor().raw_ptr(), arg.layout); + } + const CompNodeEnv& env = + CompNodeEnv::from_comp_node(fusion_opr->comp_node()); + + int64_t num_block = (nr_elements - 1) / block_size + 1; + params.insert(params.begin(), &nr_elements); + MGB_CUDA_CU_CHECK(cuLaunchKernel(func, num_block, 1, 1, block_size, 1, 1, 0, + env.cuda_env().stream, params.data(), 0)); +} +} // namespace + +const std::string MLIRCUDAExecutable::sm_blob_annotation = "nvvm.cubin"; +MLIRCUDAExecutable::MLIRCUDAExecutable(mlir::OwningModuleRef& module, + const std::string& kernel_name) { + m_kernel_name = kernel_name + "_kernel"; + auto kernel_module = + module->lookupSymbol(m_kernel_name); + mgb_assert(kernel_module, "Expected gpu kernel module"); + + auto binary_attr = kernel_module.getAttrOfType( + llvm::StringRef(sm_blob_annotation)); + mgb_assert(binary_attr, "Missing %s attribute in gpu kernel module", + sm_blob_annotation.c_str()); + m_kernel_data = binary_attr.getValue().str(); +} + +void MLIRCUDAExecutable::execute(JITExecutor* fusion_opr) { + FuncCache* func; + auto cn = fusion_opr->comp_node(); + auto&& prop = CompNodeEnv::from_comp_node(cn).cuda_env().device_prop; + func = &m_func_cache[{prop.major, prop.minor}]; + func->kernel_data = m_kernel_data; + func->exec(fusion_opr, this); +} + +MLIRCUDAExecutable::~MLIRCUDAExecutable() {} + +void MLIRCUDAExecutable::FuncCache::exec(const JITExecutor* fusion_opr, + const MLIRCUDAExecutable* cuda_exe) { + Func* func; + { + MGB_LOCK_GUARD(mtx); + auto ins = cn2func.insert({fusion_opr->comp_node(), {}}); + func = &ins.first->second; + if (ins.second) { + MGB_CUDA_CU_CHECK( + cuModuleLoadData(&func->module, kernel_data.data())); + MGB_CUDA_CU_CHECK( + cuModuleGetFunction(&func->func, func->module, + cuda_exe->m_kernel_name.c_str())); + int min_grid_size = 0; + MGB_CUDA_CU_CHECK(cuOccupancyMaxPotentialBlockSize( + &min_grid_size, &func->block_size, func->func, nullptr, 0, + 0)); + } + } + + mgb_assert(fusion_opr->args().outputs.size() == 1, + "Currently only support 1 outputs, got %zu", + fusion_opr->args().outputs.size()); + int out_dim = fusion_opr->args().outputs[0].layout.ndim; + DType dtype = fusion_opr->args().outputs[0].layout.dtype; +#define cb_outdim(_ndim, _dtype) \ + if (_ndim == out_dim) { \ + setup_and_launch<_ndim, _dtype>(fusion_opr, func->func, \ + func->block_size); \ + return; \ + } + +#define cb(_dtype) \ + cb_outdim(1, float); \ + cb_outdim(2, float); \ + cb_outdim(3, float); \ + cb_outdim(4, float); \ + mgb_throw(InternalError, "unsupported out_dim=%zu", \ + static_cast(out_dim)); \ + return; + + switch (dtype.enumv()) { + case DTypeEnum::Float32: + cb(float); + default: + mgb_throw(InternalError, "unsupport dtype: %s", dtype.name()); + } +#undef cb +#undef cb_outdim +} + +#endif // MGB_CUDA +#endif // MGB_JIT && MGB_JIT_MLIR + +// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} diff --git a/src/jit/impl/mlir/executable_cuda.h b/src/jit/impl/mlir/executable_cuda.h new file mode 100644 index 00000000..11a12626 --- /dev/null +++ b/src/jit/impl/mlir/executable_cuda.h @@ -0,0 +1,74 @@ +/** + * \file src/jit/impl/mlir/executable_cuda.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ + +#pragma once + +#include "megbrain_build_config.h" +#if MGB_JIT && MGB_JIT_MLIR + +#if MGB_CUDA +#include "megbrain/jit/compiler.h" + +#include + +#include + +namespace mgb { +namespace jit { + +/*! + * \brief Executable class for MLIR + */ +class MLIRCUDAExecutable final : public Executable { +public: + MLIRCUDAExecutable(mlir::OwningModuleRef& module, + const std::string& kernel_name); + ~MLIRCUDAExecutable(); + + /*! + * \brief execute + * A executable instance can be executed by one or more fusion_opr + */ + void execute(JITExecutor* fusion_opr) override final; + + const static std::string sm_blob_annotation; +private: + //! cache for a func on a specific device + struct FuncCache { + struct Func { + int block_size{-1}; + CUmodule module{nullptr}; + CUfunction func{nullptr}; + }; + + std::mutex mtx; + std::string kernel_data; + CompNode::UnorderedMap cn2func; + + void exec(const JITExecutor* fusion_opr, + const MLIRCUDAExecutable* cuda_exe); + }; + + std::string m_kernel_name; + std::string m_kernel_data; + + //! (cuda_major, cuda_minor) => func + ThinHashMap, FuncCache> m_func_cache; +}; + +} // namespace jit +} // namespace mgb + +#endif // MGB_CUDA +#endif // MGB_JIT && MGB_JIT_MLIR + +// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} diff --git a/src/jit/impl/mlir/ir/common.cpp b/src/jit/impl/mlir/ir/common.cpp new file mode 100644 index 00000000..bddaeb6b --- /dev/null +++ b/src/jit/impl/mlir/ir/common.cpp @@ -0,0 +1,41 @@ +/** + * \file src/jit/impl/mlir/ir/common.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ + +#include "megbrain_build_config.h" +#if MGB_JIT && MGB_JIT_MLIR + +#include "common.h" + +#include + +using namespace mgb; +using namespace jit; + +mlir::Value jit::insert_alloc_and_dealloc(mlir::MemRefType type, + mlir::Location loc, + mlir::PatternRewriter& rewriter) { + auto alloc = rewriter.create(loc, type); + + // Make sure to allocate at the beginning of the block. + auto* parent_block = alloc.getOperation()->getBlock(); + alloc.getOperation()->moveBefore(&parent_block->front()); + + // Make sure to deallocate this alloc at the end of the block. This is fine + // as toy functions have no control flow. + auto dealloc = rewriter.create(loc, alloc); + dealloc.getOperation()->moveBefore(&parent_block->back()); + return alloc; +} + +#endif // MGB_JIT_MLIR + +// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} diff --git a/src/jit/impl/mlir/ir/common.h b/src/jit/impl/mlir/ir/common.h new file mode 100644 index 00000000..251c569e --- /dev/null +++ b/src/jit/impl/mlir/ir/common.h @@ -0,0 +1,32 @@ +/** + * \file src/jit/impl/mlir/ir/common.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#pragma once + +#include "megbrain_build_config.h" +#if MGB_JIT && MGB_JIT_MLIR + +#include +#include +#include + +namespace mgb { +namespace jit { + +mlir::Value insert_alloc_and_dealloc(mlir::MemRefType type, mlir::Location loc, + mlir::PatternRewriter& rewriter); + +} // namespace jit +} // namespace mgb + +#endif // MGB_JIT_MLIR + +// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} diff --git a/src/jit/impl/mlir/ir/dialect.cpp b/src/jit/impl/mlir/ir/dialect.cpp new file mode 100644 index 00000000..dcff5900 --- /dev/null +++ b/src/jit/impl/mlir/ir/dialect.cpp @@ -0,0 +1,91 @@ +/** + * \file src/jit/impl/mlir/ir/dialect.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ + +#include "megbrain_build_config.h" +#if MGB_JIT && MGB_JIT_MLIR + +#include "megbrain/jit/mlir/ir/dialect.h" + +#include +#include +#include +#include + +using namespace mgb; +using namespace jit; + +MgbDialect::MgbDialect(mlir::MLIRContext *ctx) : mlir::Dialect("mgb", ctx) { + addOperations< +#define GET_OP_LIST +#include "megbrain/jit/mlir/ir/ops.cpp.inc" + >(); +} + +static mlir::ParseResult parseBinaryOp(mlir::OpAsmParser &parser, + mlir::OperationState &result) { + SmallVector operands; + llvm::SMLoc operandsLoc = parser.getCurrentLocation(); + Type type; + if (parser.parseOperandList(operands, /*requiredOperandCount=*/2) || + parser.parseOptionalAttrDict(result.attributes) || + parser.parseColonType(type)) + return mlir::failure(); + + // If the type is a function type, it contains the input and result types of + // this operation. + if (FunctionType funcType = type.dyn_cast()) { + if (parser.resolveOperands(operands, funcType.getInputs(), operandsLoc, + result.operands)) + return mlir::failure(); + result.addTypes(funcType.getResults()); + return mlir::success(); + } + + // Otherwise, the parsed type is the type of both operands and results. + if (parser.resolveOperands(operands, type, result.operands)) + return mlir::failure(); + result.addTypes(type); + return mlir::success(); +} + +static void printBinaryOp(mlir::OpAsmPrinter &printer, mlir::Operation *op) { + printer << op->getName() << " " << op->getOperands(); + printer.printOptionalAttrDict(op->getAttrs()); + printer << " : "; + + // If all of the types are the same, print the type directly. + Type resultType = *op->result_type_begin(); + if (llvm::all_of(op->getOperandTypes(), + [=](Type type) { return type == resultType; })) { + printer << resultType; + return; + } + + // Otherwise, print a functional type. + printer.printFunctionalType(op->getOperandTypes(), op->getResultTypes()); +} + +///////////////////////// ElemwiseOp ///////////////////////////////////////////// + +void AddOp::build(mlir::OpBuilder &builder, mlir::OperationState &state, + mlir::Value lhs, mlir::Value rhs) { + state.addTypes(lhs.getType()); + state.addOperands({lhs, rhs}); +} +void AddOp::infer_shapes() { getResult().setType(getOperand(0).getType()); } + +#define GET_OP_CLASSES +#include "megbrain/jit/mlir/ir/ops.cpp.inc" + +#endif // MGB_JIT_MLIR + +// vim: syntax=cpp.doxygen diff --git a/src/jit/impl/mlir/ir/lower_to_affine_pass.cpp b/src/jit/impl/mlir/ir/lower_to_affine_pass.cpp new file mode 100644 index 00000000..dbfbdeb3 --- /dev/null +++ b/src/jit/impl/mlir/ir/lower_to_affine_pass.cpp @@ -0,0 +1,159 @@ +/** + * \file src/jit/impl/mlir/ir/lower_to_affine_pass.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ + +#include "megbrain_build_config.h" +#if MGB_JIT && MGB_JIT_MLIR + +#include "megbrain/common.h" +#include "megbrain/jit/mlir/ir/dialect.h" +#include "megbrain/jit/mlir/ir/passes.h" + +#include "./common.h" + +#include +#include +#include +#include + +#include + +using namespace mgb; +using namespace jit; + +namespace { + +using LoopIterationFn = function_ref; + +void lower_op_to_loops(Operation* op, ValueRange operands, + PatternRewriter& rewriter, + LoopIterationFn process_iteration) { + auto memref_type = (*op->result_type_begin()).cast(); + auto loc = op->getLoc(); + + auto alloc = jit::insert_alloc_and_dealloc(memref_type, loc, rewriter); + + SmallVector lower_bounds(memref_type.getRank(), 0); + SmallVector steps(memref_type.getRank(), 1); + buildAffineLoopNest( + rewriter, loc, lower_bounds, memref_type.getShape(), steps, + [&](OpBuilder& nested_builder, Location loc, ValueRange ivs) { + Value value_to_store = + process_iteration(nested_builder, operands, ivs); + nested_builder.create(loc, value_to_store, alloc, + ivs); + }); + + // Replace this operation with the generated alloc. + rewriter.replaceOp(op, alloc); +} + +template +struct BinaryOpLowering : public ConversionPattern { + BinaryOpLowering(MLIRContext* ctx) + : ConversionPattern(BinaryOp::getOperationName(), 1, ctx) {} + + LogicalResult matchAndRewrite( + Operation* op, ArrayRef operands, + ConversionPatternRewriter& rewriter) const final { + auto loc = op->getLoc(); + lower_op_to_loops( + op, operands, rewriter, + [loc](OpBuilder& builder, ValueRange memref_operands, + ValueRange loop_ivs) { + typename BinaryOp::Adaptor binary_adaptor(memref_operands); + + auto loaded_lhs = builder.create( + loc, binary_adaptor.lhs(), loop_ivs); + auto loaded_rhs = builder.create( + loc, binary_adaptor.rhs(), loop_ivs); + + return builder.create(loc, loaded_lhs, + loaded_rhs); + }); + return success(); + } +}; +using AddOpLowering = BinaryOpLowering; + +struct AssignOpLowering : public ConversionPattern { + AssignOpLowering(MLIRContext* ctx) + : ConversionPattern(jit::AssignOp::getOperationName(), 1, ctx) {} + + LogicalResult matchAndRewrite( + Operation* op, ArrayRef operands, + ConversionPatternRewriter& rewriter) const final { + auto loc = op->getLoc(); + auto memref_type = operands[0].getType().cast(); + AssignOpAdaptor assign_adaptor(operands); + + SmallVector lower_bounds(memref_type.getRank(), 0); + SmallVector steps(memref_type.getRank(), 1); + buildAffineLoopNest( + rewriter, loc, lower_bounds, memref_type.getShape(), steps, + [&](OpBuilder& nested_builder, Location loc, ValueRange ivs) { + auto loaded_lhs = nested_builder.create( + loc, assign_adaptor.lhs(), ivs); + nested_builder.create( + loc, loaded_lhs, assign_adaptor.rhs(), ivs); + }); + + rewriter.eraseOp(op); + return success(); + } +}; + +struct ReturnOpLowering : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(jit::ReturnOp op, + PatternRewriter& rewriter) const final { + rewriter.replaceOpWithNewOp(op); + return success(); + } +}; + +class MgbToAffineLoweringPass + : public PassWrapper { +public: + void runOnFunction() override final { + auto function = getFunction(); + + // Verify that the given main has no inputs and results. + if (function.getType().getNumResults()) { + mgb_log_error("expected 'main' to have 0 results"); + return signalPassFailure(); + } + + ConversionTarget target(getContext()); + target.addLegalDialect(); + target.addIllegalDialect(); + + OwningRewritePatternList patterns; + patterns.insert( + &getContext()); + + if (failed(applyPartialConversion(getFunction(), target, patterns))) { + signalPassFailure(); + } + } +}; + +} // namespace + +std::unique_ptr mgb::jit::create_lower_to_affine_pass() { + return std::make_unique(); +} + +#endif // MGB_JIT_MLIR + +// vim: syntax=cpp.doxygen diff --git a/src/jit/impl/mlir/ir/lower_to_gpu_pass.cpp b/src/jit/impl/mlir/ir/lower_to_gpu_pass.cpp new file mode 100644 index 00000000..ad72760f --- /dev/null +++ b/src/jit/impl/mlir/ir/lower_to_gpu_pass.cpp @@ -0,0 +1,211 @@ +/** + * \file src/jit/impl/mlir/ir/lower_to_gpu_pass.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ + +#include "megbrain_build_config.h" +#if MGB_JIT && MGB_JIT_MLIR + +#include "megbrain/common.h" +#include "megbrain/jit/mlir/ir/dialect.h" +#include "megbrain/jit/mlir/ir/passes.h" + +#include "../utils.h" + +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +using namespace mgb; +using namespace jit; + +namespace { + +mlir::Value get_operand(ConversionPatternRewriter& rewriter, + const mlir::Location& loc, const mlir::Value& val, + const mlir::Value& index) { + if (val.getType().isa()) { + return rewriter.create(loc, val, index); + } else { + return val; + } +} + +mlir::Value get_tid(ConversionPatternRewriter& rewriter, const Location& loc) { + auto thread_idx = rewriter.create( + loc, rewriter.getIndexType(), rewriter.getStringAttr("x")); + auto block_idx = rewriter.create( + loc, rewriter.getIndexType(), rewriter.getStringAttr("x")); + auto group_size = rewriter.create( + loc, rewriter.getIndexType(), rewriter.getStringAttr("x")); + Value index = rewriter.create( + loc, thread_idx, + rewriter.create(loc, block_idx, group_size)); + + return index; +} + +template +struct BinaryOpLowering : public ConversionPattern { + BinaryOpLowering(MLIRContext* ctx, gpu::LaunchOp* launch_op) + : ConversionPattern(BinaryOp::getOperationName(), 1, ctx), + m_launch_op{launch_op} {} + + LogicalResult matchAndRewrite( + Operation* op, ArrayRef operands, + ConversionPatternRewriter& rewriter) const final { + auto loc = op->getLoc(); + + typename BinaryOp::Adaptor binary_adaptor(operands); + rewriter.setInsertionPointToEnd(&(m_launch_op->body().front())); + + auto index = get_tid(rewriter, loc); + auto loaded_lhs = + get_operand(rewriter, loc, binary_adaptor.lhs(), index); + auto loaded_rhs = + get_operand(rewriter, loc, binary_adaptor.rhs(), index); + + auto binary_op = + rewriter.create(loc, loaded_lhs, loaded_rhs); + + rewriter.replaceOp(op, binary_op.getResult()); + return success(); + } + +private: + gpu::LaunchOp* m_launch_op; +}; + +using AddOpLowering = BinaryOpLowering; + +struct ReturnOpLowering : public ConversionPattern { + ReturnOpLowering(MLIRContext* ctx, gpu::LaunchOp* launch_op) + : ConversionPattern(jit::ReturnOp::getOperationName(), 1, ctx), + m_launch_op{launch_op} {} + + LogicalResult matchAndRewrite( + Operation* op, ArrayRef, + ConversionPatternRewriter& rewriter) const final { + rewriter.replaceOpWithNewOp(op); + auto loc = op->getLoc(); + + //! remove the first gpu.terminator + m_launch_op->body().front().front().erase(); + + //! if (tid >= nr_tid) {return;} in the begin of the block + rewriter.setInsertionPointToStart(&(m_launch_op->body().front())); + Block* cond_block = rewriter.getInsertionBlock(); + Block::iterator op_position = rewriter.getInsertionPoint(); + Block* remaining_ops_block = + rewriter.splitBlock(cond_block, op_position); + rewriter.setInsertionPointToEnd(cond_block); + + auto index = get_tid(rewriter, loc); + auto comparison = rewriter.create( + loc, CmpIPredicate::sge, index, + m_launch_op->getParentOfType() + .getArguments() + .back()); + + Block* then_block = + rewriter.splitBlock(cond_block, rewriter.getInsertionPoint()); + rewriter.setInsertionPointToEnd(then_block); + rewriter.create(loc); + + rewriter.setInsertionPointToEnd(cond_block); + rewriter.create( + loc, comparison, then_block, ArrayRef(), + remaining_ops_block, ArrayRef()); + + rewriter.setInsertionPointToEnd(remaining_ops_block); + rewriter.create(loc); + + return success(); + } + +private: + gpu::LaunchOp* m_launch_op; +}; + +struct AssignOpLowering : public ConversionPattern { + AssignOpLowering(MLIRContext* ctx, gpu::LaunchOp* launch_op) + : ConversionPattern(jit::AssignOp::getOperationName(), 2, ctx), + m_launch_op{launch_op} {} + + LogicalResult matchAndRewrite( + Operation* op, ArrayRef operands, + ConversionPatternRewriter& rewriter) const final { + auto loc = op->getLoc(); + + AssignOpAdaptor assign_adaptor(operands); + rewriter.setInsertionPointToEnd(&(m_launch_op->body().front())); + + auto index = get_tid(rewriter, loc); + + auto loaded_lhs = + get_operand(rewriter, loc, assign_adaptor.lhs(), index); + rewriter.create(loc, loaded_lhs, assign_adaptor.rhs(), index); + + rewriter.eraseOp(op); + return success(); + } + +private: + gpu::LaunchOp* m_launch_op; +}; + +class MgbToGpuLoweringPass + : public PassWrapper { +public: + void runOnFunction() override final { + auto func_op = getFunction(); + Location loc = func_op.getLoc(); + OpBuilder builder(&func_op.getBody()); + Value constantOne = builder.create(loc, 1); + gpu::LaunchOp launch_op = builder.create( + loc, constantOne, constantOne, constantOne, constantOne, + constantOne, constantOne); + builder.setInsertionPointToEnd(&(launch_op.body().front())); + builder.create(loc); + + OwningRewritePatternList patterns; + ConversionTarget target(getContext()); + target.addLegalDialect(); + target.addLegalDialect(); + target.addIllegalDialect(); + + patterns.insert( + &getContext(), &launch_op); + + if (failed(applyPartialConversion(func_op, target, patterns))) { + signalPassFailure(); + } + } +}; + +} // namespace + +std::unique_ptr mgb::jit::create_lower_to_gpu_pass() { + return std::make_unique(); +} + +#endif // MGB_JIT_MLIR + +// vim: syntax=cpp.doxygen diff --git a/src/jit/impl/mlir/ir/lower_to_llvm_pass.cpp b/src/jit/impl/mlir/ir/lower_to_llvm_pass.cpp new file mode 100644 index 00000000..ac0b2243 --- /dev/null +++ b/src/jit/impl/mlir/ir/lower_to_llvm_pass.cpp @@ -0,0 +1,56 @@ +/** + * \file src/jit/impl/mlir/ir/lower_to_llvm_pass.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ + +#include "megbrain_build_config.h" +#if MGB_JIT && MGB_JIT_MLIR + +#include "megbrain/common.h" +#include "megbrain/jit/mlir/ir/dialect.h" +#include "megbrain/jit/mlir/ir/passes.h" + +#include +#include +#include +#include + +using namespace mgb; +using namespace jit; + +namespace { + +class AffineToLLVMLoweringPass : public PassWrapper> { + void runOnOperation() final { + LLVMConversionTarget target(getContext()); + target.addLegalOp(); + + LLVMTypeConverter typeConverter(&getContext()); + + OwningRewritePatternList patterns; + populateAffineToStdConversionPatterns(patterns, &getContext()); + populateLoopToStdConversionPatterns(patterns, &getContext()); + populateStdToLLVMConversionPatterns(typeConverter, patterns); + + auto module = getOperation(); + if (failed(applyFullConversion(module, target, patterns))) + signalPassFailure(); + } +}; +} // namespace + +std::unique_ptr mgb::jit::create_lower_to_llvm_pass() { + return std::make_unique(); +} + +#endif // MGB_JIT_MLIR + +// vim: syntax=cpp.doxygen diff --git a/src/jit/impl/mlir/ir/ops.td b/src/jit/impl/mlir/ir/ops.td new file mode 100644 index 00000000..05fdf018 --- /dev/null +++ b/src/jit/impl/mlir/ir/ops.td @@ -0,0 +1,72 @@ +/** + * \file src/jit/impl/mlir/ir/ops.td + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ + +#ifndef MGB_MLIR_OPS +#define MGB_MLIR_OPS + +include "mlir/IR/OpBase.td" +include "mlir/Interfaces/SideEffectInterfaces.td" + +include "./shape_inference_interface.td" + +def Mgb_Dialect : Dialect { + let name = "mgb"; + let cppNamespace = "mgb::jit"; +} + +class ElemwiseOp traits = []> : + Op; + +class GenericOp traits = []> : + Op; + +def AddOp : ElemwiseOp<"add", + [NoSideEffect, DeclareOpInterfaceMethods]> { + let summary = "element-wise addition operation"; + let description = [{ + The "add" operation performs element-wise addition between two tensors. + The shapes of the tensor operands are expected to match. + }]; + + let arguments = (ins F32MemRef:$lhs, F32MemRef:$rhs); + let results = (outs F32MemRef); + + // Specify a parser and printer method. + let parser = [{ return ::parseBinaryOp(parser, result); }]; + let printer = [{ return ::printBinaryOp(p, *this); }]; + + // Allow building an AddOp with from the two input operands. + let builders = [ + OpBuilder<"OpBuilder &b, OperationState &state, Value lhs, Value rhs"> + ]; +} + +def ReturnOp : GenericOp<"return", + [NoSideEffect, HasParent<"FuncOp">, Terminator]> { + let summary = "return operation"; + let description = [{ + The "return" operation represents a return operation within a function. + The operation takes an no tensor operand and produces no results. + }]; + +} + +def AssignOp : GenericOp<"assign", []> { + let summary = "assign op"; + let description = [{ + assign rhs to lhs without results + }]; + + let arguments = (ins F32MemRef:$lhs, F32MemRef:$rhs); +} + +#endif diff --git a/src/jit/impl/mlir/ir/shape_inference_interface.td b/src/jit/impl/mlir/ir/shape_inference_interface.td new file mode 100644 index 00000000..7ed48e5f --- /dev/null +++ b/src/jit/impl/mlir/ir/shape_inference_interface.td @@ -0,0 +1,30 @@ +/** + * \file src/jit/impl/mlir/ir/shape_inference_interface.td + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ + +#ifndef MGB_JIT_SHAPE_INFERENCE_INTERFACE +#define MGB_JIT_SHAPE_INFERENCE_INTERFACE + +include "mlir/IR/OpBase.td" + +def ShapeInferenceOpInterface : OpInterface<"ShapeInference"> { + let description = [{ + Interface to access a registered method to infer the return types for an + operation that can be used during type inference. + }]; + + let methods = [ + InterfaceMethod<"Infer and set the output shape for the current operation.", + "void", "infer_shapes"> + ]; +} + +#endif // MGB_SHAPE_INFERENCE_INTERFACE diff --git a/src/jit/impl/mlir/ir/shape_inference_pass.cpp b/src/jit/impl/mlir/ir/shape_inference_pass.cpp new file mode 100644 index 00000000..12f3174c --- /dev/null +++ b/src/jit/impl/mlir/ir/shape_inference_pass.cpp @@ -0,0 +1,100 @@ +/** + * \file src/jit/impl/mlir/ir/shape_inference_pass.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ + +#include "megbrain_build_config.h" +#if MGB_JIT && MGB_JIT_MLIR + +#include "megbrain/common.h" +#include "megbrain/jit/mlir/ir/dialect.h" +#include "megbrain/jit/mlir/ir/passes.h" +#include "megbrain/jit/mlir/ir/shape_inference_interface.h" + +#include +#include +#include + +using namespace mgb; +using namespace jit; + +#include "megbrain/jit/mlir/ir/shape_inference_interface.cpp.inc" + +namespace { +class ShapeInferencePass + : public mlir::PassWrapper { +public: + void runOnFunction() override { + auto f = getFunction(); + + llvm::SmallPtrSet op_worklist; + f.walk([&](mlir::Operation* op) { + if (returns_dynamic_shape(op)) + op_worklist.insert(op); + }); + + // Iterate on the operations in the worklist until all operations have + // been inferred or no change happened (fix point). + while (!op_worklist.empty()) { + // Find the next operation ready for inference, that is an operation + // with all operands already resolved (non-generic). + auto nextop = llvm::find_if(op_worklist, all_operands_inferred); + if (nextop == op_worklist.end()) + break; + + Operation* op = *nextop; + op_worklist.erase(op); + + if (auto shapeOp = dyn_cast(op)) { + shapeOp.infer_shapes(); + } else { + mgb_log_error( + "unable to infer shape of operation without shape " + "inference interface"); + return signalPassFailure(); + } + } + + // If the operation worklist isn't empty, this indicates a failure. + if (!op_worklist.empty()) { + mgb_log_error( + "Shape inference failed, %zu operations couldn't be " + "inferred", + op_worklist.size()); + signalPassFailure(); + } + } + + //! A utility method that returns if the given operation has all of its + //! operands inferred. + static bool all_operands_inferred(Operation* op) { + return llvm::all_of(op->getOperandTypes(), [](Type operandType) { + return operandType.isa(); + }); + } + + //! A utility method that returns if the given operation has a dynamically + //! shaped result. + static bool returns_dynamic_shape(Operation* op) { + return llvm::any_of(op->getResultTypes(), [](Type resultType) { + return !resultType.isa(); + }); + } +}; + +} // namespace + +std::unique_ptr mgb::jit::create_shape_inference_pass() { + return std::make_unique(); +} + +#endif // MGB_JIT_MLIR + +// vim: syntax=cpp.doxygen diff --git a/src/jit/impl/mlir/mlir_gen.cpp b/src/jit/impl/mlir/mlir_gen.cpp new file mode 100644 index 00000000..aeb9f142 --- /dev/null +++ b/src/jit/impl/mlir/mlir_gen.cpp @@ -0,0 +1,207 @@ +/** + * \file src/jit/impl/mlir/mlir_gen.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ + +#include "megbrain_build_config.h" +#if MGB_JIT && MGB_JIT_MLIR + +#include "./mlir_gen.h" + +#include "./utils.h" +#include "megbrain/jit/mlir/ir/dialect.h" +#include "megbrain/opr/basic_arith.h" +#include "megdnn/dtype.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +using namespace mgb; +using namespace jit; + +namespace { +class MLIRGenImpl { +public: + MLIRGenImpl(mlir::MLIRContext& context) : m_builder(&context) {} + + std::pair gen( + const InternalGraph& internal_graph, + const JITExecutor::Args& args) { + mlir::ModuleOp module = + mlir::ModuleOp::create(m_builder.getUnknownLoc()); + + //! Create main routine function + auto func_op = gen_func_op(internal_graph, args); + module.push_back(func_op); + + if (mlir::failed(mlir::verify(module))) { + module.emitError("module verification error"); + return {}; + } + + return {func_op.getName(), module}; + } + +private: + mlir::OpBuilder m_builder; + llvm::ScopedHashTable m_symbol_table; + + mlir::FuncOp gen_func_op(const InternalGraph& internal_graph, + const JITExecutor::Args& args) { + llvm::ScopedHashTableScope var_scope( + m_symbol_table); + std::vector func_args; + for (auto&& arg : args.inputs) { + func_args.push_back(get_type(arg.layout)); + } + for (auto&& arg : args.outputs) { + func_args.push_back(get_type(arg.layout)); + } + //! the last arg is nr_elements + func_args.push_back(m_builder.getIndexType()); + + auto func_type = m_builder.getFunctionType(func_args, llvm::None); + //! function name maybe renamed in later pass + mlir::FuncOp func_op = mlir::FuncOp::create(m_builder.getUnknownLoc(), + "func", func_type); + if (!func_op) + return nullptr; + + func_op.setAttr("llvm.emit_c_interface", + mlir::UnitAttr::get(m_builder.getContext())); + auto& entry_block = *func_op.addEntryBlock(); + size_t idx = 0; + for (auto&& input : args.inputs) { + if (mlir::failed(declare(internal_graph.placeholders()[input.idx] + ->output(0) + ->name(), + entry_block.getArgument(idx)))) { + return nullptr; + } + idx++; + } + for (auto&& output : args.outputs) { + if (mlir::failed(declare(output.from->name(), + entry_block.getArgument(idx)))) { + return nullptr; + } + idx++; + } + + m_builder.setInsertionPointToStart(&entry_block); + + if (mlir::failed(gen_func_body(internal_graph, args))) { + func_op.erase(); + return nullptr; + } + + jit::ReturnOp return_op; + if (!return_op) { + m_builder.create(m_builder.getUnknownLoc()); + } + std::string op_content = to_string(func_op); + func_op.setName( + ssprintf("jit_mlir_%" PRIx64, + XXHash{}.update(op_content.data(), op_content.size()) + .digest())); + return func_op; + } + + mlir::LogicalResult gen_func_body(const InternalGraph& internal_graph, + const JITExecutor::Args& args) { + llvm::ScopedHashTableScope var_scope( + m_symbol_table); + cg::DepOprIter{[&](cg::OperatorNodeBase* opr) { + if (opr->same_type()) { + return; + } + + if (opr->same_type()) { + auto&& out = gen_op(opr->cast_final()); + mgb_assert( + mlir::succeeded(declare(opr->output(0)->name(), out))); + } + }}.add(internal_graph.output()); + m_builder.create(m_builder.getUnknownLoc(), + get(internal_graph.output()), + get(args.outputs[0].from)); + + return mlir::success(); + } + + mlir::Value gen_op(const opr::Elemwise& opr) { + switch (opr.param().mode) { + case opr::Elemwise::Mode::ADD: + return m_builder.create(m_builder.getUnknownLoc(), + get(opr.input(0)), + get(opr.input(1))); + break; + default: + return nullptr; + } + return nullptr; + } + + mlir::Type get_type(const TensorLayout& layout) { + std::vector shape; + for (size_t i = 0; i < layout.ndim; i++) { + shape.push_back(layout[i]); + } + mgb_assert(layout.ndim != 0); + switch (layout.dtype.enumv()) { + case DTypeEnum::Float32: + return mlir::MemRefType::get(shape, m_builder.getF32Type()); + default: + mgb_throw(InternalError, "No supported dtype: %s", + layout.dtype.name()); + } + return mlir::UnrankedMemRefType::get(m_builder.getNoneType(), 0); + } + + mlir::Value get(const VarNode* var) { + if (auto ret = m_symbol_table.lookup(var->name())) { + return ret; + } + mgb_throw(InternalError, "Unknown var: %s", var->cname()); + } + + mlir::LogicalResult declare(llvm::StringRef var, mlir::Value value) { + if (m_symbol_table.count(var)) { + return mlir::failure(); + } + m_symbol_table.insert(var, value); + return mlir::success(); + } +}; +} // namespace + +std::pair mgb::jit::mlir_gen( + mlir::MLIRContext& context, + const mgb::jit::InternalGraph& internal_graph, + const mgb::jit::JITExecutor::Args& args) { + return MLIRGenImpl(context).gen(internal_graph, args); +} + +#endif // MGB_JIT_MLIR + +// vim: syntax=cpp.doxygen diff --git a/src/jit/impl/mlir/mlir_gen.h b/src/jit/impl/mlir/mlir_gen.h new file mode 100644 index 00000000..25b79488 --- /dev/null +++ b/src/jit/impl/mlir/mlir_gen.h @@ -0,0 +1,42 @@ +/** + * \file src/jit/impl/mlir/mlir_gen.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ + +#pragma once + +#include "megbrain_build_config.h" +#if MGB_JIT && MGB_JIT_MLIR + +#include "megbrain/jit/executor_opr.h" +#include "megbrain/jit/internal_graph.h" + +#include + +namespace mgb { +namespace jit { + +/** + * \brief generate mlir from subgraph. + * + * \param context mlir context + * \param internal_graph internal graph used to generate mlir + * \param args input args for the internal graph + * \return A pair of {kernel_name, module} + **/ +std::pair mlir_gen( + mlir::MLIRContext& context, const InternalGraph& internal_graph, + const JITExecutor::Args& args); +} +} // namespace mgb + +#endif // MGB_JIT_MLIR + +// vim: syntax=cpp.doxygen diff --git a/src/jit/impl/mlir/utils.h b/src/jit/impl/mlir/utils.h new file mode 100644 index 00000000..4d109af8 --- /dev/null +++ b/src/jit/impl/mlir/utils.h @@ -0,0 +1,45 @@ +/** + * \file src/jit/impl/mlir/utils.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ + +#pragma once + +#include "megbrain_build_config.h" +#if MGB_JIT && MGB_JIT_MLIR + +#include "megbrain/common.h" +#include "megbrain/exception.h" +#include "megdnn/basic_types.h" +#include "megdnn/dtype.h" + +#include + +#include + +#include + +namespace mgb { +namespace jit { + +template +std::string to_string(T&& t) { + std::string ret; + llvm::raw_string_ostream stream(ret); + t.print(stream); + return ret; +} + +} // namespace jit +} // namespace mgb + +#endif // MGB_JIT && MGB_JIT_MLIR + +// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} diff --git a/src/jit/include/megbrain/jit/mlir/ir/dialect.h b/src/jit/include/megbrain/jit/mlir/ir/dialect.h new file mode 100644 index 00000000..706587e2 --- /dev/null +++ b/src/jit/include/megbrain/jit/mlir/ir/dialect.h @@ -0,0 +1,45 @@ +/** + * \file src/jit/impl/mlir/ir/dialect.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ + +#pragma once + +#include "megbrain_build_config.h" +#if MGB_JIT && MGB_JIT_MLIR + +#include +#include +#include +#include + +#include "megbrain/jit/mlir/ir/shape_inference_interface.h" + +namespace mgb { +namespace jit { + +class MgbDialect : public ::mlir::Dialect { +public: + explicit MgbDialect(::mlir::MLIRContext* ctx); + + //! We should register this function in dialect + static llvm::StringRef getDialectNamespace() { return "mgb::jit"; } +}; + +#define GET_OP_CLASSES +using namespace mlir; +#include "megbrain/jit/mlir/ir/ops.h.inc" + +} // namespace jit +} // namespace mgb + +#endif // MGB_JIT_MLIR + +// vim: syntax=cpp.doxygen diff --git a/src/jit/include/megbrain/jit/mlir/ir/passes.h b/src/jit/include/megbrain/jit/mlir/ir/passes.h new file mode 100644 index 00000000..dc7eaaa7 --- /dev/null +++ b/src/jit/include/megbrain/jit/mlir/ir/passes.h @@ -0,0 +1,43 @@ +/** + * \file src/jit/impl/mlir/ir/passes.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ + +#pragma once + +#include +#include "megbrain_build_config.h" +#if MGB_JIT && MGB_JIT_MLIR + +#include + +#include + +namespace mgb { +namespace jit { + +std::unique_ptr create_shape_inference_pass(); + +/** + * \brief Create a pass for lowering to operations in the `Affine` and `Std` + * dialects, for a subset of the megbrain IR. + */ +std::unique_ptr create_lower_to_affine_pass(); + +std::unique_ptr create_lower_to_llvm_pass(); + +std::unique_ptr create_lower_to_gpu_pass(); + +} // namespace jit +} // namespace mgb + +#endif // MGB_JIT_MLIR + +// vim: syntax=cpp.doxygen diff --git a/src/jit/include/megbrain/jit/mlir/ir/shape_inference_interface.h b/src/jit/include/megbrain/jit/mlir/ir/shape_inference_interface.h new file mode 100644 index 00000000..594e72ba --- /dev/null +++ b/src/jit/include/megbrain/jit/mlir/ir/shape_inference_interface.h @@ -0,0 +1,33 @@ +/** + * \file src/jit/impl/mlir/ir/shape_inference_interface.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ + +#pragma once + +#include "megbrain_build_config.h" +#if MGB_JIT && MGB_JIT_MLIR + +#include "mlir/IR/OpDefinition.h" + +namespace mgb { +namespace jit { + +/// Include the auto-generated declarations. +#include "megbrain/jit/mlir/ir/shape_inference_interface.h.inc" + +} // end namespace toy +} // end namespace mlir + + + +#endif // MGB_JIT_MLIR + +// vim: syntax=cpp.doxygen diff --git a/src/jit/include/megbrain/jit/param_elem_visitor.h b/src/jit/include/megbrain/jit/param_elem_visitor.h index 12740572..81c3dd13 100644 --- a/src/jit/include/megbrain/jit/param_elem_visitor.h +++ b/src/jit/include/megbrain/jit/param_elem_visitor.h @@ -18,8 +18,8 @@ */ /*! - * * \brief fast division for unsigned int - * */ + * \brief fast division for unsigned int + */ struct Uint32Fastdiv { unsigned int m_mul, m_divisor, m_divisor_is_not_1, m_inc_dividend, m_shift; diff --git a/src/jit/test/codegen.cpp b/src/jit/test/codegen.cpp index bcafdc35..11cce30e 100644 --- a/src/jit/test/codegen.cpp +++ b/src/jit/test/codegen.cpp @@ -14,6 +14,7 @@ #include "megbrain/jit/executor_opr.h" #include "megbrain/opr/basic_arith_wrapper.h" #include "megbrain/test/helper.h" +#include "megdnn/dtype.h" #if MGB_JIT using namespace mgb; @@ -120,6 +121,44 @@ void run(Backend backend, CompNode cn) { template <> void run(Backend, CompNode) {} + +#if MGB_JIT_MLIR +void run_mlir(CompNode cn) { + set_backend(Backend::MLIR); + auto graph = ComputingGraph::make(); + HostTensorGenerator gen; + + auto host_x0 = gen({23, 42}, cn), host_x1 = gen({23, 42}, cn), + host_x2 = gen({23, 42}, cn); + + auto a = opr::Host2DeviceCopy::make(*graph, host_x0), + b = opr::Host2DeviceCopy::make(*graph, host_x1), + c = opr::Host2DeviceCopy::make(*graph, host_x2); + + auto y = a + b + c; + + VarNodeArray inputs{a.node(), b.node(), c.node()}, outputs{y.node()}; + auto ig_gen = + std::make_unique(y.node()->owner_opr()); + + for (auto i : get_rev_topo_order(y)) { + if (!i->same_type()) { + ig_gen->add_opr(i); + } + } + + auto igraph = ig_gen->generate(); + auto y_jit = JITExecutor::make(igraph, ig_gen->orig_inps()); + + HostTensorND host_y, host_y_jit; + auto func = graph->compile({make_callback_copy(y, host_y), + make_callback_copy(y_jit, host_y_jit)}); + func->execute(); + + MGB_ASSERT_TENSOR_EQ(host_y, host_y_jit); +} +#endif + } // anonymous namespace #if MGB_JIT_HALIDE @@ -140,6 +179,20 @@ TYPED_TEST(TestJITNvrtcCodeGen, run) { run(Backend::NVRTC, CompNode::load("gpu0")); } +#if MGB_JIT_MLIR +TEST(TestJITMlirCodeGen, Basic) { + auto cn = CompNode::load("cpu0"); + run_mlir(cn); +} + +TEST(TestJITMlirCodeGen, BasicGPU) { + REQUIRE_GPU(1); + auto cn = CompNode::load("gpu0"); + run_mlir(cn); +} + +#endif + #endif // MGB_JIT // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} diff --git a/src/jit/test/helper.cpp b/src/jit/test/helper.cpp index c0311a5d..eef8f6ef 100644 --- a/src/jit/test/helper.cpp +++ b/src/jit/test/helper.cpp @@ -35,6 +35,9 @@ void jit::set_backend(Backend backend) { case Backend::NVRTC: setenv("MGB_JIT_BACKEND", "NVRTC", 1); return; + case Backend::MLIR: + setenv("MGB_JIT_BACKEND", "MLIR", 1); + return; default: mgb_assert(0); } diff --git a/src/jit/test/helper.h b/src/jit/test/helper.h index 33d3e1b7..f69df975 100644 --- a/src/jit/test/helper.h +++ b/src/jit/test/helper.h @@ -15,7 +15,7 @@ namespace mgb { namespace jit { -enum class Backend { NONE, HALIDE, NVRTC }; +enum class Backend { NONE, HALIDE, NVRTC, MLIR }; void set_backend(Backend backend);