You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

common.cpp 5.2 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159
  1. /**
  2. * \file src/jit/impl/mlir/ir/common.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
  10. * implied.
  11. */
  12. #include "megbrain_build_config.h"
  13. #if MGB_JIT && MGB_JIT_MLIR
  14. #include "./common.h"
  15. #include "megbrain/jit/mlir/ir/utils.h"
  16. #include <mlir/Dialect/Affine/IR/AffineOps.h>
  17. #include <mlir/Dialect/StandardOps/IR/Ops.h>
  18. using namespace mgb;
  19. using namespace jit;
  20. /* ===================== trivial unary functions ===================== */
  21. #define cb(name, op) \
  22. mlir::Value ValueBuilderHelper::name(mlir::Value lhs) { \
  23. return m_builder.create<mlir::op>(m_location, lhs); \
  24. }
  25. cb(abs, AbsFOp);
  26. cb(ceil, CeilFOp);
  27. cb(cos, CosOp);
  28. cb(exp2, Exp2Op);
  29. cb(exp, ExpOp);
  30. cb(floor, FloorFOp);
  31. cb(log10, Log10Op);
  32. cb(log2, Log2Op);
  33. cb(log, LogOp);
  34. cb(neg, NegFOp);
  35. cb(rsqrt, RsqrtOp);
  36. cb(sin, SinOp);
  37. cb(sqrt, SqrtOp);
  38. cb(tanh, TanhOp);
  39. #undef cb
  40. /* ===================== trivial binary functions ===================== */
  41. #define cb(name, op) \
  42. mlir::Value ValueBuilderHelper::name(mlir::Value lhs, mlir::Value rhs) { \
  43. return m_builder.create<mlir::op>(m_location, lhs, rhs); \
  44. }
  45. cb(add, AddFOp);
  46. cb(bit_and, AndOp);
  47. cb(bit_or, OrOp);
  48. cb(div, DivFOp);
  49. cb(divI, SignedDivIOp);
  50. cb(modI, SignedRemIOp);
  51. cb(mod, RemFOp);
  52. cb(mul, MulFOp);
  53. cb(sub, SubFOp);
  54. #undef cb
  55. /* ===================== compare functions ===================== */
  56. #define cb(name, mode) \
  57. mlir::Value ValueBuilderHelper::name(mlir::Value lhs, mlir::Value rhs) { \
  58. return m_builder.create<mlir::CmpFOp>( \
  59. m_location, mlir::CmpFPredicate::mode, lhs, rhs); \
  60. }
  61. cb(eq, OEQ);
  62. cb(ge, OGE);
  63. cb(gt, OGT);
  64. cb(le, OLE);
  65. cb(lt, OLT);
  66. #undef cb
  67. mlir::Value ValueBuilderHelper::max(mlir::Value lhs, mlir::Value rhs) {
  68. mlir::Value cmp = m_builder.create<mlir::CmpFOp>(
  69. m_location, mlir::CmpFPredicate::OGT, lhs, rhs);
  70. return m_builder.create<mlir::SelectOp>(m_location, cmp, lhs, rhs);
  71. }
  72. mlir::Value ValueBuilderHelper::min(mlir::Value lhs, mlir::Value rhs) {
  73. mlir::Value cmp = m_builder.create<mlir::CmpFOp>(
  74. m_location, mlir::CmpFPredicate::OLT, lhs, rhs);
  75. return m_builder.create<mlir::SelectOp>(m_location, cmp, lhs, rhs);
  76. }
  77. /* ===================== constant functions ===================== */
  78. mlir::Value ValueBuilderHelper::const_f32(float val) {
  79. return m_builder.create<mlir::ConstantOp>(m_location,
  80. m_builder.getF32FloatAttr(val));
  81. }
  82. mlir::Value ValueBuilderHelper::const_i32(int32_t val) {
  83. return m_builder.create<mlir::ConstantOp>(m_location,
  84. m_builder.getIndexAttr(val));
  85. }
  86. /* ===================== select function ===================== */
  87. mlir::Value ValueBuilderHelper::select(mlir::Value cond, mlir::Value true_val,
  88. mlir::Value false_val) {
  89. return m_builder.create<mlir::SelectOp>(m_location, cond, true_val,
  90. false_val);
  91. }
  92. /* ===================== helper functions ===================== */
  93. mlir::AffineMap jit::get_affinemap(mlir::OpBuilder& builder,
  94. const mlir::Value& val,
  95. const megdnn::TensorLayout& layout) {
  96. auto type = val.getType().cast<mlir::MemRefType>();
  97. mgb_assert(type, "currently only support MemRefType");
  98. std::vector<mlir::AffineExpr> exprs;
  99. for (int i = 0; i < type.getRank(); ++i) {
  100. if (layout[i] == 1) {
  101. exprs.push_back(builder.getAffineConstantExpr(0));
  102. } else {
  103. exprs.push_back(builder.getAffineDimExpr(i));
  104. }
  105. }
  106. auto map = mlir::AffineMap::get(type.getRank(), 0, exprs,
  107. builder.getContext());
  108. return map;
  109. }
  110. mlir::Value jit::get_affine_load_op(mlir::OpBuilder& builder,
  111. const mlir::Location& loc,
  112. const mlir::Value& val,
  113. const mlir::ValueRange& index,
  114. const megdnn::TensorLayout& dst) {
  115. if (val.getType().isa<mlir::MemRefType>()) {
  116. auto type = val.getType().cast<mlir::MemRefType>();
  117. megdnn::TensorLayout src_layout = mlir_type_to_layout(type);
  118. src_layout.init_contiguous_stride();
  119. if (src_layout.eq_shape(dst)) {
  120. return builder.create<mlir::AffineLoadOp>(loc, val, index);
  121. } else {
  122. auto lhs_map = get_affinemap(builder, val, src_layout);
  123. return builder.create<mlir::AffineLoadOp>(loc, val, lhs_map, index);
  124. }
  125. } else {
  126. return val;
  127. }
  128. }
  129. #endif // MGB_JIT && MGB_JIT_MLIR
  130. // vim: syntax=cpp.doxygen

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台