From 5d0f8da46a4975c3c1a290299b839411faee732a Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Thu, 19 Nov 2020 23:34:06 +0800 Subject: [PATCH] feat(mgb/jit): add Dimshuffle and lowering passes in jit mlir backend GitOrigin-RevId: ce6f4ea42a876fafbb7ca67f30d5c0fa96d28096 --- src/jit/impl/mlir/compiler.h | 2 +- src/jit/impl/mlir/ir/lower_to_affine_pass.cpp | 43 ++++++++++++++++++++++-- src/jit/impl/mlir/ir/lower_to_gpu_pass.cpp | 47 +++++++++++++++++++++++++-- src/jit/impl/mlir/mlir_gen.cpp | 39 +++++++++++++++++++--- src/jit/test/codegen.cpp | 46 ++++++++++++++++++++++++++ 5 files changed, 166 insertions(+), 11 deletions(-) diff --git a/src/jit/impl/mlir/compiler.h b/src/jit/impl/mlir/compiler.h index 594a04ec..246b748e 100644 --- a/src/jit/impl/mlir/compiler.h +++ b/src/jit/impl/mlir/compiler.h @@ -34,7 +34,7 @@ public: Property property() const override { using F = Property::Flag; return Property{F::NEED_INPUT_COLLAPSE | F::BIND_NDIM, - JITFeatureBits::NONE, 64}; + JITFeatureBits::DIMSHUFFLE, 64}; } size_t get_nr_workspace_outputs(JITExecutor* opr) const override; diff --git a/src/jit/impl/mlir/ir/lower_to_affine_pass.cpp b/src/jit/impl/mlir/ir/lower_to_affine_pass.cpp index 595e2fef..c99867ba 100644 --- a/src/jit/impl/mlir/ir/lower_to_affine_pass.cpp +++ b/src/jit/impl/mlir/ir/lower_to_affine_pass.cpp @@ -62,6 +62,7 @@ struct ElemwiseLowering : public ConversionPattern { ElemwiseLowering(MLIRContext* ctx) : ConversionPattern(mgb::dialect::Elemwise::getOperationName(), 1, ctx) {} + LogicalResult matchAndRewrite( Operation* op, ArrayRef operands, ConversionPatternRewriter& rewriter) const final { @@ -89,6 +90,7 @@ struct TypeCvtLowering : public ConversionPattern { TypeCvtLowering(MLIRContext* ctx) : ConversionPattern(mgb::dialect::TypeCvt::getOperationName(), 1, ctx) {} + LogicalResult matchAndRewrite( Operation* op, ArrayRef operands, ConversionPatternRewriter& rewriter) const final { @@ -105,6 +107,41 @@ struct TypeCvtLowering : public ConversionPattern { } }; +struct DimshuffleLowering : public ConversionPattern { + DimshuffleLowering(MLIRContext* ctx) + : ConversionPattern(mgb::dialect::Dimshuffle::getOperationName(), 1, + ctx) {} + + static mlir::AffineMap get_affinemap_from_pattern( + const std::vector& pattern, mlir::MLIRContext* ctx) { + size_t ndim = *std::max_element(pattern.begin(), pattern.end()) + 1; + std::vector exprs(ndim); + for (size_t i = 0; i < pattern.size(); i++) { + int32_t j = pattern[i]; + if (j >= 0) { + exprs[j] = mlir::getAffineDimExpr(i, ctx); + } + } + return mlir::AffineMap::get(pattern.size(), 0, exprs, ctx); + } + + LogicalResult matchAndRewrite( + Operation* op, ArrayRef operands, + ConversionPatternRewriter& rewriter) const final { + auto loc = op->getLoc(); + auto pattern = llvm::dyn_cast(op).pattern(); + auto map = get_affinemap_from_pattern(pattern, op->getContext()); + lower_op_to_loops( + op, operands, rewriter, + [loc, op, &map](OpBuilder& builder, ValueRange memref_operands, + ValueRange loop_ivs) { + return builder.create(loc, memref_operands[0], + map, loop_ivs); + }); + return success(); + } +}; + struct AssignOpLowering : public ConversionPattern { AssignOpLowering(MLIRContext* ctx) : ConversionPattern(dialect::AssignOp::getOperationName(), 1, ctx) { @@ -172,9 +209,9 @@ public: target.addIllegalDialect(); OwningRewritePatternList patterns; - patterns.insert( - &getContext()); + patterns.insert(&getContext()); if (failed(applyPartialConversion(getFunction(), target, std::move(patterns)))) { diff --git a/src/jit/impl/mlir/ir/lower_to_gpu_pass.cpp b/src/jit/impl/mlir/ir/lower_to_gpu_pass.cpp index f0580cad..3ab9166e 100644 --- a/src/jit/impl/mlir/ir/lower_to_gpu_pass.cpp +++ b/src/jit/impl/mlir/ir/lower_to_gpu_pass.cpp @@ -152,6 +152,47 @@ private: gpu::LaunchOp& m_launch_op; }; +struct DimshuffleLowering : public ConversionPattern { + DimshuffleLowering(MLIRContext* ctx, gpu::LaunchOp& launch_op) + : ConversionPattern(dialect::Dimshuffle::getOperationName(), 1, + ctx), + m_launch_op{launch_op} {} + + static std::vector get_index_from_pattern( + const std::vector& pattern, + const std::vector& index) { + size_t ndim = *std::max_element(pattern.begin(), pattern.end()) + 1; + std::vector res(ndim); + for (size_t i = 0; i < pattern.size(); i++) { + int32_t j = pattern[i]; + if (j >= 0) { + res[j] = index[i]; + } + } + return res; + } + + LogicalResult matchAndRewrite( + Operation* op, ArrayRef operands, + ConversionPatternRewriter& rewriter) const final { + auto loc = op->getLoc(); + + rewriter.setInsertionPointToEnd(&(m_launch_op.body().front())); + + auto dst_layout = output_layout(m_launch_op); + auto index = get_multidim_tid(rewriter, loc, operands[0], dst_layout); + auto pattern = llvm::dyn_cast(op).pattern(); + auto shuffled_index = get_index_from_pattern(pattern, index); + + rewriter.replaceOp(op, get_operand(rewriter, loc, operands[0], + shuffled_index)); + return success(); + } + +private: + gpu::LaunchOp& m_launch_op; +}; + struct ReturnOpLowering : public ConversionPattern { ReturnOpLowering(MLIRContext* ctx, gpu::LaunchOp& launch_op) : ConversionPattern(dialect::ReturnOp::getOperationName(), 1, ctx), @@ -275,9 +316,9 @@ public: target.addLegalDialect(); target.addIllegalDialect(); - patterns.insert( - &getContext(), launch_op); + patterns.insert(&getContext(), launch_op); if (failed(applyPartialConversion(func_op, target, std::move(patterns)))) { diff --git a/src/jit/impl/mlir/mlir_gen.cpp b/src/jit/impl/mlir/mlir_gen.cpp index 399cabe2..93b1048b 100644 --- a/src/jit/impl/mlir/mlir_gen.cpp +++ b/src/jit/impl/mlir/mlir_gen.cpp @@ -20,6 +20,7 @@ #include "megbrain/jit/mlir/ir/dialect.h" #include "megbrain/jit/mlir/ir/utils.h" #include "megbrain/opr/basic_arith.h" +#include "megbrain/opr/tensor_manip.h" #include "megdnn/dtype.h" #include @@ -160,6 +161,10 @@ private: mgb_assert( mlir::succeeded(declare(opr->output(0)->name(), out))); return; + } else if (opr->same_type()) { + auto&& out = gen_dimshuffle(opr->cast_final()); + mgb_assert( + mlir::succeeded(declare(opr->output(0)->name(), out))); } else if (opr->same_type()) { auto&& out = gen_typecvt(opr->cast_final()); mgb_assert( @@ -186,18 +191,44 @@ private: } mlir::Value gen_typecvt(const opr::TypeCvt& opr) { - auto shape = get(opr.input(0)) + auto itype = get(opr.input(0)) .getType() - .dyn_cast_or_null() - .getShape(); + .dyn_cast_or_null(); + mgb_assert(itype, "currently only support MemRefType"); auto res_type = mlir::MemRefType::get( - shape, + itype.getShape(), megdnn_dtype_to_mlir_type(opr.param(), m_builder.getContext())); return m_builder.create( m_builder.getUnknownLoc(), res_type, get(opr.input(0)), opr.input(0)->dtype(), opr.param()); } + mlir::Value gen_dimshuffle(const opr::Dimshuffle& opr) { + auto itype = get(opr.input(0)) + .getType() + .dyn_cast_or_null(); + mgb_assert(itype, "the input type of Dimshuffle must be MemRefType"); + auto ishape = itype.getShape(); + auto param = opr.param(); + + std::vector pattern; + std::vector oshape; + for (size_t i = 0; i < param.pattern_len; i++) { + int32_t j = param.pattern[i]; + pattern.push_back(j); + if (j < 0) { + oshape.push_back(1); + } else { + oshape.push_back(ishape[j]); + } + } + auto res_type = mlir::MemRefType::get(oshape, itype.getElementType()); + + return m_builder.create( + m_builder.getUnknownLoc(), res_type, get(opr.input(0)), + pattern); + } + mlir::Type get_type(const TensorLayout& layout) { return layout_to_mlir_type(layout, m_builder); } diff --git a/src/jit/test/codegen.cpp b/src/jit/test/codegen.cpp index 9b6f4473..1e07c87b 100644 --- a/src/jit/test/codegen.cpp +++ b/src/jit/test/codegen.cpp @@ -15,6 +15,7 @@ #include "megbrain/jit/executor_opr.h" #include "megbrain/opr/basic_arith.h" #include "megbrain/opr/basic_arith_wrapper.h" +#include "megbrain/opr/tensor_manip.h" #include "megbrain/test/helper.h" #include "megdnn/dtype.h" @@ -539,6 +540,51 @@ add_typecvt_gtest(Uint8, Float32); #undef add_typecvt_gtest +/* ===================== TestJITMlirDimshuffle ===================== */ + +void run_dimshuffle(CompNode cn, TensorShape ishape, + const std::vector& pattern) { + set_backend(Backend::MLIR); + auto graph = ComputingGraph::make(); + HostTensorGenerator<> gen; + + auto host_x = gen(ishape, cn); + auto x = opr::Host2DeviceCopy::make(*graph, host_x); + auto y = opr::Dimshuffle::make(x, pattern); + + auto ig_gen = std::make_unique(y.node()->owner_opr()); + + for (auto i : get_rev_topo_order(y)) { + if (!i->template same_type()) { + ig_gen->add_opr(i); + } + } + + auto igraph = ig_gen->generate(); + auto y_jit = JITExecutor::make(igraph, ig_gen->orig_inps()); + + HostTensorND host_y, host_y_jit; + auto func = graph->compile({make_callback_copy(y, host_y), + make_callback_copy(y_jit, host_y_jit)}); + func->execute(); + + MGB_ASSERT_TENSOR_EQ(host_y, host_y_jit); +} + +void run_dimshuffle_cases(CompNode cn) { + run_dimshuffle(cn, {3, 4, 5}, {2, 0, 1}); + run_dimshuffle(cn, {3, 4, 5}, {1, -1, 0, 2}); +} + +TEST(TestJITMlirDimshuffle, Basic) { + run_dimshuffle_cases(CompNode::load("cpu0")); +} + +TEST(TestJITMlirDimshuffle, BasicGPU) { + REQUIRE_GPU(1); + run_dimshuffle_cases(CompNode::load("gpu0")); +} + #endif // MGB_JIT_MLIR #endif // MGB_JIT