/** * \file src/jit/impl/mlir/ir/lower_to_affine_pass.cpp * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or * implied. */ #include "megbrain_build_config.h" #if MGB_JIT && MGB_JIT_MLIR #include "megbrain/common.h" #include "megbrain/jit/mlir/ir/dialect.h" #include "megbrain/jit/mlir/ir/passes.h" #include "./common.h" #include #include #include #include #include using namespace mgb; using namespace jit; namespace { using LoopIterationFn = function_ref; void lower_op_to_loops(Operation* op, ValueRange operands, PatternRewriter& rewriter, LoopIterationFn process_iteration) { auto memref_type = (*op->result_type_begin()).cast(); auto loc = op->getLoc(); auto alloc = jit::insert_alloc_and_dealloc(memref_type, loc, rewriter); SmallVector lower_bounds(memref_type.getRank(), 0); SmallVector steps(memref_type.getRank(), 1); buildAffineLoopNest( rewriter, loc, lower_bounds, memref_type.getShape(), steps, [&](OpBuilder& nested_builder, Location loc, ValueRange ivs) { Value value_to_store = process_iteration(nested_builder, operands, ivs); nested_builder.create(loc, value_to_store, alloc, ivs); }); // Replace this operation with the generated alloc. rewriter.replaceOp(op, alloc); } template struct BinaryOpLowering : public ConversionPattern { BinaryOpLowering(MLIRContext* ctx) : ConversionPattern(BinaryOp::getOperationName(), 1, ctx) {} LogicalResult matchAndRewrite( Operation* op, ArrayRef operands, ConversionPatternRewriter& rewriter) const final { auto loc = op->getLoc(); lower_op_to_loops( op, operands, rewriter, [loc](OpBuilder& builder, ValueRange memref_operands, ValueRange loop_ivs) { typename BinaryOp::Adaptor binary_adaptor(memref_operands); auto loaded_lhs = builder.create( loc, binary_adaptor.lhs(), loop_ivs); auto loaded_rhs = builder.create( loc, binary_adaptor.rhs(), loop_ivs); return builder.create(loc, loaded_lhs, loaded_rhs); }); return success(); } }; using AddOpLowering = BinaryOpLowering; struct AssignOpLowering : public ConversionPattern { AssignOpLowering(MLIRContext* ctx) : ConversionPattern(jit::AssignOp::getOperationName(), 1, ctx) {} LogicalResult matchAndRewrite( Operation* op, ArrayRef operands, ConversionPatternRewriter& rewriter) const final { auto loc = op->getLoc(); auto memref_type = operands[0].getType().cast(); AssignOpAdaptor assign_adaptor(operands); SmallVector lower_bounds(memref_type.getRank(), 0); SmallVector steps(memref_type.getRank(), 1); buildAffineLoopNest( rewriter, loc, lower_bounds, memref_type.getShape(), steps, [&](OpBuilder& nested_builder, Location loc, ValueRange ivs) { auto loaded_lhs = nested_builder.create( loc, assign_adaptor.lhs(), ivs); nested_builder.create( loc, loaded_lhs, assign_adaptor.rhs(), ivs); }); rewriter.eraseOp(op); return success(); } }; struct ReturnOpLowering : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(jit::ReturnOp op, PatternRewriter& rewriter) const final { rewriter.replaceOpWithNewOp(op); return success(); } }; class MgbToAffineLoweringPass : public PassWrapper { public: void runOnFunction() override final { auto function = getFunction(); // Verify that the given main has no inputs and results. if (function.getType().getNumResults()) { mgb_log_error("expected 'main' to have 0 results"); return signalPassFailure(); } ConversionTarget target(getContext()); target.addLegalDialect(); target.addIllegalDialect(); OwningRewritePatternList patterns; patterns.insert( &getContext()); if (failed(applyPartialConversion(getFunction(), target, patterns))) { signalPassFailure(); } } }; } // namespace std::unique_ptr mgb::jit::create_lower_to_affine_pass() { return std::make_unique(); } #endif // MGB_JIT_MLIR // vim: syntax=cpp.doxygen