From 9767ca8f19c654a44f9dbd855d2b18794b35455a Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Thu, 27 Aug 2020 00:31:17 +0800 Subject: [PATCH] feat(mgb/jit): refactor code and add more elemwise mode GitOrigin-RevId: eb6bcadf54d3111a94152c10d0fb826ca2a16d66 --- cmake/llvm-project.cmake | 4 +- src/jit/impl/mlir/compiler.cpp | 2 - src/jit/impl/mlir/executable_cpu.cpp | 5 +- src/jit/impl/mlir/executable_cuda.cpp | 10 +- src/jit/impl/mlir/ir/CMakeLists.txt | 6 +- src/jit/impl/mlir/ir/common.cpp | 89 ++++- src/jit/impl/mlir/ir/common.h | 65 +++- src/jit/impl/mlir/ir/dialect.cpp | 63 +--- src/jit/impl/mlir/ir/each_mode.h | 412 +++++++++++++++++++++ src/jit/impl/mlir/ir/interfaces.td | 33 ++ src/jit/impl/mlir/ir/lower_to_affine_pass.cpp | 104 +++++- src/jit/impl/mlir/ir/lower_to_gpu_pass.cpp | 101 ++++- src/jit/impl/mlir/ir/lower_to_llvm_pass.cpp | 2 + src/jit/impl/mlir/ir/ops.td | 152 +++++++- src/jit/impl/mlir/ir/shape_inference_interface.td | 30 -- src/jit/impl/mlir/ir/shape_inference_pass.cpp | 100 ----- src/jit/impl/mlir/ir/utils.cpp | 111 ++++++ src/jit/impl/mlir/mlir_gen.cpp | 53 +-- src/jit/impl/utils.cpp | 2 +- src/jit/include/megbrain/jit/mlir/ir/dialect.h | 7 +- .../{shape_inference_interface.h => interfaces.h} | 21 +- src/jit/include/megbrain/jit/mlir/ir/passes.h | 6 +- .../mlir => include/megbrain/jit/mlir/ir}/utils.h | 30 +- src/jit/test/codegen.cpp | 158 +++++++- src/opr/include/megbrain/opr/basic_arith_wrapper.h | 1 + 25 files changed, 1246 insertions(+), 321 deletions(-) create mode 100644 src/jit/impl/mlir/ir/each_mode.h create mode 100644 src/jit/impl/mlir/ir/interfaces.td delete mode 100644 src/jit/impl/mlir/ir/shape_inference_interface.td delete mode 100644 src/jit/impl/mlir/ir/shape_inference_pass.cpp create mode 100644 src/jit/impl/mlir/ir/utils.cpp rename src/jit/include/megbrain/jit/mlir/ir/{shape_inference_interface.h => interfaces.h} (62%) rename src/jit/{impl/mlir => include/megbrain/jit/mlir/ir}/utils.h (52%) diff --git a/cmake/llvm-project.cmake b/cmake/llvm-project.cmake index 9bbb3c23..561435c5 100644 --- a/cmake/llvm-project.cmake +++ b/cmake/llvm-project.cmake @@ -77,7 +77,7 @@ if (MGE_USE_SYSTEM_LIB) endif() endfunction(find_mlir_llvm_lib) - set(MLIR_COMPONENTS MLIRAnalysis;MLIRExecutionEngine;MLIRIR;MLIRParser;MLIRPass;MLIRSideEffectInterfaces;MLIRTargetLLVMIR;MLIRTransforms;MLIRAffineToStandard;MLIRSCFToStandard;MLIRAVX512ToLLVM;MLIRAVX512;MLIRLLVMAVX512;MLIRSDBM;MLIRROCDLIR;MLIRGPU;MLIRQuant;MLIRSPIRV;MLIRNVVMIR;MLIRShape;MLIRGPUToNVVMTransforms;MLIRTargetNVVMIR;MLIRGPUToGPURuntimeTransforms) + set(MLIR_COMPONENTS MLIRAnalysis;MLIRExecutionEngine;MLIRIR;MLIRParser;MLIRPass;MLIRSideEffectInterfaces;MLIRTargetLLVMIR;MLIRTransforms;MLIRAffineToStandard;MLIRSCFToStandard;MLIRAVX512ToLLVM;MLIRAVX512;MLIRLLVMAVX512;MLIRSDBM;MLIRROCDLIR;MLIRGPU;MLIRQuant;MLIRSPIRV;MLIRNVVMIR;MLIRShape;MLIRGPUToNVVMTransforms;MLIRTargetNVVMIR;MLIRGPUToGPURuntimeTransforms;MLIRStandardOpsTransforms) foreach(c ${MLIR_COMPONENTS}) find_mlir_llvm_lib(${c}) @@ -120,4 +120,4 @@ set(MLIR_LLVM_INCLUDE_DIR ) set(MLIR_TABLEGEN_EXE mlir-tblgen) -set(MLIR_LLVM_LIBS LLVMCore;LLVMSupport;LLVMX86CodeGen;LLVMOrcJIT;LLVMNVPTXCodeGen;LLVMNVPTXDesc;LLVMNVPTXInfo;MLIRAnalysis;MLIRExecutionEngine;MLIRIR;MLIRParser;MLIRPass;MLIRSideEffectInterfaces;MLIRTargetLLVMIR;MLIRTransforms;MLIRAffineToStandard;MLIRSCFToStandard;MLIRAVX512ToLLVM;MLIRAVX512;MLIRLLVMAVX512;MLIRSDBM;MLIRROCDLIR;MLIRGPU;MLIRQuant;MLIRSPIRV;MLIRNVVMIR;MLIRGPUToNVVMTransforms;MLIRShape;MLIRTargetNVVMIR;MLIRGPUToGPURuntimeTransforms) +set(MLIR_LLVM_LIBS LLVMCore;LLVMSupport;LLVMX86CodeGen;LLVMOrcJIT;LLVMNVPTXCodeGen;LLVMNVPTXDesc;LLVMNVPTXInfo;MLIRAnalysis;MLIRExecutionEngine;MLIRIR;MLIRParser;MLIRPass;MLIRSideEffectInterfaces;MLIRTargetLLVMIR;MLIRTransforms;MLIRAffineToStandard;MLIRSCFToStandard;MLIRAVX512ToLLVM;MLIRAVX512;MLIRLLVMAVX512;MLIRSDBM;MLIRROCDLIR;MLIRGPU;MLIRQuant;MLIRSPIRV;MLIRNVVMIR;MLIRGPUToNVVMTransforms;MLIRShape;MLIRTargetNVVMIR;MLIRGPUToGPURuntimeTransforms;MLIRStandardOpsTransforms) diff --git a/src/jit/impl/mlir/compiler.cpp b/src/jit/impl/mlir/compiler.cpp index 6e3ae023..0ebb2fba 100644 --- a/src/jit/impl/mlir/compiler.cpp +++ b/src/jit/impl/mlir/compiler.cpp @@ -64,7 +64,6 @@ mlir::OwnedBlob compile_ptx_to_cubin(const std::string ptx, mlir::Location, void add_cpu_lowering_pass(mlir::PassManager& manager) { { mlir::OpPassManager& opt_pm = manager.nest(); - opt_pm.addPass(create_shape_inference_pass()); opt_pm.addPass(mlir::createCanonicalizerPass()); opt_pm.addPass(mlir::createCSEPass()); } @@ -84,7 +83,6 @@ void add_cpu_lowering_pass(mlir::PassManager& manager) { void add_cuda_lowering_pass(mlir::PassManager& manager, CompNode cn) { { mlir::OpPassManager& opt_pm = manager.nest(); - opt_pm.addPass(create_shape_inference_pass()); opt_pm.addPass(mlir::createCanonicalizerPass()); opt_pm.addPass(mlir::createCSEPass()); } diff --git a/src/jit/impl/mlir/executable_cpu.cpp b/src/jit/impl/mlir/executable_cpu.cpp index d9190736..35c8368e 100644 --- a/src/jit/impl/mlir/executable_cpu.cpp +++ b/src/jit/impl/mlir/executable_cpu.cpp @@ -14,9 +14,10 @@ #if MGB_JIT && MGB_JIT_MLIR #include "./executable_cpu.h" -#include "./utils.h" +#include "megbrain/jit/mlir/ir/utils.h" #include +#include using namespace mgb; using namespace jit; @@ -113,7 +114,7 @@ void MLIRCPUExecutable::execute(JITExecutor* fusion_opr) { idx++; } - args_array_pointer[idx++] = &nr_elements; + args_array_pointer.push_back(&nr_elements); std::string adapter_name = std::string("_mlir_ciface_") + m_kernel_name; auto err = m_engine->invoke( adapter_name, llvm::MutableArrayRef(args_array_pointer)); diff --git a/src/jit/impl/mlir/executable_cuda.cpp b/src/jit/impl/mlir/executable_cuda.cpp index 2dea7499..2921631e 100644 --- a/src/jit/impl/mlir/executable_cuda.cpp +++ b/src/jit/impl/mlir/executable_cuda.cpp @@ -17,13 +17,15 @@ #if MGB_CUDA #include "./executable_cuda.h" -#include "./utils.h" -#include "megbrain/utils/timer.h" -#include "megbrain/utils/persistent_cache.h" + #include "megbrain/comp_node_env.h" +#include "megbrain/jit/mlir/ir/utils.h" +#include "megbrain/utils/persistent_cache.h" +#include "megbrain/utils/timer.h" -#include #include +#include +#include #include using namespace mgb; diff --git a/src/jit/impl/mlir/ir/CMakeLists.txt b/src/jit/impl/mlir/ir/CMakeLists.txt index 405a5341..5c89c7c3 100644 --- a/src/jit/impl/mlir/ir/CMakeLists.txt +++ b/src/jit/impl/mlir/ir/CMakeLists.txt @@ -8,12 +8,12 @@ external_tablegen_library( TBLGEN MLIR SRCS - "shape_inference_interface.td" + "interfaces.td" INCLUDES ${MGB_MLIR_TABLEGEN_INC} ${MLIR_LLVM_INCLUDE_DIR} OUTS - -gen-op-interface-decls include/megbrain/jit/mlir/ir/shape_inference_interface.h.inc - -gen-op-interface-defs include/megbrain/jit/mlir/ir/shape_inference_interface.cpp.inc + -gen-op-interface-decls include/megbrain/jit/mlir/ir/interfaces.h.inc + -gen-op-interface-defs include/megbrain/jit/mlir/ir/interfaces.cpp.inc ) external_tablegen_library( diff --git a/src/jit/impl/mlir/ir/common.cpp b/src/jit/impl/mlir/ir/common.cpp index cd22eda6..e46ba6c2 100644 --- a/src/jit/impl/mlir/ir/common.cpp +++ b/src/jit/impl/mlir/ir/common.cpp @@ -13,29 +13,88 @@ #include "megbrain_build_config.h" #if MGB_JIT && MGB_JIT_MLIR -#include "common.h" +#include "./common.h" -#include +#include "mlir/Dialect/StandardOps/IR/Ops.h" using namespace mgb; using namespace jit; -mlir::Value jit::insert_alloc_and_dealloc(mlir::MemRefType type, - mlir::Location loc, - mlir::PatternRewriter& rewriter) { - auto alloc = rewriter.create(loc, type); +#define cb(name, op) \ + mlir::Value ValueBuilderHelper::name(mlir::Value lhs, mlir::Value rhs) { \ + return m_builder.create(m_location, lhs, rhs); \ + } +cb(add, AddFOp); +cb(sub, SubFOp); +cb(mul, MulFOp); +cb(div, DivFOp); +cb(mod, RemFOp); +#undef cb - // Make sure to allocate at the beginning of the block. - auto* parent_block = alloc.getOperation()->getBlock(); - alloc.getOperation()->moveBefore(&parent_block->front()); +#define cb(name, mode) \ + mlir::Value ValueBuilderHelper::name(mlir::Value lhs, mlir::Value rhs) { \ + return m_builder.create( \ + m_location, mlir::CmpFPredicate::mode, lhs, rhs); \ + } +cb(gt, OGT); +cb(ge, OGE); +cb(lt, OLT); +cb(le, OLE); +cb(eq, OEQ); +#undef cb - // Make sure to deallocate this alloc at the end of the block. This is fine - // as toy functions have no control flow. - auto dealloc = rewriter.create(loc, alloc); - dealloc.getOperation()->moveBefore(&parent_block->back()); - return alloc; +mlir::Value ValueBuilderHelper::min(mlir::Value lhs, mlir::Value rhs) { + mlir::Value cmp = m_builder.create( + m_location, mlir::CmpFPredicate::OLT, lhs, rhs); + return m_builder.create(m_location, cmp, lhs, rhs); +} + +mlir::Value ValueBuilderHelper::max(mlir::Value lhs, mlir::Value rhs) { + mlir::Value cmp = m_builder.create( + m_location, mlir::CmpFPredicate::OGT, lhs, rhs); + return m_builder.create(m_location, cmp, lhs, rhs); +} + +mlir::Value ValueBuilderHelper::const_val(float val) { + return m_builder.create(m_location, + m_builder.getF32FloatAttr(val)); +} + +#define cb(name, op) \ + mlir::Value ValueBuilderHelper::name(mlir::Value lhs) { \ + return m_builder.create(m_location, lhs); \ + } + +cb(neg, NegFOp); +cb(abs, AbsFOp); +cb(ceil, CeilFOp); +cb(cos, CosOp); +cb(exp, ExpOp); +cb(exp2, Exp2Op); +cb(log10, Log10Op); +cb(log2, Log2Op); +cb(rsqrt, RsqrtOp); +cb(sin, SinOp); +cb(sqrt, SqrtOp); +cb(tanh, TanhOp); +#undef cb + +mlir::Value ValueBuilderHelper::floor(mlir::Value lhs) { + //! FIXME use standard floor when upgrade llvm + return neg(ceil(neg(lhs))); +} + +mlir::Value ValueBuilderHelper::log(mlir::Value lhs) { + // math.log10(math.e) = 0.4342944819032518f + return div(log10(lhs), const_val(0.4342944819032518f)); +} + +mlir::Value ValueBuilderHelper::select(mlir::Value cond, mlir::Value true_val, + mlir::Value false_val) { + return m_builder.create(m_location, cond, true_val, + false_val); } #endif // MGB_JIT && MGB_JIT_MLIR -// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} +// vim: syntax=cpp.doxygen diff --git a/src/jit/impl/mlir/ir/common.h b/src/jit/impl/mlir/ir/common.h index ee01ff6c..22bd7080 100644 --- a/src/jit/impl/mlir/ir/common.h +++ b/src/jit/impl/mlir/ir/common.h @@ -6,7 +6,8 @@ * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. */ #pragma once @@ -14,19 +15,71 @@ #include "megbrain_build_config.h" #if MGB_JIT && MGB_JIT_MLIR -#include -#include +#include +#include #include namespace mgb { namespace jit { -mlir::Value insert_alloc_and_dealloc(mlir::MemRefType type, mlir::Location loc, - mlir::PatternRewriter& rewriter); +/** + * \brief Helper function for common value builder + */ +class ValueBuilderHelper { +public: + ValueBuilderHelper(mlir::OpBuilder& b, mlir::Location location) + : m_builder{b}, m_location{location} {}; + +#define cb(name) \ + mlir::Value name(mlir::ValueRange operands) { \ + return name(operands[0], operands[1]); \ + } \ + mlir::Value name(mlir::Value lhs, mlir::Value rhs) + cb(add); + cb(sub); + cb(mul); + cb(div); + cb(max); + cb(min); + cb(mod); + cb(gt); + cb(ge); + cb(lt); + cb(le); + cb(eq); +#undef cb + mlir::Value const_val(float val); + +#define cb(name) \ + mlir::Value name(mlir::ValueRange operands) { return name(operands[0]); } \ + mlir::Value name(mlir::Value lhs) + cb(neg); + cb(abs); + cb(ceil); + cb(floor); + cb(cos); + cb(exp); + cb(exp2); + cb(log10); + cb(log2); + cb(log); + cb(rsqrt); + cb(sin); + cb(sqrt); + cb(tanh); +#undef cb + + mlir::Value select(mlir::Value cond, mlir::Value true_val, + mlir::Value false_val); + +private: + mlir::OpBuilder& m_builder; + mlir::Location m_location; +}; } // namespace jit } // namespace mgb #endif // MGB_JIT && MGB_JIT_MLIR -// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} +// vim: syntax=cpp.doxygen diff --git a/src/jit/impl/mlir/ir/dialect.cpp b/src/jit/impl/mlir/ir/dialect.cpp index 35576109..41bbbcce 100644 --- a/src/jit/impl/mlir/ir/dialect.cpp +++ b/src/jit/impl/mlir/ir/dialect.cpp @@ -15,77 +15,26 @@ #include "megbrain/jit/mlir/ir/dialect.h" -#include #include #include #include +#include using namespace mgb; using namespace jit; -MgbDialect::MgbDialect(mlir::MLIRContext *ctx) : mlir::Dialect("mgb", ctx) { - addOperations< +MgbDialect::MgbDialect(mlir::MLIRContext* ctx) : mlir::Dialect("mgb", ctx) { + addOperations< #define GET_OP_LIST #include "megbrain/jit/mlir/ir/ops.cpp.inc" - >(); -} - -static mlir::ParseResult parseBinaryOp(mlir::OpAsmParser &parser, - mlir::OperationState &result) { - SmallVector operands; - llvm::SMLoc operandsLoc = parser.getCurrentLocation(); - Type type; - if (parser.parseOperandList(operands, /*requiredOperandCount=*/2) || - parser.parseOptionalAttrDict(result.attributes) || - parser.parseColonType(type)) - return mlir::failure(); - - // If the type is a function type, it contains the input and result types of - // this operation. - if (FunctionType funcType = type.dyn_cast()) { - if (parser.resolveOperands(operands, funcType.getInputs(), operandsLoc, - result.operands)) - return mlir::failure(); - result.addTypes(funcType.getResults()); - return mlir::success(); - } - - // Otherwise, the parsed type is the type of both operands and results. - if (parser.resolveOperands(operands, type, result.operands)) - return mlir::failure(); - result.addTypes(type); - return mlir::success(); -} - -static void printBinaryOp(mlir::OpAsmPrinter &printer, mlir::Operation *op) { - printer << op->getName() << " " << op->getOperands(); - printer.printOptionalAttrDict(op->getAttrs()); - printer << " : "; - - // If all of the types are the same, print the type directly. - Type resultType = *op->result_type_begin(); - if (llvm::all_of(op->getOperandTypes(), - [=](Type type) { return type == resultType; })) { - printer << resultType; - return; - } - - // Otherwise, print a functional type. - printer.printFunctionalType(op->getOperandTypes(), op->getResultTypes()); + >(); } -///////////////////////// ElemwiseOp ///////////////////////////////////////////// - -void AddOp::build(mlir::OpBuilder &builder, mlir::OperationState &state, - mlir::Value lhs, mlir::Value rhs) { - state.addTypes(lhs.getType()); - state.addOperands({lhs, rhs}); -} -void AddOp::infer_shapes() { getResult().setType(getOperand(0).getType()); } - #define GET_OP_CLASSES #include "megbrain/jit/mlir/ir/ops.cpp.inc" +#include "megbrain/jit/mlir/ir/interfaces.cpp.inc" + #endif // MGB_JIT && MGB_JIT_MLIR // vim: syntax=cpp.doxygen diff --git a/src/jit/impl/mlir/ir/each_mode.h b/src/jit/impl/mlir/ir/each_mode.h new file mode 100644 index 00000000..53f2b536 --- /dev/null +++ b/src/jit/impl/mlir/ir/each_mode.h @@ -0,0 +1,412 @@ +/** + * \file src/jit/impl/mlir/ir/each_mode.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ + +#pragma once + +#include "megbrain_build_config.h" +#if MGB_JIT && MGB_JIT_MLIR + +#include "megbrain/jit/mlir/ir/dialect.h" + +#include "./common.h" + +#include +#include +#include + +// clang-format off +#define MLIR_MGB_FOREACH_ELEMWISE_MODE_UNARY(cb) \ + cb(ReluOp, RELU) \ + cb(AbsOp, ABS) \ + cb(NegOp, NEGATE) \ + cb(CeilOp, CEIL) \ + cb(CosOp, COS) \ + cb(ExpOp, EXP) \ + cb(FloorOp, FLOOR) \ + cb(LogOp, LOG) \ + cb(Log1POp, LOG1P) \ + cb(SigmoidOp, SIGMOID) \ + cb(SinOp, SIN) \ + cb(TanhOp, TANH) \ + cb(FastTanhOp, FAST_TANH) \ + cb(HswishOp, H_SWISH) \ + cb(ExpM1Op, EXPM1) \ + cb(RoundOp, ROUND) + +#define MLIR_MGB_FOREACH_ELEMWISE_MODE_BINARY(cb) \ + cb(AbsGradOp, ABS_GRAD) \ + cb(AddOp, ADD) \ + cb(FloorDivOp, FLOOR_DIV) \ + cb(MaxOp, MAX) \ + cb(MinOp, MIN) \ + cb(ModOp, MOD) \ + cb(SubOp, SUB) \ + cb(MulOp, MUL) \ + cb(TrueDivOp, TRUE_DIV) \ + cb(SigmoidGradOp, SIGMOID_GRAD) \ + cb(SwishGt0Op, SWITCH_GT0) \ + cb(TanhGradOp, TANH_GRAD) \ + cb(LtOp, LT) \ + cb(LeqOp, LEQ) \ + cb(EqOp, EQ) \ + cb(FuseAddReluOp, FUSE_ADD_RELU) \ + cb(LogSumExpOp, LOG_SUM_EXP) \ + cb(FuseAddTanhOp, FUSE_ADD_TANH) \ + cb(FastTanhGradOp, FAST_TANH_GRAD) \ + cb(FuseAddSigmoidOp, FUSE_ADD_SIGMOID) \ + cb(HswishGradOp, H_SWISH_GRAD) \ + cb(FuseAddHswishOp, FUSE_ADD_H_SWISH) + +#define MLIR_MGB_FOREACH_ELEMWISE_MODE_TERNARY(cb) \ + cb(CondLeqMovOp, COND_LEQ_MOV) \ + cb(FuseMulAdd3Op, FUSE_MUL_ADD3) +// clang-format on + +namespace mgb { +namespace jit { + +template +struct StandardOp; + +#define cb(mgb_op, fun) \ + template <> \ + struct StandardOp { \ + mlir::Value operator()(mlir::OpBuilder& builder, mlir::Location loc, \ + ValueRange operands) { \ + ValueBuilderHelper helper(builder, loc); \ + return helper.fun(operands); \ + } \ + } + +//! unary +cb(AbsOp, abs); +cb(NegOp, neg); +cb(ExpOp, exp); +cb(CosOp, cos); +cb(CeilOp, ceil); +cb(FloorOp, floor); +cb(LogOp, log); +cb(SinOp, sin); +cb(TanhOp, tanh); + +//! binary +cb(AddOp, add); +cb(MaxOp, max); +cb(MinOp, min); +cb(SubOp, sub); +cb(MulOp, mul); +cb(ModOp, mod); +cb(TrueDivOp, div); + +#undef cb + +/////////////////////////// unary op /////////////////////////// +//! max(x, 0) +template <> +struct StandardOp { + mlir::Value operator()(mlir::OpBuilder& builder, mlir::Location loc, + ValueRange operands) { + ValueBuilderHelper helper(builder, loc); + return helper.max(operands[0], helper.const_val(0.f)); + } +}; + +//! x * (27.f + x * x) / (27.f + 9.f * x * x); +template <> +struct StandardOp { + mlir::Value operator()(mlir::OpBuilder& builder, mlir::Location loc, + ValueRange operands) { + ValueBuilderHelper helper(builder, loc); + auto square = helper.mul(operands[0], operands[0]); + return helper.div( + helper.mul(operands[0], + helper.add(helper.const_val(27.f), square)), + helper.add(helper.const_val(27.f), + helper.mul(helper.const_val(9.f), square))); + } +}; + +//! x * clip(x + 3, 0, 6) / 6 +template <> +struct StandardOp { + mlir::Value operator()(mlir::OpBuilder& builder, mlir::Location loc, + ValueRange operands) { + ValueBuilderHelper helper(builder, loc); + + auto const_3 = helper.const_val(3.f); + auto const_0 = helper.const_val(0.f); + auto const_6 = helper.const_val(6.f); + auto tmp = helper.add(operands[0], const_3); + return helper.div( + helper.mul(operands[0], + helper.min(helper.max(tmp, const_0), const_6)), + const_6); + } +}; + +//! log(1 + p) +template <> +struct StandardOp { + mlir::Value operator()(mlir::OpBuilder& builder, mlir::Location loc, + ValueRange operands) { + ValueBuilderHelper helper(builder, loc); + return helper.log(helper.add(operands[0], helper.const_val(1.f))); + } +}; + +//! 1.f / (expf(-y) + 1.f)) +template <> +struct StandardOp { + mlir::Value operator()(mlir::OpBuilder& builder, mlir::Location loc, + ValueRange operands) { + ValueBuilderHelper helper(builder, loc); + return helper.div(helper.const_val(1.f), + helper.add(helper.exp(helper.neg(operands[0])), + helper.const_val(1.f))); + } +}; + +//! exp(x) - 1 +template <> +struct StandardOp { + mlir::Value operator()(mlir::OpBuilder& builder, mlir::Location loc, + ValueRange operands) { + ValueBuilderHelper helper(builder, loc); + return helper.sub(helper.exp(operands[0]), helper.const_val(1.f)); + } +}; + +template <> +struct StandardOp { + mlir::Value operator()(mlir::OpBuilder& builder, mlir::Location loc, + ValueRange operands) { + ValueBuilderHelper helper(builder, loc); + return helper.select( + helper.gt(operands[0], helper.const_val(0.f)), + helper.floor(helper.add(operands[0], helper.const_val(0.5f))), + helper.ceil(helper.sub(operands[0], helper.const_val(0.5f)))); + } +}; + +/////////////////////////// binary op /////////////////////////// + +//! binary: x > 0 ? y : -y +template <> +struct StandardOp { + mlir::Value operator()(mlir::OpBuilder& builder, mlir::Location loc, + ValueRange operands) { + ValueBuilderHelper helper(builder, loc); + return helper.select(helper.gt(operands[0], helper.const_val(0.f)), + operands[1], helper.neg(operands[1])); + } +}; + +//! x * (1 - x) * y +template <> +struct StandardOp { + mlir::Value operator()(mlir::OpBuilder& builder, mlir::Location loc, + ValueRange operands) { + ValueBuilderHelper helper(builder, loc); + return helper.mul( + helper.mul(operands[0], + helper.sub(helper.const_val(1.f), operands[0])), + operands[1]); + } +}; + +//! (x > 0) * y +template <> +struct StandardOp { + mlir::Value operator()(mlir::OpBuilder& builder, mlir::Location loc, + ValueRange operands) { + ValueBuilderHelper helper(builder, loc); + return helper.select(helper.gt(operands[0], helper.const_val(0.f)), + operands[1], helper.const_val(0.f)); + } +}; + +//! (1 - x * x) * y +template <> +struct StandardOp { + mlir::Value operator()(mlir::OpBuilder& builder, mlir::Location loc, + ValueRange operands) { + ValueBuilderHelper helper(builder, loc); + return helper.mul(helper.sub(helper.const_val(1.0f), + helper.mul(operands[0], operands[0])), + operands[1]); + } +}; + +#define cb(op, fun) \ + template <> \ + struct StandardOp { \ + mlir::Value operator()(mlir::OpBuilder& builder, mlir::Location loc, \ + ValueRange operands) { \ + ValueBuilderHelper helper(builder, loc); \ + return helper.select(helper.fun(operands[0], operands[1]), \ + helper.const_val(1.f), \ + helper.const_val(0.f)); \ + } \ + } + +cb(LtOp, lt); +cb(LeqOp, le); +cb(EqOp, eq); +#undef cb + +//! (x + y) <= ctype(0) ? ctype(0) : (x + y) +template <> +struct StandardOp { + mlir::Value operator()(mlir::OpBuilder& builder, mlir::Location loc, + ValueRange operands) { + ValueBuilderHelper helper(builder, loc); + auto sum = helper.add(operands[0], operands[1]); + return helper.max(sum, helper.const_val(0.f)); + } +}; + +//! log(exp(x) + exp(y)) +template <> +struct StandardOp { + mlir::Value operator()(mlir::OpBuilder& builder, mlir::Location loc, + ValueRange operands) { + ValueBuilderHelper helper(builder, loc); + return helper.log( + helper.add(helper.exp(operands[0]), helper.exp(operands[1]))); + } +}; + +//! floor(x/y) +template <> +struct StandardOp { + mlir::Value operator()(mlir::OpBuilder& builder, mlir::Location loc, + ValueRange operands) { + ValueBuilderHelper helper(builder, loc); + return helper.floor(helper.div(operands[0], operands[1])); + } +}; + +//! tanh(x + y) +template <> +struct StandardOp { + mlir::Value operator()(mlir::OpBuilder& builder, mlir::Location loc, + ValueRange operands) { + ValueBuilderHelper helper(builder, loc); + return helper.tanh(helper.add(operands[0], operands[1])); + } +}; + +//! ((-48.f * x * x) / (3.f + x * x) + 27.f + x * x) / (3.f + x * x) * y +template <> +struct StandardOp { + mlir::Value operator()(mlir::OpBuilder& builder, mlir::Location loc, + ValueRange operands) { + ValueBuilderHelper helper(builder, loc); + auto x_pow2 = helper.mul(operands[0], operands[0]); + auto deno = helper.add(helper.const_val(3.f), x_pow2); + return helper.mul( + helper.div( + helper.add( + helper.add( + helper.div(helper.mul(helper.const_val( + -48.f), + x_pow2), + deno), + helper.const_val(27.f)), + x_pow2), + helper.mul(deno, helper.const_val(9.f))), + operands[1]); + } +}; + +//! 1.f / (expf(-(x+y)) + 1.f)) +template <> +struct StandardOp { + mlir::Value operator()(mlir::OpBuilder& builder, mlir::Location loc, + ValueRange operands) { + ValueBuilderHelper helper(builder, loc); + return helper.div(helper.const_val(1.f), + helper.add(helper.exp(helper.neg(helper.add( + operands[0], operands[1]))), + helper.const_val(1.f))); + } +}; + +//! x < -3.f ? 0.f : (x > 3.f ? y : (2.f * x + 3.f) / 6.f * y) +template <> +struct StandardOp { + mlir::Value operator()(mlir::OpBuilder& builder, mlir::Location loc, + ValueRange operands) { + ValueBuilderHelper helper(builder, loc); + return helper.select( + helper.lt(operands[0], helper.const_val(-3.f)), + helper.const_val(0.f), + helper.select( + helper.gt(operands[0], helper.const_val(3.f)), + operands[1], + helper.mul( + helper.div( + helper.add(helper.mul(helper.const_val( + 2.f), + operands[0]), + helper.const_val(3.f)), + helper.const_val(6.f)), + operands[1]))); + } +}; + +//! (x+y) * min(max(x + y + 3, 0), 6) * (1/6) +template <> +struct StandardOp { + mlir::Value operator()(mlir::OpBuilder& builder, mlir::Location loc, + ValueRange operands) { + ValueBuilderHelper helper(builder, loc); + auto sum = helper.add(operands[0], operands[1]); + + auto const_3 = helper.const_val(3.f); + auto const_0 = helper.const_val(0.f); + auto const_6 = helper.const_val(6.f); + auto tmp = helper.add(sum, const_3); + return helper.div( + helper.mul(sum, helper.min(helper.max(tmp, const_0), const_6)), + const_6); + } +}; + +/////////////////////////// ternary op /////////////////////////// +//! x <= y ? z : ctype(0) +template <> +struct StandardOp { + mlir::Value operator()(mlir::OpBuilder& builder, mlir::Location loc, + ValueRange operands) { + ValueBuilderHelper helper(builder, loc); + return helper.select(helper.le(operands[0], operands[1]), operands[2], + helper.const_val(0.f)); + } +}; + +//! x * y + z +template <> +struct StandardOp { + mlir::Value operator()(mlir::OpBuilder& builder, mlir::Location loc, + ValueRange operands) { + ValueBuilderHelper helper(builder, loc); + return helper.add(helper.mul(operands[0], operands[1]), operands[2]); + } +}; + +} // namespace jit +} // namespace mgb + +#endif // MGB_JIT_MLIR + +// vim: syntax=cpp.doxygen diff --git a/src/jit/impl/mlir/ir/interfaces.td b/src/jit/impl/mlir/ir/interfaces.td new file mode 100644 index 00000000..e5ca6778 --- /dev/null +++ b/src/jit/impl/mlir/ir/interfaces.td @@ -0,0 +1,33 @@ +/** + * \file src/jit/impl/mlir/ir/interfaces.td + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ + +#ifndef MGB_MLIR_INTERFACES +#define MGB_MLIR_INTERFACES + +#ifndef OP_BASE +include "mlir/IR/OpBase.td" +#endif + +def GenericBuilderInterface : OpInterface<"GenericBuilder"> { + let methods = [ + StaticInterfaceMethod<"TODO", "Type", "getResultType", (ins "ArrayRef":$operands)>, + StaticInterfaceMethod<"TODO", "Operation*", "create", (ins + "OpBuilder*":$builder, + "Location":$loc, + "ArrayRef":$operands + )>, + ]; +} + +def ElemwiseOpInterface : OpInterface<"ElemwiseOp">; + +#endif diff --git a/src/jit/impl/mlir/ir/lower_to_affine_pass.cpp b/src/jit/impl/mlir/ir/lower_to_affine_pass.cpp index 41c20cdc..684a03f2 100644 --- a/src/jit/impl/mlir/ir/lower_to_affine_pass.cpp +++ b/src/jit/impl/mlir/ir/lower_to_affine_pass.cpp @@ -16,11 +16,11 @@ #include "megbrain/common.h" #include "megbrain/jit/mlir/ir/dialect.h" #include "megbrain/jit/mlir/ir/passes.h" +#include "megbrain/jit/mlir/ir/utils.h" -#include "./common.h" +#include "./each_mode.h" #include -#include #include #include @@ -57,10 +57,40 @@ void lower_op_to_loops(Operation* op, ValueRange operands, rewriter.replaceOp(op, alloc); } -template +template +struct UnaryOpLowering : public ConversionPattern { + UnaryOpLowering(MLIRContext* ctx) + : ConversionPattern(Op::getOperationName(), 1, ctx) {} + + LogicalResult matchAndRewrite( + Operation* op, ArrayRef operands, + ConversionPatternRewriter& rewriter) const final { + auto loc = op->getLoc(); + lower_op_to_loops( + op, operands, rewriter, + [loc](OpBuilder& builder, ValueRange memref_operands, + ValueRange loop_ivs) { + typename Op::Adaptor binary_adaptor(memref_operands); + LoweredOp lower_op; + + auto loaded_lhs = builder.create( + loc, binary_adaptor.lhs(), loop_ivs); + + return lower_op(builder, loc, {loaded_lhs}); + }); + return success(); + } +}; + +#define cb(_op, _) \ + using _op##Lowering = UnaryOpLowering>; +MLIR_MGB_FOREACH_ELEMWISE_MODE_UNARY(cb) +#undef cb + +template struct BinaryOpLowering : public ConversionPattern { BinaryOpLowering(MLIRContext* ctx) - : ConversionPattern(BinaryOp::getOperationName(), 1, ctx) {} + : ConversionPattern(Op::getOperationName(), 1, ctx) {} LogicalResult matchAndRewrite( Operation* op, ArrayRef operands, @@ -70,20 +100,61 @@ struct BinaryOpLowering : public ConversionPattern { op, operands, rewriter, [loc](OpBuilder& builder, ValueRange memref_operands, ValueRange loop_ivs) { - typename BinaryOp::Adaptor binary_adaptor(memref_operands); + typename Op::Adaptor binary_adaptor(memref_operands); + LoweredOp lower_op; auto loaded_lhs = builder.create( loc, binary_adaptor.lhs(), loop_ivs); auto loaded_rhs = builder.create( loc, binary_adaptor.rhs(), loop_ivs); - return builder.create(loc, loaded_lhs, - loaded_rhs); + return lower_op(builder, loc, {loaded_lhs, loaded_rhs}); }); return success(); } }; -using AddOpLowering = BinaryOpLowering; + +#define cb(_op, _) \ + using _op##Lowering = BinaryOpLowering>; +MLIR_MGB_FOREACH_ELEMWISE_MODE_BINARY(cb) +#undef cb + +template +struct TernaryOpLowering : public ConversionPattern { + TernaryOpLowering(MLIRContext* ctx) + : ConversionPattern(Op::getOperationName(), 1, ctx) {} + + LogicalResult matchAndRewrite( + Operation* op, ArrayRef operands, + ConversionPatternRewriter& rewriter) const final { + auto loc = op->getLoc(); + lower_op_to_loops( + op, operands, rewriter, + [loc](OpBuilder& builder, ValueRange memref_operands, + ValueRange loop_ivs) { + typename Op::Adaptor ternary_adaptor(memref_operands); + LoweredOp lower_op; + + auto loaded_x = builder.create( + loc, ternary_adaptor.x(), loop_ivs); + auto loaded_y = builder.create( + loc, ternary_adaptor.y(), loop_ivs); + auto loaded_z = builder.create( + loc, ternary_adaptor.z(), loop_ivs); + + return lower_op(builder, loc, + {loaded_x, loaded_y, loaded_z}); + }); + return success(); + } +}; + +#define cb(_op, _) \ + using _op##Lowering = \ + TernaryOpLowering>; +MLIR_MGB_FOREACH_ELEMWISE_MODE_TERNARY(cb) +#undef cb + struct AssignOpLowering : public ConversionPattern { AssignOpLowering(MLIRContext* ctx) @@ -126,21 +197,18 @@ class MgbToAffineLoweringPass : public PassWrapper { public: void runOnFunction() override final { - auto function = getFunction(); - - // Verify that the given main has no inputs and results. - if (function.getType().getNumResults()) { - mgb_log_error("expected 'main' to have 0 results"); - return signalPassFailure(); - } - ConversionTarget target(getContext()); target.addLegalDialect(); target.addIllegalDialect(); OwningRewritePatternList patterns; - patterns.insert( - &getContext()); +#define cb(_op, _) _op##Lowering, + patterns.insert(&getContext()); +#undef cb if (failed(applyPartialConversion(getFunction(), target, patterns))) { signalPassFailure(); diff --git a/src/jit/impl/mlir/ir/lower_to_gpu_pass.cpp b/src/jit/impl/mlir/ir/lower_to_gpu_pass.cpp index 0e0c4ed0..dddabef5 100644 --- a/src/jit/impl/mlir/ir/lower_to_gpu_pass.cpp +++ b/src/jit/impl/mlir/ir/lower_to_gpu_pass.cpp @@ -13,11 +13,11 @@ #include "megbrain_build_config.h" #if MGB_JIT && MGB_JIT_MLIR +#include "./each_mode.h" #include "megbrain/common.h" #include "megbrain/jit/mlir/ir/dialect.h" #include "megbrain/jit/mlir/ir/passes.h" - -#include "../utils.h" +#include "megbrain/jit/mlir/ir/utils.h" #include #include @@ -62,10 +62,43 @@ mlir::Value get_tid(ConversionPatternRewriter& rewriter, const Location& loc) { return index; } -template +template +struct UnaryOpLowering : public ConversionPattern { + UnaryOpLowering(MLIRContext* ctx, gpu::LaunchOp* launch_op) + : ConversionPattern(Op::getOperationName(), 1, ctx), + m_launch_op{launch_op} {} + + LogicalResult matchAndRewrite( + Operation* op, ArrayRef operands, + ConversionPatternRewriter& rewriter) const final { + auto loc = op->getLoc(); + + typename Op::Adaptor binary_adaptor(operands); + rewriter.setInsertionPointToEnd(&(m_launch_op->body().front())); + + auto index = get_tid(rewriter, loc); + auto loaded_lhs = + get_operand(rewriter, loc, binary_adaptor.lhs(), index); + + LoweredOp lower_op; + + rewriter.replaceOp(op, lower_op(rewriter, loc, {loaded_lhs})); + return success(); + } + +private: + gpu::LaunchOp* m_launch_op; +}; + +#define cb(_op, _) \ + using _op##Lowering = UnaryOpLowering>; +MLIR_MGB_FOREACH_ELEMWISE_MODE_UNARY(cb) +#undef cb + +template struct BinaryOpLowering : public ConversionPattern { BinaryOpLowering(MLIRContext* ctx, gpu::LaunchOp* launch_op) - : ConversionPattern(BinaryOp::getOperationName(), 1, ctx), + : ConversionPattern(Op::getOperationName(), 1, ctx), m_launch_op{launch_op} {} LogicalResult matchAndRewrite( @@ -73,7 +106,7 @@ struct BinaryOpLowering : public ConversionPattern { ConversionPatternRewriter& rewriter) const final { auto loc = op->getLoc(); - typename BinaryOp::Adaptor binary_adaptor(operands); + typename Op::Adaptor binary_adaptor(operands); rewriter.setInsertionPointToEnd(&(m_launch_op->body().front())); auto index = get_tid(rewriter, loc); @@ -82,10 +115,48 @@ struct BinaryOpLowering : public ConversionPattern { auto loaded_rhs = get_operand(rewriter, loc, binary_adaptor.rhs(), index); - auto binary_op = - rewriter.create(loc, loaded_lhs, loaded_rhs); + LoweredOp lower_op; + + rewriter.replaceOp(op, + lower_op(rewriter, loc, {loaded_lhs, loaded_rhs})); + return success(); + } + +private: + gpu::LaunchOp* m_launch_op; +}; + +#define cb(_op, _) \ + using _op##Lowering = BinaryOpLowering>; +MLIR_MGB_FOREACH_ELEMWISE_MODE_BINARY(cb) +#undef cb + +template +struct TernaryOpLowering : public ConversionPattern { + TernaryOpLowering(MLIRContext* ctx, gpu::LaunchOp* launch_op) + : ConversionPattern(Op::getOperationName(), 1, ctx), + m_launch_op{launch_op} {} + + LogicalResult matchAndRewrite( + Operation* op, ArrayRef operands, + ConversionPatternRewriter& rewriter) const final { + auto loc = op->getLoc(); - rewriter.replaceOp(op, binary_op.getResult()); + typename Op::Adaptor ternary_adaptor(operands); + rewriter.setInsertionPointToEnd(&(m_launch_op->body().front())); + + auto index = get_tid(rewriter, loc); + auto loaded_x = + get_operand(rewriter, loc, ternary_adaptor.x(), index); + auto loaded_y = + get_operand(rewriter, loc, ternary_adaptor.y(), index); + auto loaded_z = + get_operand(rewriter, loc, ternary_adaptor.z(), index); + + LoweredOp lower_op; + + rewriter.replaceOp( + op, lower_op(rewriter, loc, {loaded_x, loaded_y, loaded_z})); return success(); } @@ -93,7 +164,11 @@ private: gpu::LaunchOp* m_launch_op; }; -using AddOpLowering = BinaryOpLowering; +#define cb(_op, _) \ + using _op##Lowering = \ + TernaryOpLowering>; +MLIR_MGB_FOREACH_ELEMWISE_MODE_TERNARY(cb) +#undef cb struct ReturnOpLowering : public ConversionPattern { ReturnOpLowering(MLIRContext* ctx, gpu::LaunchOp* launch_op) @@ -194,6 +269,14 @@ public: patterns.insert( &getContext(), &launch_op); +#define cb(_op, _) _op##Lowering, + patterns.insert(&getContext(), &launch_op); +#undef cb + if (failed(applyPartialConversion(func_op, target, patterns))) { signalPassFailure(); } diff --git a/src/jit/impl/mlir/ir/lower_to_llvm_pass.cpp b/src/jit/impl/mlir/ir/lower_to_llvm_pass.cpp index 47785c04..6415ab43 100644 --- a/src/jit/impl/mlir/ir/lower_to_llvm_pass.cpp +++ b/src/jit/impl/mlir/ir/lower_to_llvm_pass.cpp @@ -21,6 +21,7 @@ #include #include #include +#include using namespace mgb; using namespace jit; @@ -39,6 +40,7 @@ class AffineToLLVMLoweringPass : public PassWrapper traits = []> : - Op; +class ElemwiseBuilderImpl { + code ElemwiseBuilderImpl_create = [{ + static Operation* create(OpBuilder* builder, Location loc, ValueRange operands) { + OperationState state(loc, getOperationName()); + state.addOperands(operands); + state.addTypes(getResultType(operands)); + return builder->createOperation(state); + } + }]; +} + +class ElemwiseOp traits = [NoSideEffect]> : + Op, ElemwiseBuilderImpl; class GenericOp traits = []> : Op; -def AddOp : ElemwiseOp<"add", - [NoSideEffect, DeclareOpInterfaceMethods]> { - let summary = "element-wise addition operation"; - let description = [{ - The "add" operation performs element-wise addition between two tensors. - The shapes of the tensor operands are expected to match. - }]; +class ElemwiseUnaryOp traits = [NoSideEffect]> : + ElemwiseOp { + let arguments = (ins F32MemRef:$lhs); + let results = (outs F32MemRef); + + let builders = [OpBuilder< + "Builder* builder, OperationState& result, ValueRange operands", [{ + result.addOperands(operands); + result.addTypes(getResultType(operands)); + }]>, OpBuilder < + "OpBuilder& builder, OperationState& result, Value lhs", [{ + result.addOperands(lhs); + result.addTypes(getResultType({lhs})); + }] + >]; + + let extraClassDeclaration = [{ + static Type getResultType(ValueRange operands) { + return deduce_result_type(operands); + } + }] # ElemwiseBuilderImpl_create; +} + +def ReluOp : ElemwiseUnaryOp<"relu", [NoSideEffect]>; +def AbsOp : ElemwiseUnaryOp<"abs", [NoSideEffect]>; +def NegOp : ElemwiseUnaryOp<"negate", [NoSideEffect]>; +/* ACOS */ +/* ASIN */ +def CeilOp : ElemwiseUnaryOp<"ceil", [NoSideEffect]>; +def CosOp : ElemwiseUnaryOp<"cos", [NoSideEffect]>; +def ExpOp : ElemwiseUnaryOp<"exp", [NoSideEffect]>; +def ExpM1Op : ElemwiseUnaryOp<"expm1", [NoSideEffect]>; +def FloorOp : ElemwiseUnaryOp<"floor", [NoSideEffect]>; +def LogOp : ElemwiseUnaryOp<"log", [NoSideEffect]>; +def Log1POp : ElemwiseUnaryOp<"log1p", [NoSideEffect]>; +def SigmoidOp: ElemwiseUnaryOp<"sigmoid", [NoSideEffect]>; +def SinOp : ElemwiseUnaryOp<"sin", [NoSideEffect]>; +def TanhOp : ElemwiseUnaryOp<"tanh", [NoSideEffect]>; +def FastTanhOp : ElemwiseUnaryOp<"fast_tanh", [NoSideEffect]>; +def HswishOp : ElemwiseUnaryOp<"hswish", [NoSideEffect]>; +def RoundOp : ElemwiseUnaryOp<"round", [NoSideEffect]>; +/* ERF */ +/* ERFINV */ +/* ERFC */ +/* ERFCINV */ + +class ElemwiseBinaryOp traits = [NoSideEffect]> : + ElemwiseOp { let arguments = (ins F32MemRef:$lhs, F32MemRef:$rhs); let results = (outs F32MemRef); - // Specify a parser and printer method. - let parser = [{ return ::parseBinaryOp(parser, result); }]; - let printer = [{ return ::printBinaryOp(p, *this); }]; + let builders = [OpBuilder< + "Builder* builder, OperationState& result, ValueRange operands", [{ + result.addOperands(operands); + result.addTypes(getResultType(operands)); + }] + >, OpBuilder < + "OpBuilder& builder, OperationState& result, Value lhs, Value rhs", [{ + result.addOperands(lhs); + result.addOperands(rhs); + result.addTypes(getResultType({lhs, rhs})); + }] + >]; + + let extraClassDeclaration = [{ + static Type getResultType(ValueRange operands) { + return deduce_result_type(operands); + } + }] # ElemwiseBuilderImpl_create; +} + +def AbsGradOp : ElemwiseBinaryOp<"abs_grad", [NoSideEffect]>; +def AddOp : ElemwiseBinaryOp<"add", [Commutative, NoSideEffect]>; +def FloorDivOp : ElemwiseBinaryOp<"floor_div", [NoSideEffect]>; +def MaxOp : ElemwiseBinaryOp<"max", [Commutative, NoSideEffect]>; +def MinOp : ElemwiseBinaryOp<"min", [Commutative, NoSideEffect]>; +def ModOp : ElemwiseBinaryOp<"mod", [NoSideEffect]>; +def MulOp : ElemwiseBinaryOp<"mul", [Commutative, NoSideEffect]>; +def SubOp : ElemwiseBinaryOp<"sub", [NoSideEffect]>; +def SigmoidGradOp : ElemwiseBinaryOp<"sigmoid_grad", [NoSideEffect]>; +def SwishGt0Op : ElemwiseBinaryOp<"switch_gt0", [NoSideEffect]>; +def TanhGradOp : ElemwiseBinaryOp<"tanh_grad", [NoSideEffect]>; +def LtOp : ElemwiseBinaryOp<"lt", [NoSideEffect]>; +def LeqOp : ElemwiseBinaryOp<"leq", [NoSideEffect]>; +def EqOp : ElemwiseBinaryOp<"eq", [Commutative, NoSideEffect]>; +def FuseAddReluOp : ElemwiseBinaryOp<"fuse_add_relu", [NoSideEffect]>; +def TrueDivOp : ElemwiseBinaryOp<"true_div", [NoSideEffect]>; +/* POW */ +def LogSumExpOp : ElemwiseBinaryOp<"log_sum_exp", [Commutative, NoSideEffect]>; +def FuseAddTanhOp : ElemwiseBinaryOp<"fuse_add_tanh", [NoSideEffect]>; +def FastTanhGradOp : ElemwiseBinaryOp<"fast_tanh_grad", [NoSideEffect]>; +def FuseAddSigmoidOp : ElemwiseBinaryOp<"fuse_add_sigmoid", [NoSideEffect]>; +def HswishGradOp : ElemwiseBinaryOp<"hswish_grad", [NoSideEffect]>; +def FuseAddHswishOp : ElemwiseBinaryOp<"fuse_add_hswish", [NoSideEffect]>; +/* ATAN2 */ + +class ElemwiseTernaryOp traits = [NoSideEffect]> : + ElemwiseOp { + + let arguments = (ins F32MemRef:$x, F32MemRef:$y, F32MemRef:$z); + let results = (outs F32MemRef); + + let builders = [OpBuilder< + "Builder* builder, OperationState& result, ValueRange operands", [{ + result.addOperands(operands); + result.addTypes(getResultType(operands)); + }] + >, OpBuilder < + "OpBuilder& builder, OperationState& result, Value x, Value y, Value z", [{ + result.addOperands(x); + result.addOperands(y); + result.addOperands(z); + result.addTypes(getResultType({x, y, z})); + }] + >]; - // Allow building an AddOp with from the two input operands. - let builders = [ - OpBuilder<"OpBuilder &b, OperationState &state, Value lhs, Value rhs"> - ]; + let extraClassDeclaration = [{ + static Type getResultType(ValueRange operands) { + return deduce_result_type(operands); + } + }] # ElemwiseBuilderImpl_create; } +def CondLeqMovOp: ElemwiseTernaryOp<"cond_leq_mov", [NoSideEffect]>; +def FuseMulAdd3Op: ElemwiseTernaryOp<"fuse_mul_add3", [NoSideEffect]>; + def ReturnOp : GenericOp<"return", [NoSideEffect, HasParent<"FuncOp">, Terminator]> { let summary = "return operation"; diff --git a/src/jit/impl/mlir/ir/shape_inference_interface.td b/src/jit/impl/mlir/ir/shape_inference_interface.td deleted file mode 100644 index 7ed48e5f..00000000 --- a/src/jit/impl/mlir/ir/shape_inference_interface.td +++ /dev/null @@ -1,30 +0,0 @@ -/** - * \file src/jit/impl/mlir/ir/shape_inference_interface.td - * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") - * - * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or - * implied. - */ - -#ifndef MGB_JIT_SHAPE_INFERENCE_INTERFACE -#define MGB_JIT_SHAPE_INFERENCE_INTERFACE - -include "mlir/IR/OpBase.td" - -def ShapeInferenceOpInterface : OpInterface<"ShapeInference"> { - let description = [{ - Interface to access a registered method to infer the return types for an - operation that can be used during type inference. - }]; - - let methods = [ - InterfaceMethod<"Infer and set the output shape for the current operation.", - "void", "infer_shapes"> - ]; -} - -#endif // MGB_SHAPE_INFERENCE_INTERFACE diff --git a/src/jit/impl/mlir/ir/shape_inference_pass.cpp b/src/jit/impl/mlir/ir/shape_inference_pass.cpp deleted file mode 100644 index c00d5c83..00000000 --- a/src/jit/impl/mlir/ir/shape_inference_pass.cpp +++ /dev/null @@ -1,100 +0,0 @@ -/** - * \file src/jit/impl/mlir/ir/shape_inference_pass.cpp - * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") - * - * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or - * implied. - */ - -#include "megbrain_build_config.h" -#if MGB_JIT && MGB_JIT_MLIR - -#include "megbrain/common.h" -#include "megbrain/jit/mlir/ir/dialect.h" -#include "megbrain/jit/mlir/ir/passes.h" -#include "megbrain/jit/mlir/ir/shape_inference_interface.h" - -#include -#include -#include - -using namespace mgb; -using namespace jit; - -#include "megbrain/jit/mlir/ir/shape_inference_interface.cpp.inc" - -namespace { -class ShapeInferencePass - : public mlir::PassWrapper { -public: - void runOnFunction() override { - auto f = getFunction(); - - llvm::SmallPtrSet op_worklist; - f.walk([&](mlir::Operation* op) { - if (returns_dynamic_shape(op)) - op_worklist.insert(op); - }); - - // Iterate on the operations in the worklist until all operations have - // been inferred or no change happened (fix point). - while (!op_worklist.empty()) { - // Find the next operation ready for inference, that is an operation - // with all operands already resolved (non-generic). - auto nextop = llvm::find_if(op_worklist, all_operands_inferred); - if (nextop == op_worklist.end()) - break; - - Operation* op = *nextop; - op_worklist.erase(op); - - if (auto shapeOp = dyn_cast(op)) { - shapeOp.infer_shapes(); - } else { - mgb_log_error( - "unable to infer shape of operation without shape " - "inference interface"); - return signalPassFailure(); - } - } - - // If the operation worklist isn't empty, this indicates a failure. - if (!op_worklist.empty()) { - mgb_log_error( - "Shape inference failed, %zu operations couldn't be " - "inferred", - op_worklist.size()); - signalPassFailure(); - } - } - - //! A utility method that returns if the given operation has all of its - //! operands inferred. - static bool all_operands_inferred(Operation* op) { - return llvm::all_of(op->getOperandTypes(), [](Type operandType) { - return operandType.isa(); - }); - } - - //! A utility method that returns if the given operation has a dynamically - //! shaped result. - static bool returns_dynamic_shape(Operation* op) { - return llvm::any_of(op->getResultTypes(), [](Type resultType) { - return !resultType.isa(); - }); - } -}; - -} // namespace - -std::unique_ptr mgb::jit::create_shape_inference_pass() { - return std::make_unique(); -} - -#endif // MGB_JIT && MGB_JIT_MLIR - -// vim: syntax=cpp.doxygen diff --git a/src/jit/impl/mlir/ir/utils.cpp b/src/jit/impl/mlir/ir/utils.cpp new file mode 100644 index 00000000..53bf42fb --- /dev/null +++ b/src/jit/impl/mlir/ir/utils.cpp @@ -0,0 +1,111 @@ +/** + * \file src/jit/impl/mlir/ir/utils.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ + +#include "megbrain_build_config.h" +#if MGB_JIT && MGB_JIT_MLIR + +#include "megbrain/common.h" +#include "megbrain/exception.h" +#include "megbrain/jit/mlir/ir/utils.h" +#include "megdnn/oprs/general.h" +#include "megdnn/basic_types.h" + +#include +#include +#include +#include +#include + +using namespace mgb; +using namespace jit; + +mlir::Value jit::insert_alloc_and_dealloc(mlir::MemRefType type, + mlir::Location loc, + mlir::PatternRewriter& rewriter) { + auto alloc = rewriter.create(loc, type); + + // Make sure to allocate at the beginning of the block. + auto* parent_block = alloc.getOperation()->getBlock(); + alloc.getOperation()->moveBefore(&parent_block->front()); + + // Make sure to deallocate this alloc at the end of the block. This is fine + // as toy functions have no control flow. + auto dealloc = rewriter.create(loc, alloc); + dealloc.getOperation()->moveBefore(&parent_block->back()); + return alloc; +} + +mlir::Type jit::deduce_result_type(mlir::ValueRange operands) { + megdnn::TensorShapeArray srcs; + megdnn::TensorShape dst; + megdnn::DType dst_type; + for (auto operand : operands) { + auto type = operand.getType().dyn_cast_or_null(); + mgb_assert(type, "currently only support MemRefType"); + + srcs.push_back(mlir_type_to_layout(type)); + } + megdnn::Elemwise::deduce_shape(srcs, dst); + mlir::Builder builder(operands[0].getContext()); + return layout_to_mlir_type({dst, mlir_type_to_dtype(operands[0].getType())}, + builder); +} + +megdnn::TensorLayout jit::mlir_type_to_layout(mlir::Type type) { + megdnn::TensorLayout ret; + if (type.isa()) { + auto real_type = type.dyn_cast_or_null(); + mgb_assert(real_type); + ret.ndim = real_type.getRank(); + for (size_t i = 0; i < ret.ndim; i++) { + ret.shape[i] = real_type.getDimSize(i); + } + ret.dtype = mlir_type_to_dtype(real_type.getElementType()); + } + return ret; +} + +megdnn::DType jit::mlir_type_to_dtype(mlir::Type type) { + mlir::Type element_type = type; + if (auto cast = type.dyn_cast_or_null()) { + element_type = cast.getElementType(); + } + switch (element_type.getKind()) { + case mlir::StandardTypes::F32: + return megdnn::dtype::Float32{}; + default: + mgb_throw(InternalError, + "Unsupport mlir type for MemRefType, got: %s\n", + mlir_type_to_string(type).c_str()); + } + return {}; +} + +mlir::MemRefType jit::layout_to_mlir_type(const megdnn::TensorLayout& layout, + mlir::Builder& builder) { + std::vector shape; + for (size_t i = 0; i < layout.ndim; i++) { + shape.push_back(layout[i]); + } + + switch (layout.dtype.enumv()) { + case megdnn::DTypeEnum::Float32: + return mlir::MemRefType::get(shape, builder.getF32Type()); + default: + mgb_throw(InternalError, "No supported dtype: %s", + layout.dtype.name()); + } +} + +#endif // MGB_JIT_MLIR + +// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} diff --git a/src/jit/impl/mlir/mlir_gen.cpp b/src/jit/impl/mlir/mlir_gen.cpp index 346a5355..3653609a 100644 --- a/src/jit/impl/mlir/mlir_gen.cpp +++ b/src/jit/impl/mlir/mlir_gen.cpp @@ -14,8 +14,10 @@ #if MGB_JIT && MGB_JIT_MLIR #include "./mlir_gen.h" -#include "./utils.h" +#include "./ir/each_mode.h" + #include "megbrain/jit/mlir/ir/dialect.h" +#include "megbrain/jit/mlir/ir/utils.h" #include "megbrain/opr/basic_arith.h" #include "megdnn/dtype.h" @@ -118,7 +120,7 @@ private: if (!return_op) { m_builder.create(m_builder.getUnknownLoc()); } - std::string op_content = to_string(func_op); + std::string op_content = mlir_type_to_string(func_op); func_op.setName( ssprintf("jit_mlir_%" PRIx64, XXHash{}.update(op_content.data(), op_content.size()) @@ -140,7 +142,8 @@ private: mgb_assert( mlir::succeeded(declare(opr->output(0)->name(), out))); } - }}.add(internal_graph.output()); + }} + .add(internal_graph.output()); m_builder.create(m_builder.getUnknownLoc(), get(internal_graph.output()), get(args.outputs[0].from)); @@ -150,11 +153,31 @@ private: mlir::Value gen_op(const opr::Elemwise& opr) { switch (opr.param().mode) { - case opr::Elemwise::Mode::ADD: - return m_builder.create(m_builder.getUnknownLoc(), - get(opr.input(0)), - get(opr.input(1))); - break; +#define cb(mlir_op, mgb_mode) \ + case opr::Elemwise::Mode::mgb_mode: \ + return m_builder.create(m_builder.getUnknownLoc(), \ + get(opr.input(0)), \ + get(opr.input(1))); \ + break; + MLIR_MGB_FOREACH_ELEMWISE_MODE_BINARY(cb) +#undef cb + +#define cb(mlir_op, mgb_mode) \ + case opr::Elemwise::Mode::mgb_mode: \ + return m_builder.create(m_builder.getUnknownLoc(), \ + get(opr.input(0))); \ + break; + MLIR_MGB_FOREACH_ELEMWISE_MODE_UNARY(cb) +#undef cb +#define cb(mlir_op, mgb_mode) \ + case opr::Elemwise::Mode::mgb_mode: \ + return m_builder.create( \ + m_builder.getUnknownLoc(), get(opr.input(0)), \ + get(opr.input(1)), get(opr.input(2))); \ + break; + MLIR_MGB_FOREACH_ELEMWISE_MODE_TERNARY(cb) +#undef cb + default: return nullptr; } @@ -162,19 +185,7 @@ private: } mlir::Type get_type(const TensorLayout& layout) { - std::vector shape; - for (size_t i = 0; i < layout.ndim; i++) { - shape.push_back(layout[i]); - } - mgb_assert(layout.ndim != 0); - switch (layout.dtype.enumv()) { - case DTypeEnum::Float32: - return mlir::MemRefType::get(shape, m_builder.getF32Type()); - default: - mgb_throw(InternalError, "No supported dtype: %s", - layout.dtype.name()); - } - return mlir::UnrankedMemRefType::get(m_builder.getNoneType(), 0); + return layout_to_mlir_type(layout, m_builder); } mlir::Value get(const VarNode* var) { diff --git a/src/jit/impl/utils.cpp b/src/jit/impl/utils.cpp index 6edc1601..83e046ea 100644 --- a/src/jit/impl/utils.cpp +++ b/src/jit/impl/utils.cpp @@ -9,12 +9,12 @@ * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ -#include "megbrain/jit/utils.h" #include "megbrain_build_config.h" #if MGB_JIT #include "megbrain/utils/debug.h" +#include "megbrain/jit/utils.h" #include diff --git a/src/jit/include/megbrain/jit/mlir/ir/dialect.h b/src/jit/include/megbrain/jit/mlir/ir/dialect.h index fd0b668e..f0ee5fe8 100644 --- a/src/jit/include/megbrain/jit/mlir/ir/dialect.h +++ b/src/jit/include/megbrain/jit/mlir/ir/dialect.h @@ -15,13 +15,14 @@ #include "megbrain_build_config.h" #if MGB_JIT && MGB_JIT_MLIR -#include +#include "megbrain/jit/mlir/ir/interfaces.h" +#include "megbrain/jit/mlir/ir/utils.h" + #include #include +#include #include -#include "megbrain/jit/mlir/ir/shape_inference_interface.h" - namespace mgb { namespace jit { diff --git a/src/jit/include/megbrain/jit/mlir/ir/shape_inference_interface.h b/src/jit/include/megbrain/jit/mlir/ir/interfaces.h similarity index 62% rename from src/jit/include/megbrain/jit/mlir/ir/shape_inference_interface.h rename to src/jit/include/megbrain/jit/mlir/ir/interfaces.h index ea5ba9e4..4803bc0d 100644 --- a/src/jit/include/megbrain/jit/mlir/ir/shape_inference_interface.h +++ b/src/jit/include/megbrain/jit/mlir/ir/interfaces.h @@ -1,5 +1,5 @@ /** - * \file src/jit/impl/mlir/ir/shape_inference_interface.h + * \file src/jit/include/mlir/ir/interfaces.h * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. @@ -13,21 +13,16 @@ #pragma once #include "megbrain_build_config.h" -#if MGB_JIT && MGB_JIT_MLIR +#if MGB_JIT_MLIR -#include "mlir/IR/OpDefinition.h" - -namespace mgb { -namespace jit { +#include +#include +namespace mlir { /// Include the auto-generated declarations. -#include "megbrain/jit/mlir/ir/shape_inference_interface.h.inc" - -} // end namespace toy -} // end namespace mlir - - +#include "megbrain/jit/mlir/ir/interfaces.h.inc" +} -#endif // MGB_JIT && MGB_JIT_MLIR +#endif // MGB_JIT_MLIR // vim: syntax=cpp.doxygen diff --git a/src/jit/include/megbrain/jit/mlir/ir/passes.h b/src/jit/include/megbrain/jit/mlir/ir/passes.h index 9ca39908..1cb2d341 100644 --- a/src/jit/include/megbrain/jit/mlir/ir/passes.h +++ b/src/jit/include/megbrain/jit/mlir/ir/passes.h @@ -13,19 +13,15 @@ #pragma once #include "megbrain_build_config.h" -#include -#include "megbrain_build_config.h" #if MGB_JIT && MGB_JIT_MLIR #include - +#include #include namespace mgb { namespace jit { -std::unique_ptr create_shape_inference_pass(); - /** * \brief Create a pass for lowering to operations in the `Affine` and `Std` * dialects, for a subset of the megbrain IR. diff --git a/src/jit/impl/mlir/utils.h b/src/jit/include/megbrain/jit/mlir/ir/utils.h similarity index 52% rename from src/jit/impl/mlir/utils.h rename to src/jit/include/megbrain/jit/mlir/ir/utils.h index 4d109af8..b9fb5fb5 100644 --- a/src/jit/impl/mlir/utils.h +++ b/src/jit/include/megbrain/jit/mlir/ir/utils.h @@ -1,13 +1,12 @@ /** - * \file src/jit/impl/mlir/utils.h + * \file src/jit/include/megbrain/mlir/ir/utils.h * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or - * implied. + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ #pragma once @@ -15,28 +14,37 @@ #include "megbrain_build_config.h" #if MGB_JIT && MGB_JIT_MLIR -#include "megbrain/common.h" -#include "megbrain/exception.h" #include "megdnn/basic_types.h" #include "megdnn/dtype.h" -#include - -#include - -#include +#include +#include +#include namespace mgb { namespace jit { template -std::string to_string(T&& t) { +std::string mlir_type_to_string(T&& t) { std::string ret; llvm::raw_string_ostream stream(ret); t.print(stream); return ret; } +mlir::Value insert_alloc_and_dealloc(mlir::MemRefType type, mlir::Location loc, + mlir::PatternRewriter& rewriter); + +mlir::Type deduce_result_type(mlir::ValueRange operands); + +/** + * \brief convert mlir type to TensorShape + */ +megdnn::TensorLayout mlir_type_to_layout(mlir::Type type); +megdnn::DType mlir_type_to_dtype(mlir::Type type); +mlir::MemRefType layout_to_mlir_type(const megdnn::TensorLayout& layout, + mlir::Builder& builder); + } // namespace jit } // namespace mgb diff --git a/src/jit/test/codegen.cpp b/src/jit/test/codegen.cpp index 11cce30e..6d97a076 100644 --- a/src/jit/test/codegen.cpp +++ b/src/jit/test/codegen.cpp @@ -9,9 +9,11 @@ * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ +#include #include "./helper.h" #include "megbrain/jit/executor_opr.h" +#include "megbrain/opr/basic_arith.h" #include "megbrain/opr/basic_arith_wrapper.h" #include "megbrain/test/helper.h" #include "megdnn/dtype.h" @@ -129,7 +131,7 @@ void run_mlir(CompNode cn) { HostTensorGenerator gen; auto host_x0 = gen({23, 42}, cn), host_x1 = gen({23, 42}, cn), - host_x2 = gen({23, 42}, cn); + host_x2 = gen({23, 42}, cn), host_x3 = gen({23, 42}, cn); auto a = opr::Host2DeviceCopy::make(*graph, host_x0), b = opr::Host2DeviceCopy::make(*graph, host_x1), @@ -137,7 +139,6 @@ void run_mlir(CompNode cn) { auto y = a + b + c; - VarNodeArray inputs{a.node(), b.node(), c.node()}, outputs{y.node()}; auto ig_gen = std::make_unique(y.node()->owner_opr()); @@ -157,6 +158,48 @@ void run_mlir(CompNode cn) { MGB_ASSERT_TENSOR_EQ(host_y, host_y_jit); } + +template +void run_mlir_mode(CompNode cn) { + set_backend(Backend::MLIR); + auto graph = ComputingGraph::make(); + float low = 0.f, high = 1.f; + if (tag::mode == opr::Elemwise::Mode::LOG) { + low = 0.1; + high = 4; + } + HostTensorGenerator gen(low, + high); + + SmallVector> hosts; + VarNodeArray input_vars; + for (int i = 0; i < arity; i++) { + hosts.push_back(gen({23, 42}, cn)); + input_vars.push_back( + opr::Host2DeviceCopy::make(*graph, hosts[i]).node()); + } + + auto y = opr::Elemwise::make(input_vars, tag::mode); + + auto ig_gen = + std::make_unique(y.node()->owner_opr()); + + for (auto i : get_rev_topo_order(y)) { + if (!i->template same_type()) { + ig_gen->add_opr(i); + } + } + + auto igraph = ig_gen->generate(); + auto y_jit = JITExecutor::make(igraph, ig_gen->orig_inps()); + + HostTensorND host_y, host_y_jit; + auto func = graph->compile({make_callback_copy(y, host_y), + make_callback_copy(y_jit, host_y_jit)}); + func->execute(); + + MGB_ASSERT_TENSOR_EQ(host_y, host_y_jit); +} #endif } // anonymous namespace @@ -191,6 +234,117 @@ TEST(TestJITMlirCodeGen, BasicGPU) { run_mlir(cn); } +///////////////////////// unary /////////////////////////////// +// clang-format off +#define FOREACH_UNARY_MODE(cb) \ + cb(RELU) \ + cb(ABS) \ + cb(NEGATE) \ + cb(CEIL) \ + cb(EXP) \ + cb(FLOOR) \ + cb(LOG) \ + cb(LOG1P) \ + cb(SIN) \ + cb(TANH) \ + cb(FAST_TANH) \ + cb(H_SWISH) \ + cb(SIGMOID) \ + cb(EXPM1) \ + cb(ROUND) +// clang-format on +template +class TestJITMlirUnaryElemwise : public ::testing::Test {}; + +#define def_tag(x) \ + struct x { \ + static constexpr opr::Elemwise::Mode mode = opr::Elemwise::Mode::x; \ + }; +FOREACH_UNARY_MODE(def_tag) +#undef def_tag + +#define t(n) n, + using mlir_elemwise_unary_types = + ::testing::Types; +#undef t +TYPED_TEST_CASE(TestJITMlirUnaryElemwise, mlir_elemwise_unary_types); +TYPED_TEST(TestJITMlirUnaryElemwise, run) { + auto cn = CompNode::load("cpu0"); + run_mlir_mode(cn); +} + +///////////////////////// binary /////////////////////////////// +// clang-format off +#define FOREACH_BINARY_MODE(cb) \ + cb(ADD) \ + cb(FLOOR_DIV) \ + cb(MUL) \ + cb(MAX) \ + cb(MIN) \ + cb(MOD) \ + cb(SUB) \ + cb(TRUE_DIV) \ + cb(ABS_GRAD) \ + cb(SIGMOID_GRAD) \ + cb(SWITCH_GT0) \ + cb(TANH_GRAD) \ + cb(LT) \ + cb(LEQ) \ + cb(EQ) \ + cb(FUSE_ADD_RELU) \ + cb(LOG_SUM_EXP) \ + cb(FUSE_ADD_TANH) \ + cb(FAST_TANH_GRAD) \ + cb(FUSE_ADD_SIGMOID) \ + cb(H_SWISH_GRAD) \ + cb(FUSE_ADD_H_SWISH) +// clang-format on +template +class TestJITMlirBinaryElemwise : public ::testing::Test {}; + +#define def_tag(x) \ + struct x { \ + static constexpr opr::Elemwise::Mode mode = opr::Elemwise::Mode::x; \ + }; +FOREACH_BINARY_MODE(def_tag) +#undef def_tag + +#define t(n) n, + using mlir_elemwise_binary_types = + ::testing::Types; +#undef t +TYPED_TEST_CASE(TestJITMlirBinaryElemwise, mlir_elemwise_binary_types); +TYPED_TEST(TestJITMlirBinaryElemwise, run) { + auto cn = CompNode::load("cpu0"); + run_mlir_mode(cn); +} + +///////////////////////// ternary /////////////////////////////// +// clang-format off +#define FOREACH_TERNARY_MODE(cb) \ + cb(COND_LEQ_MOV) \ + cb(FUSE_MUL_ADD3) \ +// clang-format on +template +class TestJITMlirTernaryElemwise : public ::testing::Test {}; + +#define def_tag(x) \ + struct x { \ + static constexpr opr::Elemwise::Mode mode = opr::Elemwise::Mode::x; \ + }; +FOREACH_TERNARY_MODE(def_tag) +#undef def_tag + +#define t(n) n, + using mlir_elemwise_ternary_types = + ::testing::Types; +#undef t +TYPED_TEST_CASE(TestJITMlirTernaryElemwise, mlir_elemwise_ternary_types); +TYPED_TEST(TestJITMlirTernaryElemwise, run) { + auto cn = CompNode::load("cpu0"); + run_mlir_mode(cn); +} + #endif #endif // MGB_JIT diff --git a/src/opr/include/megbrain/opr/basic_arith_wrapper.h b/src/opr/include/megbrain/opr/basic_arith_wrapper.h index 1199176f..2397e83b 100644 --- a/src/opr/include/megbrain/opr/basic_arith_wrapper.h +++ b/src/opr/include/megbrain/opr/basic_arith_wrapper.h @@ -57,6 +57,7 @@ namespace opr { EL2(and_, AND) EL2(or_, OR) EL2(xor_, XOR) + EL2(mod, MOD) #undef EL2