You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

lower_to_affine_pass.cpp 5.8 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159
  1. /**
  2. * \file src/jit/impl/mlir/ir/lower_to_affine_pass.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
  10. * implied.
  11. */
  12. #include "megbrain_build_config.h"
  13. #if MGB_JIT && MGB_JIT_MLIR
  14. #include "megbrain/common.h"
  15. #include "megbrain/jit/mlir/ir/dialect.h"
  16. #include "megbrain/jit/mlir/ir/passes.h"
  17. #include "./common.h"
  18. #include <mlir/Dialect/Affine/IR/AffineOps.h>
  19. #include <mlir/Dialect/StandardOps/IR/Ops.h>
  20. #include <mlir/Pass/Pass.h>
  21. #include <mlir/Transforms/DialectConversion.h>
  22. #include <llvm/ADT/Sequence.h>
  23. using namespace mgb;
  24. using namespace jit;
  25. namespace {
  26. using LoopIterationFn = function_ref<Value(
  27. OpBuilder& rewriter, ValueRange memRefOperands, ValueRange loopIvs)>;
  28. void lower_op_to_loops(Operation* op, ValueRange operands,
  29. PatternRewriter& rewriter,
  30. LoopIterationFn process_iteration) {
  31. auto memref_type = (*op->result_type_begin()).cast<MemRefType>();
  32. auto loc = op->getLoc();
  33. auto alloc = jit::insert_alloc_and_dealloc(memref_type, loc, rewriter);
  34. SmallVector<int64_t, 4> lower_bounds(memref_type.getRank(), 0);
  35. SmallVector<int64_t, 4> steps(memref_type.getRank(), 1);
  36. buildAffineLoopNest(
  37. rewriter, loc, lower_bounds, memref_type.getShape(), steps,
  38. [&](OpBuilder& nested_builder, Location loc, ValueRange ivs) {
  39. Value value_to_store =
  40. process_iteration(nested_builder, operands, ivs);
  41. nested_builder.create<AffineStoreOp>(loc, value_to_store, alloc,
  42. ivs);
  43. });
  44. // Replace this operation with the generated alloc.
  45. rewriter.replaceOp(op, alloc);
  46. }
  47. template <typename BinaryOp, typename LoweredBinaryOp>
  48. struct BinaryOpLowering : public ConversionPattern {
  49. BinaryOpLowering(MLIRContext* ctx)
  50. : ConversionPattern(BinaryOp::getOperationName(), 1, ctx) {}
  51. LogicalResult matchAndRewrite(
  52. Operation* op, ArrayRef<Value> operands,
  53. ConversionPatternRewriter& rewriter) const final {
  54. auto loc = op->getLoc();
  55. lower_op_to_loops(
  56. op, operands, rewriter,
  57. [loc](OpBuilder& builder, ValueRange memref_operands,
  58. ValueRange loop_ivs) {
  59. typename BinaryOp::Adaptor binary_adaptor(memref_operands);
  60. auto loaded_lhs = builder.create<AffineLoadOp>(
  61. loc, binary_adaptor.lhs(), loop_ivs);
  62. auto loaded_rhs = builder.create<AffineLoadOp>(
  63. loc, binary_adaptor.rhs(), loop_ivs);
  64. return builder.create<LoweredBinaryOp>(loc, loaded_lhs,
  65. loaded_rhs);
  66. });
  67. return success();
  68. }
  69. };
  70. using AddOpLowering = BinaryOpLowering<jit::AddOp, AddFOp>;
  71. struct AssignOpLowering : public ConversionPattern {
  72. AssignOpLowering(MLIRContext* ctx)
  73. : ConversionPattern(jit::AssignOp::getOperationName(), 1, ctx) {}
  74. LogicalResult matchAndRewrite(
  75. Operation* op, ArrayRef<Value> operands,
  76. ConversionPatternRewriter& rewriter) const final {
  77. auto loc = op->getLoc();
  78. auto memref_type = operands[0].getType().cast<MemRefType>();
  79. AssignOpAdaptor assign_adaptor(operands);
  80. SmallVector<int64_t, 4> lower_bounds(memref_type.getRank(), 0);
  81. SmallVector<int64_t, 4> steps(memref_type.getRank(), 1);
  82. buildAffineLoopNest(
  83. rewriter, loc, lower_bounds, memref_type.getShape(), steps,
  84. [&](OpBuilder& nested_builder, Location loc, ValueRange ivs) {
  85. auto loaded_lhs = nested_builder.create<AffineLoadOp>(
  86. loc, assign_adaptor.lhs(), ivs);
  87. nested_builder.create<AffineStoreOp>(
  88. loc, loaded_lhs, assign_adaptor.rhs(), ivs);
  89. });
  90. rewriter.eraseOp(op);
  91. return success();
  92. }
  93. };
  94. struct ReturnOpLowering : public OpRewritePattern<jit::ReturnOp> {
  95. using OpRewritePattern<jit::ReturnOp>::OpRewritePattern;
  96. LogicalResult matchAndRewrite(jit::ReturnOp op,
  97. PatternRewriter& rewriter) const final {
  98. rewriter.replaceOpWithNewOp<mlir::ReturnOp>(op);
  99. return success();
  100. }
  101. };
  102. class MgbToAffineLoweringPass
  103. : public PassWrapper<MgbToAffineLoweringPass, FunctionPass> {
  104. public:
  105. void runOnFunction() override final {
  106. auto function = getFunction();
  107. // Verify that the given main has no inputs and results.
  108. if (function.getType().getNumResults()) {
  109. mgb_log_error("expected 'main' to have 0 results");
  110. return signalPassFailure();
  111. }
  112. ConversionTarget target(getContext());
  113. target.addLegalDialect<AffineDialect, StandardOpsDialect>();
  114. target.addIllegalDialect<MgbDialect>();
  115. OwningRewritePatternList patterns;
  116. patterns.insert<AddOpLowering, ReturnOpLowering, AssignOpLowering>(
  117. &getContext());
  118. if (failed(applyPartialConversion(getFunction(), target, patterns))) {
  119. signalPassFailure();
  120. }
  121. }
  122. };
  123. } // namespace
  124. std::unique_ptr<mlir::Pass> mgb::jit::create_lower_to_affine_pass() {
  125. return std::make_unique<MgbToAffineLoweringPass>();
  126. }
  127. #endif // MGB_JIT_MLIR
  128. // vim: syntax=cpp.doxygen

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台