You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

lower_to_gpu_pass.cpp 12 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338
  1. /**
  2. * \file src/jit/impl/mlir/ir/lower_to_gpu_pass.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
  10. * implied.
  11. */
  12. #include "megbrain_build_config.h"
  13. #if MGB_JIT && MGB_JIT_MLIR
  14. #include "./common.h"
  15. #include "./each_mode.h"
  16. #include "megbrain/common.h"
  17. #include "megbrain/jit/mlir/ir/dialect.h"
  18. #include "megbrain/jit/mlir/ir/passes.h"
  19. #include "megbrain/jit/mlir/ir/utils.h"
  20. #include <llvm/ADT/PointerUnion.h>
  21. #include <llvm/ADT/Sequence.h>
  22. #include <llvm/ADT/SetVector.h>
  23. #include <llvm/ADT/Twine.h>
  24. #include <llvm/IR/Type.h>
  25. #include <mlir/Dialect/GPU/GPUDialect.h>
  26. #include <mlir/Dialect/SCF/SCF.h>
  27. #include <mlir/Dialect/StandardOps/IR/Ops.h>
  28. #include <mlir/EDSC/Builders.h>
  29. #include <mlir/IR/StandardTypes.h>
  30. #include <mlir/Pass/Pass.h>
  31. #include <mlir/Transforms/DialectConversion.h>
  32. using namespace mgb;
  33. using namespace jit;
  34. namespace {
  35. mlir::Value get_tid(ConversionPatternRewriter& rewriter, const Location& loc) {
  36. auto thread_idx = rewriter.create<gpu::ThreadIdOp>(
  37. loc, rewriter.getIndexType(), rewriter.getStringAttr("x"));
  38. auto block_idx = rewriter.create<gpu::BlockIdOp>(
  39. loc, rewriter.getIndexType(), rewriter.getStringAttr("x"));
  40. auto group_size = rewriter.create<gpu::BlockDimOp>(
  41. loc, rewriter.getIndexType(), rewriter.getStringAttr("x"));
  42. Value index = rewriter.create<AddIOp>(
  43. loc, thread_idx,
  44. rewriter.create<MulIOp>(loc, block_idx, group_size));
  45. return index;
  46. }
  47. megdnn::TensorLayout output_layout(gpu::LaunchOp& launch_op) {
  48. auto func_op = launch_op.getParentOfType<mlir::FuncOp>();
  49. mgb_assert(func_op, "Unexpexted launch op.");
  50. for (auto block_iter = func_op.rbegin(); block_iter != func_op.rend();
  51. block_iter++) {
  52. for (auto op_iter = block_iter->rbegin(); op_iter != block_iter->rend();
  53. op_iter++) {
  54. auto op = llvm::dyn_cast_or_null<dialect::AssignOp>(&(*op_iter));
  55. if (op && op.getNumOperands() > 0) {
  56. return mlir_type_to_layout(*(op.operand_type_begin()));
  57. }
  58. }
  59. }
  60. mgb_throw(MegBrainError, "Unexpexted launch op.");
  61. }
  62. std::vector<mlir::Value> get_multidim_tid(ConversionPatternRewriter& rewriter,
  63. const Location& loc,
  64. const mlir::Value& val,
  65. const megdnn::TensorLayout& dst) {
  66. Value index = get_tid(rewriter, loc);
  67. auto type = val.getType().dyn_cast_or_null<mlir::MemRefType>();
  68. if (type) {
  69. ValueBuilderHelper helper(rewriter, loc);
  70. std::vector<mlir::Value> idxs;
  71. idxs.resize(dst.ndim);
  72. mlir::Value dim_index = index;
  73. for (int i = dst.ndim - 1; i >= 0; i--) {
  74. auto cur_index = helper.modI(dim_index, helper.const_i32(dst[i]));
  75. idxs[i] = cur_index;
  76. dim_index = helper.divI(dim_index, helper.const_i32(dst[i]));
  77. }
  78. megdnn::TensorLayout src_layout = mlir_type_to_layout(type);
  79. src_layout.init_contiguous_stride();
  80. for (int i = 0; i < type.getRank(); ++i) {
  81. if (src_layout[i] == 1) {
  82. idxs[i] = helper.const_i32(0);
  83. }
  84. }
  85. return idxs;
  86. } else {
  87. return {index};
  88. }
  89. }
  90. struct ElemwiseLowering : public ConversionPattern {
  91. ElemwiseLowering(MLIRContext* ctx, gpu::LaunchOp& launch_op)
  92. : ConversionPattern(dialect::Elemwise::getOperationName(), 1, ctx),
  93. m_launch_op{launch_op} {}
  94. LogicalResult matchAndRewrite(
  95. Operation* op, ArrayRef<Value> operands,
  96. ConversionPatternRewriter& rewriter) const final {
  97. auto loc = op->getLoc();
  98. rewriter.setInsertionPointToEnd(&(m_launch_op.body().front()));
  99. auto dst_layout = output_layout(m_launch_op);
  100. auto inputs = llvm::to_vector<4>(
  101. llvm::map_range(operands, [&](mlir::Value val) {
  102. auto index =
  103. get_multidim_tid(rewriter, loc, val, dst_layout);
  104. return get_operand<LoadOp>(rewriter, loc, val, index);
  105. }));
  106. rewriter.replaceOp(op,
  107. lower_elemwise_to_std(op, rewriter, loc, inputs));
  108. return success();
  109. }
  110. private:
  111. gpu::LaunchOp& m_launch_op;
  112. };
  113. struct TypeCvtLowering : public ConversionPattern {
  114. TypeCvtLowering(MLIRContext* ctx, gpu::LaunchOp& launch_op)
  115. : ConversionPattern(dialect::TypeCvt::getOperationName(), 1, ctx),
  116. m_launch_op{launch_op} {}
  117. LogicalResult matchAndRewrite(
  118. Operation* op, ArrayRef<Value> operands,
  119. ConversionPatternRewriter& rewriter) const final {
  120. auto loc = op->getLoc();
  121. rewriter.setInsertionPointToEnd(&(m_launch_op.body().front()));
  122. auto dst_layout = output_layout(m_launch_op);
  123. auto index = get_multidim_tid(rewriter, loc, operands[0], dst_layout);
  124. auto input = get_operand<LoadOp>(rewriter, loc, operands[0], index);
  125. rewriter.replaceOp(op, lower_typecvt_to_std(op, rewriter, loc, input));
  126. return success();
  127. }
  128. private:
  129. gpu::LaunchOp& m_launch_op;
  130. };
  131. struct DimshuffleLowering : public ConversionPattern {
  132. DimshuffleLowering(MLIRContext* ctx, gpu::LaunchOp& launch_op)
  133. : ConversionPattern(dialect::Dimshuffle::getOperationName(), 1,
  134. ctx),
  135. m_launch_op{launch_op} {}
  136. static std::vector<mlir::Value> get_index_from_pattern(
  137. const std::vector<int32_t>& pattern,
  138. const std::vector<mlir::Value>& index) {
  139. size_t ndim = *std::max_element(pattern.begin(), pattern.end()) + 1;
  140. std::vector<mlir::Value> res(ndim);
  141. for (size_t i = 0; i < pattern.size(); i++) {
  142. int32_t j = pattern[i];
  143. if (j >= 0) {
  144. res[j] = index[i];
  145. }
  146. }
  147. return res;
  148. }
  149. LogicalResult matchAndRewrite(
  150. Operation* op, ArrayRef<Value> operands,
  151. ConversionPatternRewriter& rewriter) const final {
  152. auto loc = op->getLoc();
  153. rewriter.setInsertionPointToEnd(&(m_launch_op.body().front()));
  154. auto dst_layout = output_layout(m_launch_op);
  155. auto index = get_multidim_tid(rewriter, loc, operands[0], dst_layout);
  156. auto pattern = llvm::dyn_cast<dialect::Dimshuffle>(op).pattern();
  157. auto shuffled_index = get_index_from_pattern(pattern, index);
  158. rewriter.replaceOp(op, get_operand<LoadOp>(rewriter, loc, operands[0],
  159. shuffled_index));
  160. return success();
  161. }
  162. private:
  163. gpu::LaunchOp& m_launch_op;
  164. };
  165. struct ReturnOpLowering : public ConversionPattern {
  166. ReturnOpLowering(MLIRContext* ctx, gpu::LaunchOp& launch_op)
  167. : ConversionPattern(dialect::ReturnOp::getOperationName(), 1, ctx),
  168. m_launch_op{launch_op} {}
  169. LogicalResult matchAndRewrite(
  170. Operation* op, ArrayRef<Value>,
  171. ConversionPatternRewriter& rewriter) const final {
  172. rewriter.replaceOpWithNewOp<mlir::ReturnOp>(op);
  173. auto loc = op->getLoc();
  174. //! remove the first gpu.terminator
  175. m_launch_op.body().front().front().erase();
  176. //! if (tid >= nr_tid) {return;} in the begin of the block
  177. rewriter.setInsertionPointToStart(&(m_launch_op.body().front()));
  178. Block* cond_block = rewriter.getInsertionBlock();
  179. Block::iterator op_position = rewriter.getInsertionPoint();
  180. Block* remaining_ops_block =
  181. rewriter.splitBlock(cond_block, op_position);
  182. rewriter.setInsertionPointToEnd(cond_block);
  183. auto index = get_tid(rewriter, loc);
  184. auto comparison = rewriter.create<mlir::CmpIOp>(
  185. loc, CmpIPredicate::sge, index,
  186. m_launch_op.getParentOfType<mlir::FuncOp>()
  187. .getArguments()
  188. .back());
  189. Block* then_block =
  190. rewriter.splitBlock(cond_block, rewriter.getInsertionPoint());
  191. rewriter.setInsertionPointToEnd(then_block);
  192. rewriter.create<gpu::TerminatorOp>(loc);
  193. rewriter.setInsertionPointToEnd(cond_block);
  194. rewriter.create<mlir::CondBranchOp>(
  195. loc, comparison, then_block, ArrayRef<Value>(),
  196. remaining_ops_block, ArrayRef<Value>());
  197. rewriter.setInsertionPointToEnd(remaining_ops_block);
  198. rewriter.create<gpu::TerminatorOp>(loc);
  199. return success();
  200. }
  201. private:
  202. gpu::LaunchOp& m_launch_op;
  203. };
  204. struct ConstantScalarOpLowering
  205. : public OpRewritePattern<dialect::ConstantScalarOp> {
  206. ConstantScalarOpLowering(MLIRContext* ctx, gpu::LaunchOp& launch_op)
  207. : OpRewritePattern<dialect::ConstantScalarOp>(ctx),
  208. m_launch_op{launch_op} {}
  209. LogicalResult matchAndRewrite(dialect::ConstantScalarOp op,
  210. PatternRewriter& rewriter) const final {
  211. dialect::ConstantScalarOpAdaptor constant_scalar_adaptor(op);
  212. rewriter.setInsertionPointToEnd(&(m_launch_op.body().front()));
  213. rewriter.replaceOpWithNewOp<mlir::ConstantOp>(
  214. op, constant_scalar_adaptor.value());
  215. return success();
  216. }
  217. private:
  218. gpu::LaunchOp& m_launch_op;
  219. };
  220. struct AssignOpLowering : public ConversionPattern {
  221. AssignOpLowering(MLIRContext* ctx, gpu::LaunchOp& launch_op)
  222. : ConversionPattern(dialect::AssignOp::getOperationName(), 2, ctx),
  223. m_launch_op{launch_op} {}
  224. LogicalResult matchAndRewrite(
  225. Operation* op, ArrayRef<Value> operands,
  226. ConversionPatternRewriter& rewriter) const final {
  227. auto loc = op->getLoc();
  228. dialect::AssignOpAdaptor assign_adaptor(operands);
  229. rewriter.setInsertionPointToEnd(&(m_launch_op.body().front()));
  230. auto dst_layout = output_layout(m_launch_op);
  231. auto index = get_multidim_tid(rewriter, loc, assign_adaptor.rhs(),
  232. dst_layout);
  233. auto loaded_lhs =
  234. get_operand<LoadOp>(rewriter, loc, assign_adaptor.lhs(), index);
  235. rewriter.create<StoreOp>(loc, loaded_lhs, assign_adaptor.rhs(), index);
  236. rewriter.eraseOp(op);
  237. return success();
  238. }
  239. private:
  240. gpu::LaunchOp& m_launch_op;
  241. };
  242. class MgbToGpuLoweringPass
  243. : public PassWrapper<MgbToGpuLoweringPass, FunctionPass> {
  244. public:
  245. void getDependentDialects(mlir::DialectRegistry& registry) const override {
  246. registry.insert<mlir::gpu::GPUDialect>();
  247. registry.insert<mlir::StandardOpsDialect>();
  248. }
  249. void runOnFunction() override final {
  250. auto func_op = getFunction();
  251. Location loc = func_op.getLoc();
  252. OpBuilder builder(&func_op.getBody());
  253. Value constantOne = builder.create<ConstantIndexOp>(loc, 1);
  254. gpu::LaunchOp launch_op = builder.create<gpu::LaunchOp>(
  255. loc, constantOne, constantOne, constantOne, constantOne,
  256. constantOne, constantOne);
  257. builder.setInsertionPointToEnd(&(launch_op.body().front()));
  258. builder.create<gpu::TerminatorOp>(loc);
  259. OwningRewritePatternList patterns;
  260. ConversionTarget target(getContext());
  261. target.addLegalDialect<StandardOpsDialect>();
  262. target.addLegalDialect<gpu::GPUDialect>();
  263. target.addIllegalDialect<MgbDialect>();
  264. patterns.insert<ElemwiseLowering, TypeCvtLowering, DimshuffleLowering,
  265. ReturnOpLowering, ConstantScalarOpLowering,
  266. AssignOpLowering>(&getContext(), launch_op);
  267. if (failed(applyPartialConversion(func_op, target,
  268. std::move(patterns)))) {
  269. signalPassFailure();
  270. }
  271. }
  272. };
  273. } // namespace
  274. std::unique_ptr<mlir::Pass> mgb::jit::create_lower_to_gpu_pass() {
  275. return std::make_unique<MgbToGpuLoweringPass>();
  276. }
  277. #endif // MGB_JIT && MGB_JIT_MLIR
  278. // vim: syntax=cpp.doxygen

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台