You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

lower_to_gpu_pass.cpp 7.3 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211
  1. /**
  2. * \file src/jit/impl/mlir/ir/lower_to_gpu_pass.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
  10. * implied.
  11. */
  12. #include "megbrain_build_config.h"
  13. #if MGB_JIT && MGB_JIT_MLIR
  14. #include "megbrain/common.h"
  15. #include "megbrain/jit/mlir/ir/dialect.h"
  16. #include "megbrain/jit/mlir/ir/passes.h"
  17. #include "../utils.h"
  18. #include <mlir/Dialect/GPU/GPUDialect.h>
  19. #include <mlir/Dialect/SCF/SCF.h>
  20. #include <mlir/Dialect/StandardOps/IR/Ops.h>
  21. #include <mlir/EDSC/Builders.h>
  22. #include <mlir/IR/StandardTypes.h>
  23. #include <mlir/Pass/Pass.h>
  24. #include <mlir/Transforms/DialectConversion.h>
  25. #include <llvm/ADT/PointerUnion.h>
  26. #include <llvm/ADT/Sequence.h>
  27. #include <llvm/ADT/SetVector.h>
  28. #include <llvm/ADT/Twine.h>
  29. #include <llvm/IR/Type.h>
  30. using namespace mgb;
  31. using namespace jit;
  32. namespace {
  33. mlir::Value get_operand(ConversionPatternRewriter& rewriter,
  34. const mlir::Location& loc, const mlir::Value& val,
  35. const mlir::Value& index) {
  36. if (val.getType().isa<mlir::MemRefType>()) {
  37. return rewriter.create<LoadOp>(loc, val, index);
  38. } else {
  39. return val;
  40. }
  41. }
  42. mlir::Value get_tid(ConversionPatternRewriter& rewriter, const Location& loc) {
  43. auto thread_idx = rewriter.create<gpu::ThreadIdOp>(
  44. loc, rewriter.getIndexType(), rewriter.getStringAttr("x"));
  45. auto block_idx = rewriter.create<gpu::BlockIdOp>(
  46. loc, rewriter.getIndexType(), rewriter.getStringAttr("x"));
  47. auto group_size = rewriter.create<gpu::BlockDimOp>(
  48. loc, rewriter.getIndexType(), rewriter.getStringAttr("x"));
  49. Value index = rewriter.create<AddIOp>(
  50. loc, thread_idx,
  51. rewriter.create<MulIOp>(loc, block_idx, group_size));
  52. return index;
  53. }
  54. template <typename BinaryOp, typename LoweredBinaryOp>
  55. struct BinaryOpLowering : public ConversionPattern {
  56. BinaryOpLowering(MLIRContext* ctx, gpu::LaunchOp* launch_op)
  57. : ConversionPattern(BinaryOp::getOperationName(), 1, ctx),
  58. m_launch_op{launch_op} {}
  59. LogicalResult matchAndRewrite(
  60. Operation* op, ArrayRef<Value> operands,
  61. ConversionPatternRewriter& rewriter) const final {
  62. auto loc = op->getLoc();
  63. typename BinaryOp::Adaptor binary_adaptor(operands);
  64. rewriter.setInsertionPointToEnd(&(m_launch_op->body().front()));
  65. auto index = get_tid(rewriter, loc);
  66. auto loaded_lhs =
  67. get_operand(rewriter, loc, binary_adaptor.lhs(), index);
  68. auto loaded_rhs =
  69. get_operand(rewriter, loc, binary_adaptor.rhs(), index);
  70. auto binary_op =
  71. rewriter.create<LoweredBinaryOp>(loc, loaded_lhs, loaded_rhs);
  72. rewriter.replaceOp(op, binary_op.getResult());
  73. return success();
  74. }
  75. private:
  76. gpu::LaunchOp* m_launch_op;
  77. };
  78. using AddOpLowering = BinaryOpLowering<jit::AddOp, AddFOp>;
  79. struct ReturnOpLowering : public ConversionPattern {
  80. ReturnOpLowering(MLIRContext* ctx, gpu::LaunchOp* launch_op)
  81. : ConversionPattern(jit::ReturnOp::getOperationName(), 1, ctx),
  82. m_launch_op{launch_op} {}
  83. LogicalResult matchAndRewrite(
  84. Operation* op, ArrayRef<Value>,
  85. ConversionPatternRewriter& rewriter) const final {
  86. rewriter.replaceOpWithNewOp<mlir::ReturnOp>(op);
  87. auto loc = op->getLoc();
  88. //! remove the first gpu.terminator
  89. m_launch_op->body().front().front().erase();
  90. //! if (tid >= nr_tid) {return;} in the begin of the block
  91. rewriter.setInsertionPointToStart(&(m_launch_op->body().front()));
  92. Block* cond_block = rewriter.getInsertionBlock();
  93. Block::iterator op_position = rewriter.getInsertionPoint();
  94. Block* remaining_ops_block =
  95. rewriter.splitBlock(cond_block, op_position);
  96. rewriter.setInsertionPointToEnd(cond_block);
  97. auto index = get_tid(rewriter, loc);
  98. auto comparison = rewriter.create<mlir::CmpIOp>(
  99. loc, CmpIPredicate::sge, index,
  100. m_launch_op->getParentOfType<mlir::FuncOp>()
  101. .getArguments()
  102. .back());
  103. Block* then_block =
  104. rewriter.splitBlock(cond_block, rewriter.getInsertionPoint());
  105. rewriter.setInsertionPointToEnd(then_block);
  106. rewriter.create<gpu::TerminatorOp>(loc);
  107. rewriter.setInsertionPointToEnd(cond_block);
  108. rewriter.create<mlir::CondBranchOp>(
  109. loc, comparison, then_block, ArrayRef<Value>(),
  110. remaining_ops_block, ArrayRef<Value>());
  111. rewriter.setInsertionPointToEnd(remaining_ops_block);
  112. rewriter.create<gpu::TerminatorOp>(loc);
  113. return success();
  114. }
  115. private:
  116. gpu::LaunchOp* m_launch_op;
  117. };
  118. struct AssignOpLowering : public ConversionPattern {
  119. AssignOpLowering(MLIRContext* ctx, gpu::LaunchOp* launch_op)
  120. : ConversionPattern(jit::AssignOp::getOperationName(), 2, ctx),
  121. m_launch_op{launch_op} {}
  122. LogicalResult matchAndRewrite(
  123. Operation* op, ArrayRef<Value> operands,
  124. ConversionPatternRewriter& rewriter) const final {
  125. auto loc = op->getLoc();
  126. AssignOpAdaptor assign_adaptor(operands);
  127. rewriter.setInsertionPointToEnd(&(m_launch_op->body().front()));
  128. auto index = get_tid(rewriter, loc);
  129. auto loaded_lhs =
  130. get_operand(rewriter, loc, assign_adaptor.lhs(), index);
  131. rewriter.create<StoreOp>(loc, loaded_lhs, assign_adaptor.rhs(), index);
  132. rewriter.eraseOp(op);
  133. return success();
  134. }
  135. private:
  136. gpu::LaunchOp* m_launch_op;
  137. };
  138. class MgbToGpuLoweringPass
  139. : public PassWrapper<MgbToGpuLoweringPass, FunctionPass> {
  140. public:
  141. void runOnFunction() override final {
  142. auto func_op = getFunction();
  143. Location loc = func_op.getLoc();
  144. OpBuilder builder(&func_op.getBody());
  145. Value constantOne = builder.create<ConstantIndexOp>(loc, 1);
  146. gpu::LaunchOp launch_op = builder.create<gpu::LaunchOp>(
  147. loc, constantOne, constantOne, constantOne, constantOne,
  148. constantOne, constantOne);
  149. builder.setInsertionPointToEnd(&(launch_op.body().front()));
  150. builder.create<gpu::TerminatorOp>(loc);
  151. OwningRewritePatternList patterns;
  152. ConversionTarget target(getContext());
  153. target.addLegalDialect<StandardOpsDialect>();
  154. target.addLegalDialect<gpu::GPUDialect>();
  155. target.addIllegalDialect<MgbDialect>();
  156. patterns.insert<AddOpLowering, AssignOpLowering, ReturnOpLowering>(
  157. &getContext(), &launch_op);
  158. if (failed(applyPartialConversion(func_op, target, patterns))) {
  159. signalPassFailure();
  160. }
  161. }
  162. };
  163. } // namespace
  164. std::unique_ptr<mlir::Pass> mgb::jit::create_lower_to_gpu_pass() {
  165. return std::make_unique<MgbToGpuLoweringPass>();
  166. }
  167. #endif // MGB_JIT_MLIR
  168. // vim: syntax=cpp.doxygen

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台