Browse Source

feat(mgb/jit): add Dimshuffle and lowering passes in jit mlir backend

GitOrigin-RevId: ce6f4ea42a
release-1.1
Megvii Engine Team 4 years ago
parent
commit
5d0f8da46a
5 changed files with 166 additions and 11 deletions
  1. +1
    -1
      src/jit/impl/mlir/compiler.h
  2. +40
    -3
      src/jit/impl/mlir/ir/lower_to_affine_pass.cpp
  3. +44
    -3
      src/jit/impl/mlir/ir/lower_to_gpu_pass.cpp
  4. +35
    -4
      src/jit/impl/mlir/mlir_gen.cpp
  5. +46
    -0
      src/jit/test/codegen.cpp

+ 1
- 1
src/jit/impl/mlir/compiler.h View File

@@ -34,7 +34,7 @@ public:
Property property() const override { Property property() const override {
using F = Property::Flag; using F = Property::Flag;
return Property{F::NEED_INPUT_COLLAPSE | F::BIND_NDIM, return Property{F::NEED_INPUT_COLLAPSE | F::BIND_NDIM,
JITFeatureBits::NONE, 64};
JITFeatureBits::DIMSHUFFLE, 64};
} }


size_t get_nr_workspace_outputs(JITExecutor* opr) const override; size_t get_nr_workspace_outputs(JITExecutor* opr) const override;


+ 40
- 3
src/jit/impl/mlir/ir/lower_to_affine_pass.cpp View File

@@ -62,6 +62,7 @@ struct ElemwiseLowering : public ConversionPattern {
ElemwiseLowering(MLIRContext* ctx) ElemwiseLowering(MLIRContext* ctx)
: ConversionPattern(mgb::dialect::Elemwise::getOperationName(), 1, : ConversionPattern(mgb::dialect::Elemwise::getOperationName(), 1,
ctx) {} ctx) {}

LogicalResult matchAndRewrite( LogicalResult matchAndRewrite(
Operation* op, ArrayRef<Value> operands, Operation* op, ArrayRef<Value> operands,
ConversionPatternRewriter& rewriter) const final { ConversionPatternRewriter& rewriter) const final {
@@ -89,6 +90,7 @@ struct TypeCvtLowering : public ConversionPattern {
TypeCvtLowering(MLIRContext* ctx) TypeCvtLowering(MLIRContext* ctx)
: ConversionPattern(mgb::dialect::TypeCvt::getOperationName(), 1, : ConversionPattern(mgb::dialect::TypeCvt::getOperationName(), 1,
ctx) {} ctx) {}

LogicalResult matchAndRewrite( LogicalResult matchAndRewrite(
Operation* op, ArrayRef<Value> operands, Operation* op, ArrayRef<Value> operands,
ConversionPatternRewriter& rewriter) const final { ConversionPatternRewriter& rewriter) const final {
@@ -105,6 +107,41 @@ struct TypeCvtLowering : public ConversionPattern {
} }
}; };


struct DimshuffleLowering : public ConversionPattern {
DimshuffleLowering(MLIRContext* ctx)
: ConversionPattern(mgb::dialect::Dimshuffle::getOperationName(), 1,
ctx) {}

static mlir::AffineMap get_affinemap_from_pattern(
const std::vector<int32_t>& pattern, mlir::MLIRContext* ctx) {
size_t ndim = *std::max_element(pattern.begin(), pattern.end()) + 1;
std::vector<mlir::AffineExpr> exprs(ndim);
for (size_t i = 0; i < pattern.size(); i++) {
int32_t j = pattern[i];
if (j >= 0) {
exprs[j] = mlir::getAffineDimExpr(i, ctx);
}
}
return mlir::AffineMap::get(pattern.size(), 0, exprs, ctx);
}

LogicalResult matchAndRewrite(
Operation* op, ArrayRef<Value> operands,
ConversionPatternRewriter& rewriter) const final {
auto loc = op->getLoc();
auto pattern = llvm::dyn_cast<dialect::Dimshuffle>(op).pattern();
auto map = get_affinemap_from_pattern(pattern, op->getContext());
lower_op_to_loops(
op, operands, rewriter,
[loc, op, &map](OpBuilder& builder, ValueRange memref_operands,
ValueRange loop_ivs) {
return builder.create<AffineLoadOp>(loc, memref_operands[0],
map, loop_ivs);
});
return success();
}
};

struct AssignOpLowering : public ConversionPattern { struct AssignOpLowering : public ConversionPattern {
AssignOpLowering(MLIRContext* ctx) AssignOpLowering(MLIRContext* ctx)
: ConversionPattern(dialect::AssignOp::getOperationName(), 1, ctx) { : ConversionPattern(dialect::AssignOp::getOperationName(), 1, ctx) {
@@ -172,9 +209,9 @@ public:
target.addIllegalDialect<MgbDialect>(); target.addIllegalDialect<MgbDialect>();


OwningRewritePatternList patterns; OwningRewritePatternList patterns;
patterns.insert<ElemwiseLowering, TypeCvtLowering, ReturnOpLowering,
AssignOpLowering, ConstantScalarOpLowering>(
&getContext());
patterns.insert<ElemwiseLowering, TypeCvtLowering, DimshuffleLowering,
ReturnOpLowering, AssignOpLowering,
ConstantScalarOpLowering>(&getContext());


if (failed(applyPartialConversion(getFunction(), target, if (failed(applyPartialConversion(getFunction(), target,
std::move(patterns)))) { std::move(patterns)))) {


+ 44
- 3
src/jit/impl/mlir/ir/lower_to_gpu_pass.cpp View File

@@ -152,6 +152,47 @@ private:
gpu::LaunchOp& m_launch_op; gpu::LaunchOp& m_launch_op;
}; };


struct DimshuffleLowering : public ConversionPattern {
DimshuffleLowering(MLIRContext* ctx, gpu::LaunchOp& launch_op)
: ConversionPattern(dialect::Dimshuffle::getOperationName(), 1,
ctx),
m_launch_op{launch_op} {}

static std::vector<mlir::Value> get_index_from_pattern(
const std::vector<int32_t>& pattern,
const std::vector<mlir::Value>& index) {
size_t ndim = *std::max_element(pattern.begin(), pattern.end()) + 1;
std::vector<mlir::Value> res(ndim);
for (size_t i = 0; i < pattern.size(); i++) {
int32_t j = pattern[i];
if (j >= 0) {
res[j] = index[i];
}
}
return res;
}

LogicalResult matchAndRewrite(
Operation* op, ArrayRef<Value> operands,
ConversionPatternRewriter& rewriter) const final {
auto loc = op->getLoc();

rewriter.setInsertionPointToEnd(&(m_launch_op.body().front()));

auto dst_layout = output_layout(m_launch_op);
auto index = get_multidim_tid(rewriter, loc, operands[0], dst_layout);
auto pattern = llvm::dyn_cast<dialect::Dimshuffle>(op).pattern();
auto shuffled_index = get_index_from_pattern(pattern, index);

rewriter.replaceOp(op, get_operand<LoadOp>(rewriter, loc, operands[0],
shuffled_index));
return success();
}

private:
gpu::LaunchOp& m_launch_op;
};

struct ReturnOpLowering : public ConversionPattern { struct ReturnOpLowering : public ConversionPattern {
ReturnOpLowering(MLIRContext* ctx, gpu::LaunchOp& launch_op) ReturnOpLowering(MLIRContext* ctx, gpu::LaunchOp& launch_op)
: ConversionPattern(dialect::ReturnOp::getOperationName(), 1, ctx), : ConversionPattern(dialect::ReturnOp::getOperationName(), 1, ctx),
@@ -275,9 +316,9 @@ public:
target.addLegalDialect<gpu::GPUDialect>(); target.addLegalDialect<gpu::GPUDialect>();
target.addIllegalDialect<MgbDialect>(); target.addIllegalDialect<MgbDialect>();


patterns.insert<ElemwiseLowering, TypeCvtLowering, ReturnOpLowering,
ConstantScalarOpLowering, AssignOpLowering>(
&getContext(), launch_op);
patterns.insert<ElemwiseLowering, TypeCvtLowering, DimshuffleLowering,
ReturnOpLowering, ConstantScalarOpLowering,
AssignOpLowering>(&getContext(), launch_op);


if (failed(applyPartialConversion(func_op, target, if (failed(applyPartialConversion(func_op, target,
std::move(patterns)))) { std::move(patterns)))) {


+ 35
- 4
src/jit/impl/mlir/mlir_gen.cpp View File

@@ -20,6 +20,7 @@
#include "megbrain/jit/mlir/ir/dialect.h" #include "megbrain/jit/mlir/ir/dialect.h"
#include "megbrain/jit/mlir/ir/utils.h" #include "megbrain/jit/mlir/ir/utils.h"
#include "megbrain/opr/basic_arith.h" #include "megbrain/opr/basic_arith.h"
#include "megbrain/opr/tensor_manip.h"
#include "megdnn/dtype.h" #include "megdnn/dtype.h"


#include <mlir/Dialect/Affine/IR/AffineOps.h> #include <mlir/Dialect/Affine/IR/AffineOps.h>
@@ -160,6 +161,10 @@ private:
mgb_assert( mgb_assert(
mlir::succeeded(declare(opr->output(0)->name(), out))); mlir::succeeded(declare(opr->output(0)->name(), out)));
return; return;
} else if (opr->same_type<opr::Dimshuffle>()) {
auto&& out = gen_dimshuffle(opr->cast_final<opr::Dimshuffle>());
mgb_assert(
mlir::succeeded(declare(opr->output(0)->name(), out)));
} else if (opr->same_type<opr::TypeCvt>()) { } else if (opr->same_type<opr::TypeCvt>()) {
auto&& out = gen_typecvt(opr->cast_final<opr::TypeCvt>()); auto&& out = gen_typecvt(opr->cast_final<opr::TypeCvt>());
mgb_assert( mgb_assert(
@@ -186,18 +191,44 @@ private:
} }


mlir::Value gen_typecvt(const opr::TypeCvt& opr) { mlir::Value gen_typecvt(const opr::TypeCvt& opr) {
auto shape = get(opr.input(0))
auto itype = get(opr.input(0))
.getType() .getType()
.dyn_cast_or_null<mlir::MemRefType>()
.getShape();
.dyn_cast_or_null<mlir::MemRefType>();
mgb_assert(itype, "currently only support MemRefType");
auto res_type = mlir::MemRefType::get( auto res_type = mlir::MemRefType::get(
shape,
itype.getShape(),
megdnn_dtype_to_mlir_type(opr.param(), m_builder.getContext())); megdnn_dtype_to_mlir_type(opr.param(), m_builder.getContext()));
return m_builder.create<dialect::TypeCvt>( return m_builder.create<dialect::TypeCvt>(
m_builder.getUnknownLoc(), res_type, get(opr.input(0)), m_builder.getUnknownLoc(), res_type, get(opr.input(0)),
opr.input(0)->dtype(), opr.param()); opr.input(0)->dtype(), opr.param());
} }


mlir::Value gen_dimshuffle(const opr::Dimshuffle& opr) {
auto itype = get(opr.input(0))
.getType()
.dyn_cast_or_null<mlir::MemRefType>();
mgb_assert(itype, "the input type of Dimshuffle must be MemRefType");
auto ishape = itype.getShape();
auto param = opr.param();

std::vector<int32_t> pattern;
std::vector<int64_t> oshape;
for (size_t i = 0; i < param.pattern_len; i++) {
int32_t j = param.pattern[i];
pattern.push_back(j);
if (j < 0) {
oshape.push_back(1);
} else {
oshape.push_back(ishape[j]);
}
}
auto res_type = mlir::MemRefType::get(oshape, itype.getElementType());

return m_builder.create<dialect::Dimshuffle>(
m_builder.getUnknownLoc(), res_type, get(opr.input(0)),
pattern);
}

mlir::Type get_type(const TensorLayout& layout) { mlir::Type get_type(const TensorLayout& layout) {
return layout_to_mlir_type(layout, m_builder); return layout_to_mlir_type(layout, m_builder);
} }


+ 46
- 0
src/jit/test/codegen.cpp View File

@@ -15,6 +15,7 @@
#include "megbrain/jit/executor_opr.h" #include "megbrain/jit/executor_opr.h"
#include "megbrain/opr/basic_arith.h" #include "megbrain/opr/basic_arith.h"
#include "megbrain/opr/basic_arith_wrapper.h" #include "megbrain/opr/basic_arith_wrapper.h"
#include "megbrain/opr/tensor_manip.h"
#include "megbrain/test/helper.h" #include "megbrain/test/helper.h"
#include "megdnn/dtype.h" #include "megdnn/dtype.h"


@@ -539,6 +540,51 @@ add_typecvt_gtest(Uint8, Float32);


#undef add_typecvt_gtest #undef add_typecvt_gtest


/* ===================== TestJITMlirDimshuffle ===================== */

void run_dimshuffle(CompNode cn, TensorShape ishape,
const std::vector<int>& pattern) {
set_backend(Backend::MLIR);
auto graph = ComputingGraph::make();
HostTensorGenerator<> gen;

auto host_x = gen(ishape, cn);
auto x = opr::Host2DeviceCopy::make(*graph, host_x);
auto y = opr::Dimshuffle::make(x, pattern);

auto ig_gen = std::make_unique<InternalGraphGenerator>(y.node()->owner_opr());

for (auto i : get_rev_topo_order(y)) {
if (!i->template same_type<opr::Host2DeviceCopy>()) {
ig_gen->add_opr(i);
}
}

auto igraph = ig_gen->generate();
auto y_jit = JITExecutor::make(igraph, ig_gen->orig_inps());

HostTensorND host_y, host_y_jit;
auto func = graph->compile({make_callback_copy(y, host_y),
make_callback_copy(y_jit, host_y_jit)});
func->execute();

MGB_ASSERT_TENSOR_EQ(host_y, host_y_jit);
}

void run_dimshuffle_cases(CompNode cn) {
run_dimshuffle(cn, {3, 4, 5}, {2, 0, 1});
run_dimshuffle(cn, {3, 4, 5}, {1, -1, 0, 2});
}

TEST(TestJITMlirDimshuffle, Basic) {
run_dimshuffle_cases(CompNode::load("cpu0"));
}

TEST(TestJITMlirDimshuffle, BasicGPU) {
REQUIRE_GPU(1);
run_dimshuffle_cases(CompNode::load("gpu0"));
}

#endif // MGB_JIT_MLIR #endif // MGB_JIT_MLIR


#endif // MGB_JIT #endif // MGB_JIT


Loading…
Cancel
Save