GitOrigin-RevId: ce6f4ea42a
release-1.1
@@ -34,7 +34,7 @@ public: | |||||
Property property() const override { | Property property() const override { | ||||
using F = Property::Flag; | using F = Property::Flag; | ||||
return Property{F::NEED_INPUT_COLLAPSE | F::BIND_NDIM, | return Property{F::NEED_INPUT_COLLAPSE | F::BIND_NDIM, | ||||
JITFeatureBits::NONE, 64}; | |||||
JITFeatureBits::DIMSHUFFLE, 64}; | |||||
} | } | ||||
size_t get_nr_workspace_outputs(JITExecutor* opr) const override; | size_t get_nr_workspace_outputs(JITExecutor* opr) const override; | ||||
@@ -62,6 +62,7 @@ struct ElemwiseLowering : public ConversionPattern { | |||||
ElemwiseLowering(MLIRContext* ctx) | ElemwiseLowering(MLIRContext* ctx) | ||||
: ConversionPattern(mgb::dialect::Elemwise::getOperationName(), 1, | : ConversionPattern(mgb::dialect::Elemwise::getOperationName(), 1, | ||||
ctx) {} | ctx) {} | ||||
LogicalResult matchAndRewrite( | LogicalResult matchAndRewrite( | ||||
Operation* op, ArrayRef<Value> operands, | Operation* op, ArrayRef<Value> operands, | ||||
ConversionPatternRewriter& rewriter) const final { | ConversionPatternRewriter& rewriter) const final { | ||||
@@ -89,6 +90,7 @@ struct TypeCvtLowering : public ConversionPattern { | |||||
TypeCvtLowering(MLIRContext* ctx) | TypeCvtLowering(MLIRContext* ctx) | ||||
: ConversionPattern(mgb::dialect::TypeCvt::getOperationName(), 1, | : ConversionPattern(mgb::dialect::TypeCvt::getOperationName(), 1, | ||||
ctx) {} | ctx) {} | ||||
LogicalResult matchAndRewrite( | LogicalResult matchAndRewrite( | ||||
Operation* op, ArrayRef<Value> operands, | Operation* op, ArrayRef<Value> operands, | ||||
ConversionPatternRewriter& rewriter) const final { | ConversionPatternRewriter& rewriter) const final { | ||||
@@ -105,6 +107,41 @@ struct TypeCvtLowering : public ConversionPattern { | |||||
} | } | ||||
}; | }; | ||||
struct DimshuffleLowering : public ConversionPattern { | |||||
DimshuffleLowering(MLIRContext* ctx) | |||||
: ConversionPattern(mgb::dialect::Dimshuffle::getOperationName(), 1, | |||||
ctx) {} | |||||
static mlir::AffineMap get_affinemap_from_pattern( | |||||
const std::vector<int32_t>& pattern, mlir::MLIRContext* ctx) { | |||||
size_t ndim = *std::max_element(pattern.begin(), pattern.end()) + 1; | |||||
std::vector<mlir::AffineExpr> exprs(ndim); | |||||
for (size_t i = 0; i < pattern.size(); i++) { | |||||
int32_t j = pattern[i]; | |||||
if (j >= 0) { | |||||
exprs[j] = mlir::getAffineDimExpr(i, ctx); | |||||
} | |||||
} | |||||
return mlir::AffineMap::get(pattern.size(), 0, exprs, ctx); | |||||
} | |||||
LogicalResult matchAndRewrite( | |||||
Operation* op, ArrayRef<Value> operands, | |||||
ConversionPatternRewriter& rewriter) const final { | |||||
auto loc = op->getLoc(); | |||||
auto pattern = llvm::dyn_cast<dialect::Dimshuffle>(op).pattern(); | |||||
auto map = get_affinemap_from_pattern(pattern, op->getContext()); | |||||
lower_op_to_loops( | |||||
op, operands, rewriter, | |||||
[loc, op, &map](OpBuilder& builder, ValueRange memref_operands, | |||||
ValueRange loop_ivs) { | |||||
return builder.create<AffineLoadOp>(loc, memref_operands[0], | |||||
map, loop_ivs); | |||||
}); | |||||
return success(); | |||||
} | |||||
}; | |||||
struct AssignOpLowering : public ConversionPattern { | struct AssignOpLowering : public ConversionPattern { | ||||
AssignOpLowering(MLIRContext* ctx) | AssignOpLowering(MLIRContext* ctx) | ||||
: ConversionPattern(dialect::AssignOp::getOperationName(), 1, ctx) { | : ConversionPattern(dialect::AssignOp::getOperationName(), 1, ctx) { | ||||
@@ -172,9 +209,9 @@ public: | |||||
target.addIllegalDialect<MgbDialect>(); | target.addIllegalDialect<MgbDialect>(); | ||||
OwningRewritePatternList patterns; | OwningRewritePatternList patterns; | ||||
patterns.insert<ElemwiseLowering, TypeCvtLowering, ReturnOpLowering, | |||||
AssignOpLowering, ConstantScalarOpLowering>( | |||||
&getContext()); | |||||
patterns.insert<ElemwiseLowering, TypeCvtLowering, DimshuffleLowering, | |||||
ReturnOpLowering, AssignOpLowering, | |||||
ConstantScalarOpLowering>(&getContext()); | |||||
if (failed(applyPartialConversion(getFunction(), target, | if (failed(applyPartialConversion(getFunction(), target, | ||||
std::move(patterns)))) { | std::move(patterns)))) { | ||||
@@ -152,6 +152,47 @@ private: | |||||
gpu::LaunchOp& m_launch_op; | gpu::LaunchOp& m_launch_op; | ||||
}; | }; | ||||
struct DimshuffleLowering : public ConversionPattern { | |||||
DimshuffleLowering(MLIRContext* ctx, gpu::LaunchOp& launch_op) | |||||
: ConversionPattern(dialect::Dimshuffle::getOperationName(), 1, | |||||
ctx), | |||||
m_launch_op{launch_op} {} | |||||
static std::vector<mlir::Value> get_index_from_pattern( | |||||
const std::vector<int32_t>& pattern, | |||||
const std::vector<mlir::Value>& index) { | |||||
size_t ndim = *std::max_element(pattern.begin(), pattern.end()) + 1; | |||||
std::vector<mlir::Value> res(ndim); | |||||
for (size_t i = 0; i < pattern.size(); i++) { | |||||
int32_t j = pattern[i]; | |||||
if (j >= 0) { | |||||
res[j] = index[i]; | |||||
} | |||||
} | |||||
return res; | |||||
} | |||||
LogicalResult matchAndRewrite( | |||||
Operation* op, ArrayRef<Value> operands, | |||||
ConversionPatternRewriter& rewriter) const final { | |||||
auto loc = op->getLoc(); | |||||
rewriter.setInsertionPointToEnd(&(m_launch_op.body().front())); | |||||
auto dst_layout = output_layout(m_launch_op); | |||||
auto index = get_multidim_tid(rewriter, loc, operands[0], dst_layout); | |||||
auto pattern = llvm::dyn_cast<dialect::Dimshuffle>(op).pattern(); | |||||
auto shuffled_index = get_index_from_pattern(pattern, index); | |||||
rewriter.replaceOp(op, get_operand<LoadOp>(rewriter, loc, operands[0], | |||||
shuffled_index)); | |||||
return success(); | |||||
} | |||||
private: | |||||
gpu::LaunchOp& m_launch_op; | |||||
}; | |||||
struct ReturnOpLowering : public ConversionPattern { | struct ReturnOpLowering : public ConversionPattern { | ||||
ReturnOpLowering(MLIRContext* ctx, gpu::LaunchOp& launch_op) | ReturnOpLowering(MLIRContext* ctx, gpu::LaunchOp& launch_op) | ||||
: ConversionPattern(dialect::ReturnOp::getOperationName(), 1, ctx), | : ConversionPattern(dialect::ReturnOp::getOperationName(), 1, ctx), | ||||
@@ -275,9 +316,9 @@ public: | |||||
target.addLegalDialect<gpu::GPUDialect>(); | target.addLegalDialect<gpu::GPUDialect>(); | ||||
target.addIllegalDialect<MgbDialect>(); | target.addIllegalDialect<MgbDialect>(); | ||||
patterns.insert<ElemwiseLowering, TypeCvtLowering, ReturnOpLowering, | |||||
ConstantScalarOpLowering, AssignOpLowering>( | |||||
&getContext(), launch_op); | |||||
patterns.insert<ElemwiseLowering, TypeCvtLowering, DimshuffleLowering, | |||||
ReturnOpLowering, ConstantScalarOpLowering, | |||||
AssignOpLowering>(&getContext(), launch_op); | |||||
if (failed(applyPartialConversion(func_op, target, | if (failed(applyPartialConversion(func_op, target, | ||||
std::move(patterns)))) { | std::move(patterns)))) { | ||||
@@ -20,6 +20,7 @@ | |||||
#include "megbrain/jit/mlir/ir/dialect.h" | #include "megbrain/jit/mlir/ir/dialect.h" | ||||
#include "megbrain/jit/mlir/ir/utils.h" | #include "megbrain/jit/mlir/ir/utils.h" | ||||
#include "megbrain/opr/basic_arith.h" | #include "megbrain/opr/basic_arith.h" | ||||
#include "megbrain/opr/tensor_manip.h" | |||||
#include "megdnn/dtype.h" | #include "megdnn/dtype.h" | ||||
#include <mlir/Dialect/Affine/IR/AffineOps.h> | #include <mlir/Dialect/Affine/IR/AffineOps.h> | ||||
@@ -160,6 +161,10 @@ private: | |||||
mgb_assert( | mgb_assert( | ||||
mlir::succeeded(declare(opr->output(0)->name(), out))); | mlir::succeeded(declare(opr->output(0)->name(), out))); | ||||
return; | return; | ||||
} else if (opr->same_type<opr::Dimshuffle>()) { | |||||
auto&& out = gen_dimshuffle(opr->cast_final<opr::Dimshuffle>()); | |||||
mgb_assert( | |||||
mlir::succeeded(declare(opr->output(0)->name(), out))); | |||||
} else if (opr->same_type<opr::TypeCvt>()) { | } else if (opr->same_type<opr::TypeCvt>()) { | ||||
auto&& out = gen_typecvt(opr->cast_final<opr::TypeCvt>()); | auto&& out = gen_typecvt(opr->cast_final<opr::TypeCvt>()); | ||||
mgb_assert( | mgb_assert( | ||||
@@ -186,18 +191,44 @@ private: | |||||
} | } | ||||
mlir::Value gen_typecvt(const opr::TypeCvt& opr) { | mlir::Value gen_typecvt(const opr::TypeCvt& opr) { | ||||
auto shape = get(opr.input(0)) | |||||
auto itype = get(opr.input(0)) | |||||
.getType() | .getType() | ||||
.dyn_cast_or_null<mlir::MemRefType>() | |||||
.getShape(); | |||||
.dyn_cast_or_null<mlir::MemRefType>(); | |||||
mgb_assert(itype, "currently only support MemRefType"); | |||||
auto res_type = mlir::MemRefType::get( | auto res_type = mlir::MemRefType::get( | ||||
shape, | |||||
itype.getShape(), | |||||
megdnn_dtype_to_mlir_type(opr.param(), m_builder.getContext())); | megdnn_dtype_to_mlir_type(opr.param(), m_builder.getContext())); | ||||
return m_builder.create<dialect::TypeCvt>( | return m_builder.create<dialect::TypeCvt>( | ||||
m_builder.getUnknownLoc(), res_type, get(opr.input(0)), | m_builder.getUnknownLoc(), res_type, get(opr.input(0)), | ||||
opr.input(0)->dtype(), opr.param()); | opr.input(0)->dtype(), opr.param()); | ||||
} | } | ||||
mlir::Value gen_dimshuffle(const opr::Dimshuffle& opr) { | |||||
auto itype = get(opr.input(0)) | |||||
.getType() | |||||
.dyn_cast_or_null<mlir::MemRefType>(); | |||||
mgb_assert(itype, "the input type of Dimshuffle must be MemRefType"); | |||||
auto ishape = itype.getShape(); | |||||
auto param = opr.param(); | |||||
std::vector<int32_t> pattern; | |||||
std::vector<int64_t> oshape; | |||||
for (size_t i = 0; i < param.pattern_len; i++) { | |||||
int32_t j = param.pattern[i]; | |||||
pattern.push_back(j); | |||||
if (j < 0) { | |||||
oshape.push_back(1); | |||||
} else { | |||||
oshape.push_back(ishape[j]); | |||||
} | |||||
} | |||||
auto res_type = mlir::MemRefType::get(oshape, itype.getElementType()); | |||||
return m_builder.create<dialect::Dimshuffle>( | |||||
m_builder.getUnknownLoc(), res_type, get(opr.input(0)), | |||||
pattern); | |||||
} | |||||
mlir::Type get_type(const TensorLayout& layout) { | mlir::Type get_type(const TensorLayout& layout) { | ||||
return layout_to_mlir_type(layout, m_builder); | return layout_to_mlir_type(layout, m_builder); | ||||
} | } | ||||
@@ -15,6 +15,7 @@ | |||||
#include "megbrain/jit/executor_opr.h" | #include "megbrain/jit/executor_opr.h" | ||||
#include "megbrain/opr/basic_arith.h" | #include "megbrain/opr/basic_arith.h" | ||||
#include "megbrain/opr/basic_arith_wrapper.h" | #include "megbrain/opr/basic_arith_wrapper.h" | ||||
#include "megbrain/opr/tensor_manip.h" | |||||
#include "megbrain/test/helper.h" | #include "megbrain/test/helper.h" | ||||
#include "megdnn/dtype.h" | #include "megdnn/dtype.h" | ||||
@@ -539,6 +540,51 @@ add_typecvt_gtest(Uint8, Float32); | |||||
#undef add_typecvt_gtest | #undef add_typecvt_gtest | ||||
/* ===================== TestJITMlirDimshuffle ===================== */ | |||||
void run_dimshuffle(CompNode cn, TensorShape ishape, | |||||
const std::vector<int>& pattern) { | |||||
set_backend(Backend::MLIR); | |||||
auto graph = ComputingGraph::make(); | |||||
HostTensorGenerator<> gen; | |||||
auto host_x = gen(ishape, cn); | |||||
auto x = opr::Host2DeviceCopy::make(*graph, host_x); | |||||
auto y = opr::Dimshuffle::make(x, pattern); | |||||
auto ig_gen = std::make_unique<InternalGraphGenerator>(y.node()->owner_opr()); | |||||
for (auto i : get_rev_topo_order(y)) { | |||||
if (!i->template same_type<opr::Host2DeviceCopy>()) { | |||||
ig_gen->add_opr(i); | |||||
} | |||||
} | |||||
auto igraph = ig_gen->generate(); | |||||
auto y_jit = JITExecutor::make(igraph, ig_gen->orig_inps()); | |||||
HostTensorND host_y, host_y_jit; | |||||
auto func = graph->compile({make_callback_copy(y, host_y), | |||||
make_callback_copy(y_jit, host_y_jit)}); | |||||
func->execute(); | |||||
MGB_ASSERT_TENSOR_EQ(host_y, host_y_jit); | |||||
} | |||||
void run_dimshuffle_cases(CompNode cn) { | |||||
run_dimshuffle(cn, {3, 4, 5}, {2, 0, 1}); | |||||
run_dimshuffle(cn, {3, 4, 5}, {1, -1, 0, 2}); | |||||
} | |||||
TEST(TestJITMlirDimshuffle, Basic) { | |||||
run_dimshuffle_cases(CompNode::load("cpu0")); | |||||
} | |||||
TEST(TestJITMlirDimshuffle, BasicGPU) { | |||||
REQUIRE_GPU(1); | |||||
run_dimshuffle_cases(CompNode::load("gpu0")); | |||||
} | |||||
#endif // MGB_JIT_MLIR | #endif // MGB_JIT_MLIR | ||||
#endif // MGB_JIT | #endif // MGB_JIT | ||||