GitOrigin-RevId: ce6f4ea42a
release-1.1
@@ -34,7 +34,7 @@ public: | |||
Property property() const override { | |||
using F = Property::Flag; | |||
return Property{F::NEED_INPUT_COLLAPSE | F::BIND_NDIM, | |||
JITFeatureBits::NONE, 64}; | |||
JITFeatureBits::DIMSHUFFLE, 64}; | |||
} | |||
size_t get_nr_workspace_outputs(JITExecutor* opr) const override; | |||
@@ -62,6 +62,7 @@ struct ElemwiseLowering : public ConversionPattern { | |||
ElemwiseLowering(MLIRContext* ctx) | |||
: ConversionPattern(mgb::dialect::Elemwise::getOperationName(), 1, | |||
ctx) {} | |||
LogicalResult matchAndRewrite( | |||
Operation* op, ArrayRef<Value> operands, | |||
ConversionPatternRewriter& rewriter) const final { | |||
@@ -89,6 +90,7 @@ struct TypeCvtLowering : public ConversionPattern { | |||
TypeCvtLowering(MLIRContext* ctx) | |||
: ConversionPattern(mgb::dialect::TypeCvt::getOperationName(), 1, | |||
ctx) {} | |||
LogicalResult matchAndRewrite( | |||
Operation* op, ArrayRef<Value> operands, | |||
ConversionPatternRewriter& rewriter) const final { | |||
@@ -105,6 +107,41 @@ struct TypeCvtLowering : public ConversionPattern { | |||
} | |||
}; | |||
struct DimshuffleLowering : public ConversionPattern { | |||
DimshuffleLowering(MLIRContext* ctx) | |||
: ConversionPattern(mgb::dialect::Dimshuffle::getOperationName(), 1, | |||
ctx) {} | |||
static mlir::AffineMap get_affinemap_from_pattern( | |||
const std::vector<int32_t>& pattern, mlir::MLIRContext* ctx) { | |||
size_t ndim = *std::max_element(pattern.begin(), pattern.end()) + 1; | |||
std::vector<mlir::AffineExpr> exprs(ndim); | |||
for (size_t i = 0; i < pattern.size(); i++) { | |||
int32_t j = pattern[i]; | |||
if (j >= 0) { | |||
exprs[j] = mlir::getAffineDimExpr(i, ctx); | |||
} | |||
} | |||
return mlir::AffineMap::get(pattern.size(), 0, exprs, ctx); | |||
} | |||
LogicalResult matchAndRewrite( | |||
Operation* op, ArrayRef<Value> operands, | |||
ConversionPatternRewriter& rewriter) const final { | |||
auto loc = op->getLoc(); | |||
auto pattern = llvm::dyn_cast<dialect::Dimshuffle>(op).pattern(); | |||
auto map = get_affinemap_from_pattern(pattern, op->getContext()); | |||
lower_op_to_loops( | |||
op, operands, rewriter, | |||
[loc, op, &map](OpBuilder& builder, ValueRange memref_operands, | |||
ValueRange loop_ivs) { | |||
return builder.create<AffineLoadOp>(loc, memref_operands[0], | |||
map, loop_ivs); | |||
}); | |||
return success(); | |||
} | |||
}; | |||
struct AssignOpLowering : public ConversionPattern { | |||
AssignOpLowering(MLIRContext* ctx) | |||
: ConversionPattern(dialect::AssignOp::getOperationName(), 1, ctx) { | |||
@@ -172,9 +209,9 @@ public: | |||
target.addIllegalDialect<MgbDialect>(); | |||
OwningRewritePatternList patterns; | |||
patterns.insert<ElemwiseLowering, TypeCvtLowering, ReturnOpLowering, | |||
AssignOpLowering, ConstantScalarOpLowering>( | |||
&getContext()); | |||
patterns.insert<ElemwiseLowering, TypeCvtLowering, DimshuffleLowering, | |||
ReturnOpLowering, AssignOpLowering, | |||
ConstantScalarOpLowering>(&getContext()); | |||
if (failed(applyPartialConversion(getFunction(), target, | |||
std::move(patterns)))) { | |||
@@ -152,6 +152,47 @@ private: | |||
gpu::LaunchOp& m_launch_op; | |||
}; | |||
struct DimshuffleLowering : public ConversionPattern { | |||
DimshuffleLowering(MLIRContext* ctx, gpu::LaunchOp& launch_op) | |||
: ConversionPattern(dialect::Dimshuffle::getOperationName(), 1, | |||
ctx), | |||
m_launch_op{launch_op} {} | |||
static std::vector<mlir::Value> get_index_from_pattern( | |||
const std::vector<int32_t>& pattern, | |||
const std::vector<mlir::Value>& index) { | |||
size_t ndim = *std::max_element(pattern.begin(), pattern.end()) + 1; | |||
std::vector<mlir::Value> res(ndim); | |||
for (size_t i = 0; i < pattern.size(); i++) { | |||
int32_t j = pattern[i]; | |||
if (j >= 0) { | |||
res[j] = index[i]; | |||
} | |||
} | |||
return res; | |||
} | |||
LogicalResult matchAndRewrite( | |||
Operation* op, ArrayRef<Value> operands, | |||
ConversionPatternRewriter& rewriter) const final { | |||
auto loc = op->getLoc(); | |||
rewriter.setInsertionPointToEnd(&(m_launch_op.body().front())); | |||
auto dst_layout = output_layout(m_launch_op); | |||
auto index = get_multidim_tid(rewriter, loc, operands[0], dst_layout); | |||
auto pattern = llvm::dyn_cast<dialect::Dimshuffle>(op).pattern(); | |||
auto shuffled_index = get_index_from_pattern(pattern, index); | |||
rewriter.replaceOp(op, get_operand<LoadOp>(rewriter, loc, operands[0], | |||
shuffled_index)); | |||
return success(); | |||
} | |||
private: | |||
gpu::LaunchOp& m_launch_op; | |||
}; | |||
struct ReturnOpLowering : public ConversionPattern { | |||
ReturnOpLowering(MLIRContext* ctx, gpu::LaunchOp& launch_op) | |||
: ConversionPattern(dialect::ReturnOp::getOperationName(), 1, ctx), | |||
@@ -275,9 +316,9 @@ public: | |||
target.addLegalDialect<gpu::GPUDialect>(); | |||
target.addIllegalDialect<MgbDialect>(); | |||
patterns.insert<ElemwiseLowering, TypeCvtLowering, ReturnOpLowering, | |||
ConstantScalarOpLowering, AssignOpLowering>( | |||
&getContext(), launch_op); | |||
patterns.insert<ElemwiseLowering, TypeCvtLowering, DimshuffleLowering, | |||
ReturnOpLowering, ConstantScalarOpLowering, | |||
AssignOpLowering>(&getContext(), launch_op); | |||
if (failed(applyPartialConversion(func_op, target, | |||
std::move(patterns)))) { | |||
@@ -20,6 +20,7 @@ | |||
#include "megbrain/jit/mlir/ir/dialect.h" | |||
#include "megbrain/jit/mlir/ir/utils.h" | |||
#include "megbrain/opr/basic_arith.h" | |||
#include "megbrain/opr/tensor_manip.h" | |||
#include "megdnn/dtype.h" | |||
#include <mlir/Dialect/Affine/IR/AffineOps.h> | |||
@@ -160,6 +161,10 @@ private: | |||
mgb_assert( | |||
mlir::succeeded(declare(opr->output(0)->name(), out))); | |||
return; | |||
} else if (opr->same_type<opr::Dimshuffle>()) { | |||
auto&& out = gen_dimshuffle(opr->cast_final<opr::Dimshuffle>()); | |||
mgb_assert( | |||
mlir::succeeded(declare(opr->output(0)->name(), out))); | |||
} else if (opr->same_type<opr::TypeCvt>()) { | |||
auto&& out = gen_typecvt(opr->cast_final<opr::TypeCvt>()); | |||
mgb_assert( | |||
@@ -186,18 +191,44 @@ private: | |||
} | |||
mlir::Value gen_typecvt(const opr::TypeCvt& opr) { | |||
auto shape = get(opr.input(0)) | |||
auto itype = get(opr.input(0)) | |||
.getType() | |||
.dyn_cast_or_null<mlir::MemRefType>() | |||
.getShape(); | |||
.dyn_cast_or_null<mlir::MemRefType>(); | |||
mgb_assert(itype, "currently only support MemRefType"); | |||
auto res_type = mlir::MemRefType::get( | |||
shape, | |||
itype.getShape(), | |||
megdnn_dtype_to_mlir_type(opr.param(), m_builder.getContext())); | |||
return m_builder.create<dialect::TypeCvt>( | |||
m_builder.getUnknownLoc(), res_type, get(opr.input(0)), | |||
opr.input(0)->dtype(), opr.param()); | |||
} | |||
mlir::Value gen_dimshuffle(const opr::Dimshuffle& opr) { | |||
auto itype = get(opr.input(0)) | |||
.getType() | |||
.dyn_cast_or_null<mlir::MemRefType>(); | |||
mgb_assert(itype, "the input type of Dimshuffle must be MemRefType"); | |||
auto ishape = itype.getShape(); | |||
auto param = opr.param(); | |||
std::vector<int32_t> pattern; | |||
std::vector<int64_t> oshape; | |||
for (size_t i = 0; i < param.pattern_len; i++) { | |||
int32_t j = param.pattern[i]; | |||
pattern.push_back(j); | |||
if (j < 0) { | |||
oshape.push_back(1); | |||
} else { | |||
oshape.push_back(ishape[j]); | |||
} | |||
} | |||
auto res_type = mlir::MemRefType::get(oshape, itype.getElementType()); | |||
return m_builder.create<dialect::Dimshuffle>( | |||
m_builder.getUnknownLoc(), res_type, get(opr.input(0)), | |||
pattern); | |||
} | |||
mlir::Type get_type(const TensorLayout& layout) { | |||
return layout_to_mlir_type(layout, m_builder); | |||
} | |||
@@ -15,6 +15,7 @@ | |||
#include "megbrain/jit/executor_opr.h" | |||
#include "megbrain/opr/basic_arith.h" | |||
#include "megbrain/opr/basic_arith_wrapper.h" | |||
#include "megbrain/opr/tensor_manip.h" | |||
#include "megbrain/test/helper.h" | |||
#include "megdnn/dtype.h" | |||
@@ -539,6 +540,51 @@ add_typecvt_gtest(Uint8, Float32); | |||
#undef add_typecvt_gtest | |||
/* ===================== TestJITMlirDimshuffle ===================== */ | |||
void run_dimshuffle(CompNode cn, TensorShape ishape, | |||
const std::vector<int>& pattern) { | |||
set_backend(Backend::MLIR); | |||
auto graph = ComputingGraph::make(); | |||
HostTensorGenerator<> gen; | |||
auto host_x = gen(ishape, cn); | |||
auto x = opr::Host2DeviceCopy::make(*graph, host_x); | |||
auto y = opr::Dimshuffle::make(x, pattern); | |||
auto ig_gen = std::make_unique<InternalGraphGenerator>(y.node()->owner_opr()); | |||
for (auto i : get_rev_topo_order(y)) { | |||
if (!i->template same_type<opr::Host2DeviceCopy>()) { | |||
ig_gen->add_opr(i); | |||
} | |||
} | |||
auto igraph = ig_gen->generate(); | |||
auto y_jit = JITExecutor::make(igraph, ig_gen->orig_inps()); | |||
HostTensorND host_y, host_y_jit; | |||
auto func = graph->compile({make_callback_copy(y, host_y), | |||
make_callback_copy(y_jit, host_y_jit)}); | |||
func->execute(); | |||
MGB_ASSERT_TENSOR_EQ(host_y, host_y_jit); | |||
} | |||
void run_dimshuffle_cases(CompNode cn) { | |||
run_dimshuffle(cn, {3, 4, 5}, {2, 0, 1}); | |||
run_dimshuffle(cn, {3, 4, 5}, {1, -1, 0, 2}); | |||
} | |||
TEST(TestJITMlirDimshuffle, Basic) { | |||
run_dimshuffle_cases(CompNode::load("cpu0")); | |||
} | |||
TEST(TestJITMlirDimshuffle, BasicGPU) { | |||
REQUIRE_GPU(1); | |||
run_dimshuffle_cases(CompNode::load("gpu0")); | |||
} | |||
#endif // MGB_JIT_MLIR | |||
#endif // MGB_JIT | |||