GitOrigin-RevId: dd45984cca
release-1.1
@@ -49,6 +49,14 @@ function(external_tablegen_library) | |||||
install(TARGETS ${_NAME} EXPORT ${MGE_EXPORT_TARGETS}) | install(TARGETS ${_NAME} EXPORT ${MGE_EXPORT_TARGETS}) | ||||
endfunction() | endfunction() | ||||
set(LLVM_LIBS LLVMCore LLVMSupport LLVMX86CodeGen LLVMOrcJIT LLVMNVPTXCodeGen LLVMNVPTXDesc LLVMNVPTXInfo) | |||||
set(MLIR_CORE_LIBS MLIRAnalysis MLIRExecutionEngine MLIRIR MLIRParser MLIRPass MLIRSideEffectInterfaces MLIRTransforms) | |||||
set(MLIR_DIALECT_LIBS MLIRAsync MLIRAVX512 MLIRGPU MLIRLLVMAVX512 MLIRNVVMIR MLIROpenACC MLIRPDL MLIRPDLInterp MLIRQuant MLIRROCDLIR MLIRSDBM MLIRShape MLIRSPIRV MLIRStandardOpsTransforms) | |||||
set(MLIR_CONVERSION_LIBS MLIRAffineToStandard MLIRAVX512ToLLVM MLIRGPUToGPURuntimeTransforms MLIRGPUToNVVMTransforms MLIRSCFToStandard) | |||||
set(MLIR_TRANSLATION_LIBS MLIRTargetLLVMIR MLIRTargetNVVMIR) | |||||
set(MLIR_LIBS ${MLIR_CORE_LIBS} ${MLIR_DIALECT_LIBS} ${MLIR_CONVERSION_LIBS} ${MLIR_TRANSLATION_LIBS}) | |||||
set(MLIR_LLVM_LIBS ${LLVM_LIBS} ${MLIR_LIBS}) | |||||
if (MGE_USE_SYSTEM_LIB) | if (MGE_USE_SYSTEM_LIB) | ||||
find_package(ZLIB) | find_package(ZLIB) | ||||
find_package(MLIR REQUIRED CONFIG) | find_package(MLIR REQUIRED CONFIG) | ||||
@@ -77,9 +85,7 @@ if (MGE_USE_SYSTEM_LIB) | |||||
endif() | endif() | ||||
endfunction(find_mlir_llvm_lib) | endfunction(find_mlir_llvm_lib) | ||||
set(MLIR_COMPONENTS MLIRAnalysis;MLIRExecutionEngine;MLIRIR;MLIRParser;MLIRPass;MLIRSideEffectInterfaces;MLIRTargetLLVMIR;MLIRTransforms;MLIRAffineToStandard;MLIRSCFToStandard;MLIRAVX512ToLLVM;MLIRAVX512;MLIRLLVMAVX512;MLIRSDBM;MLIRROCDLIR;MLIRGPU;MLIRQuant;MLIRSPIRV;MLIRNVVMIR;MLIRShape;MLIRGPUToNVVMTransforms;MLIRTargetNVVMIR;MLIRGPUToGPURuntimeTransforms;MLIRStandardOpsTransforms) | |||||
foreach(c ${MLIR_COMPONENTS}) | |||||
foreach(c ${MLIR_LIBS}) | |||||
find_mlir_llvm_lib(${c}) | find_mlir_llvm_lib(${c}) | ||||
endforeach() | endforeach() | ||||
return() | return() | ||||
@@ -119,5 +125,3 @@ set(MLIR_LLVM_INCLUDE_DIR | |||||
${PROJECT_BINARY_DIR}/third_party/llvm-project/llvm/tools/mlir/include | ${PROJECT_BINARY_DIR}/third_party/llvm-project/llvm/tools/mlir/include | ||||
) | ) | ||||
set(MLIR_TABLEGEN_EXE mlir-tblgen) | set(MLIR_TABLEGEN_EXE mlir-tblgen) | ||||
set(MLIR_LLVM_LIBS LLVMCore;LLVMSupport;LLVMX86CodeGen;LLVMOrcJIT;LLVMNVPTXCodeGen;LLVMNVPTXDesc;LLVMNVPTXInfo;MLIRAnalysis;MLIRExecutionEngine;MLIRIR;MLIRParser;MLIRPass;MLIRSideEffectInterfaces;MLIRTargetLLVMIR;MLIRTransforms;MLIRAffineToStandard;MLIRSCFToStandard;MLIRAVX512ToLLVM;MLIRAVX512;MLIRLLVMAVX512;MLIRSDBM;MLIRROCDLIR;MLIRGPU;MLIRQuant;MLIRSPIRV;MLIRNVVMIR;MLIRGPUToNVVMTransforms;MLIRShape;MLIRTargetNVVMIR;MLIRGPUToGPURuntimeTransforms;MLIRStandardOpsTransforms) |
@@ -67,8 +67,8 @@ mlir::OwnedBlob compile_ptx_to_cubin(const std::string ptx, mlir::Location, | |||||
} | } | ||||
std::unique_ptr<llvm::Module> translate_module_to_nvvm_ir_and_link_device( | std::unique_ptr<llvm::Module> translate_module_to_nvvm_ir_and_link_device( | ||||
Operation* m) { | |||||
std::unique_ptr<llvm::Module> module = mlir::translateModuleToNVVMIR(m); | |||||
Operation* m, llvm::LLVMContext& llvmContext, llvm::StringRef name) { | |||||
std::unique_ptr<llvm::Module> module = mlir::translateModuleToNVVMIR(m, llvmContext); | |||||
auto get_device_path = []() -> std::string { | auto get_device_path = []() -> std::string { | ||||
auto cuda_path = getenv("CUDA_BIN_PATH"); | auto cuda_path = getenv("CUDA_BIN_PATH"); | ||||
std::string device_dir; | std::string device_dir; | ||||
@@ -223,6 +223,7 @@ void MLIRCompiler::run_lowering_pass(mlir::OwningModuleRef& module, | |||||
std::unique_ptr<Executable> MLIRCompiler::do_compile( | std::unique_ptr<Executable> MLIRCompiler::do_compile( | ||||
const InternalGraph& graph, const JITExecutor::Args& args) { | const InternalGraph& graph, const JITExecutor::Args& args) { | ||||
mlir::MLIRContext ctx; | mlir::MLIRContext ctx; | ||||
ctx.getOrLoadDialect<MgbDialect>(); | |||||
ctx.printStackTraceOnDiagnostic(true); | ctx.printStackTraceOnDiagnostic(true); | ||||
ctx.printOpOnDiagnostic(true); | ctx.printOpOnDiagnostic(true); | ||||
@@ -24,7 +24,8 @@ | |||||
using namespace mgb; | using namespace mgb; | ||||
using namespace jit; | using namespace jit; | ||||
MgbDialect::MgbDialect(mlir::MLIRContext* ctx) : mlir::Dialect("mgb", ctx) { | |||||
MgbDialect::MgbDialect(mlir::MLIRContext* ctx) | |||||
: mlir::Dialect("mgb", ctx, mlir::TypeID::get<MgbDialect>()) { | |||||
addOperations< | addOperations< | ||||
#define GET_OP_LIST | #define GET_OP_LIST | ||||
#include "megbrain/jit/mlir/ir/ops.cpp.inc" | #include "megbrain/jit/mlir/ir/ops.cpp.inc" | ||||
@@ -209,6 +209,11 @@ struct ConstantScalarOpLowering | |||||
class MgbToAffineLoweringPass | class MgbToAffineLoweringPass | ||||
: public PassWrapper<MgbToAffineLoweringPass, FunctionPass> { | : public PassWrapper<MgbToAffineLoweringPass, FunctionPass> { | ||||
public: | public: | ||||
void getDependentDialects(mlir::DialectRegistry& registry) const override { | |||||
registry.insert<mlir::AffineDialect>(); | |||||
registry.insert<mlir::StandardOpsDialect>(); | |||||
} | |||||
void runOnFunction() override final { | void runOnFunction() override final { | ||||
ConversionTarget target(getContext()); | ConversionTarget target(getContext()); | ||||
target.addLegalDialect<AffineDialect, StandardOpsDialect>(); | target.addLegalDialect<AffineDialect, StandardOpsDialect>(); | ||||
@@ -259,6 +259,11 @@ private: | |||||
class MgbToGpuLoweringPass | class MgbToGpuLoweringPass | ||||
: public PassWrapper<MgbToGpuLoweringPass, FunctionPass> { | : public PassWrapper<MgbToGpuLoweringPass, FunctionPass> { | ||||
public: | public: | ||||
void getDependentDialects(mlir::DialectRegistry& registry) const override { | |||||
registry.insert<mlir::gpu::GPUDialect>(); | |||||
registry.insert<mlir::StandardOpsDialect>(); | |||||
} | |||||
void runOnFunction() override final { | void runOnFunction() override final { | ||||
auto func_op = getFunction(); | auto func_op = getFunction(); | ||||
Location loc = func_op.getLoc(); | Location loc = func_op.getLoc(); | ||||
@@ -21,6 +21,8 @@ | |||||
#include <mlir/Conversion/SCFToStandard/SCFToStandard.h> | #include <mlir/Conversion/SCFToStandard/SCFToStandard.h> | ||||
#include <mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h> | #include <mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h> | ||||
#include <mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h> | #include <mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h> | ||||
#include <mlir/Dialect/LLVMIR/LLVMDialect.h> | |||||
#include <mlir/Dialect/SCF/SCF.h> | |||||
#include <mlir/Dialect/StandardOps/Transforms/Passes.h> | #include <mlir/Dialect/StandardOps/Transforms/Passes.h> | ||||
using namespace mgb; | using namespace mgb; | ||||
@@ -30,6 +32,12 @@ namespace { | |||||
class AffineToLLVMLoweringPass : public PassWrapper<AffineToLLVMLoweringPass, | class AffineToLLVMLoweringPass : public PassWrapper<AffineToLLVMLoweringPass, | ||||
OperationPass<ModuleOp>> { | OperationPass<ModuleOp>> { | ||||
public: | |||||
void getDependentDialects(mlir::DialectRegistry& registry) const override { | |||||
registry.insert<mlir::LLVM::LLVMDialect>(); | |||||
registry.insert<mlir::scf::SCFDialect>(); | |||||
} | |||||
void runOnOperation() final { | void runOnOperation() final { | ||||
LLVMConversionTarget target(getContext()); | LLVMConversionTarget target(getContext()); | ||||
target.addLegalOp<ModuleOp, ModuleTerminatorOp>(); | target.addLegalOp<ModuleOp, ModuleTerminatorOp>(); | ||||
@@ -21,7 +21,7 @@ namespace jit { | |||||
inline bool is_elemwise_float(const mlir::Type& dt) { | inline bool is_elemwise_float(const mlir::Type& dt) { | ||||
if (auto cast = dt.dyn_cast_or_null<mlir::MemRefType>()) { | if (auto cast = dt.dyn_cast_or_null<mlir::MemRefType>()) { | ||||
if (cast.getElementType().getKind() == mlir::StandardTypes::F32) { | |||||
if (cast.getElementType().isF32()) { | |||||
return true; | return true; | ||||
} | } | ||||
} | } | ||||
@@ -82,13 +82,12 @@ megdnn::DType jit::mlir_type_to_dtype(mlir::Type type) { | |||||
if (auto cast = type.dyn_cast_or_null<mlir::MemRefType>()) { | if (auto cast = type.dyn_cast_or_null<mlir::MemRefType>()) { | ||||
element_type = cast.getElementType(); | element_type = cast.getElementType(); | ||||
} | } | ||||
switch (element_type.getKind()) { | |||||
case mlir::StandardTypes::F32: | |||||
return megdnn::dtype::Float32{}; | |||||
default: | |||||
mgb_throw(InternalError, | |||||
"Unsupport mlir type for MemRefType, got: %s\n", | |||||
mlir_type_to_string(type).c_str()); | |||||
if (element_type.isF32()) { | |||||
return megdnn::dtype::Float32{}; | |||||
} else { | |||||
mgb_throw(InternalError, | |||||
"Unsupport mlir type for MemRefType, got: %s\n", | |||||
mlir_type_to_string(type).c_str()); | |||||
} | } | ||||
return {}; | return {}; | ||||
} | } | ||||
@@ -34,13 +34,13 @@ public: | |||||
static llvm::StringRef getDialectNamespace() { return "mgb::jit"; } | static llvm::StringRef getDialectNamespace() { return "mgb::jit"; } | ||||
}; | }; | ||||
} // namespace jit | |||||
} // namespace mgb | |||||
#define GET_OP_CLASSES | #define GET_OP_CLASSES | ||||
using namespace mlir; | using namespace mlir; | ||||
#include "megbrain/jit/mlir/ir/ops.h.inc" | #include "megbrain/jit/mlir/ir/ops.h.inc" | ||||
} // namespace jit | |||||
} // namespace mgb | |||||
#endif // MGB_JIT && MGB_JIT_MLIR | #endif // MGB_JIT && MGB_JIT_MLIR | ||||
// vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen |