@@ -294,22 +294,6 @@ void JITFusionPass::Impl::process_opr(OperatorNodeBase* opr) { | |||||
cond_nr_inp = ig_gen->get_cnt_input_if_add(opr) <= max_nr_input, | cond_nr_inp = ig_gen->get_cnt_input_if_add(opr) <= max_nr_input, | ||||
cond_mlir_specific = true; | cond_mlir_specific = true; | ||||
#if MGB_JIT_MLIR | |||||
//! FIXME mlir does't support broadcast currently. | |||||
auto backend = MGB_GETENV("MGB_JIT_BACKEND"); | |||||
if (backend && !strcmp(backend, "MLIR")) { | |||||
for (VarNode* var : opr->input()) { | |||||
if (!SymbolVar{var}.as_immutable_scalar().valid()) { | |||||
if (opr->node_prop().dep_map().at(var) & | |||||
DepType::DEV_VALUE) { | |||||
if (!var->shape().eq_shape(opr->output(0)->shape())) { | |||||
cond_mlir_specific = false; | |||||
} | |||||
} | |||||
} | |||||
} | |||||
} | |||||
#endif | |||||
if (cond_readers && cond_cn && cond_shp && cond_nr_inp && | if (cond_readers && cond_cn && cond_shp && cond_nr_inp && | ||||
cond_mlir_specific) { | cond_mlir_specific) { | ||||
ig_gen->add_opr(opr); | ig_gen->add_opr(opr); | ||||
@@ -57,23 +57,23 @@ void setup_and_launch(const JITExecutor* fusion_opr, CUfunction func, | |||||
} | } | ||||
}; | }; | ||||
for (const auto& arg : args.inputs) { | for (const auto& arg : args.inputs) { | ||||
set_params(arg.from->dev_tensor().raw_ptr(), arg.layout); | |||||
set_params(arg.from->dev_tensor().raw_ptr(), arg.from->layout()); | |||||
} | } | ||||
int64_t nr_elements = 0; | int64_t nr_elements = 0; | ||||
for (const auto& arg : args.outputs) { | for (const auto& arg : args.outputs) { | ||||
if (nr_elements == 0) { | if (nr_elements == 0) { | ||||
nr_elements = arg.layout.total_nr_elems(); | |||||
nr_elements = arg.from->layout().total_nr_elems(); | |||||
} else { | } else { | ||||
mgb_assert(static_cast<size_t>(nr_elements) == | mgb_assert(static_cast<size_t>(nr_elements) == | ||||
arg.layout.total_nr_elems(), | arg.layout.total_nr_elems(), | ||||
"The number of elements of outputs mismatch, expected: " | "The number of elements of outputs mismatch, expected: " | ||||
"%zu got: %zu(%s)", | "%zu got: %zu(%s)", | ||||
static_cast<size_t>(nr_elements), | static_cast<size_t>(nr_elements), | ||||
arg.layout.total_nr_elems(), | |||||
arg.layout.to_string().c_str()); | |||||
arg.from->layout().total_nr_elems(), | |||||
arg.from->layout().to_string().c_str()); | |||||
} | } | ||||
set_params(arg.from->dev_tensor().raw_ptr(), arg.layout); | |||||
set_params(arg.from->dev_tensor().raw_ptr(), arg.from->layout()); | |||||
} | } | ||||
const CompNodeEnv& env = | const CompNodeEnv& env = | ||||
CompNodeEnv::from_comp_node(fusion_opr->comp_node()); | CompNodeEnv::from_comp_node(fusion_opr->comp_node()); | ||||
@@ -134,8 +134,8 @@ void MLIRCUDAExecutable::FuncCache::exec(const JITExecutor* fusion_opr, | |||||
mgb_assert(fusion_opr->args().outputs.size() == 1, | mgb_assert(fusion_opr->args().outputs.size() == 1, | ||||
"Currently only support 1 outputs, got %zu", | "Currently only support 1 outputs, got %zu", | ||||
fusion_opr->args().outputs.size()); | fusion_opr->args().outputs.size()); | ||||
int out_dim = fusion_opr->args().outputs[0].layout.ndim; | |||||
DType dtype = fusion_opr->args().outputs[0].layout.dtype; | |||||
int out_dim = fusion_opr->args().outputs[0].from->layout().ndim; | |||||
DType dtype = fusion_opr->args().outputs[0].from->layout().dtype; | |||||
#define cb_outdim(_ndim, _dtype) \ | #define cb_outdim(_ndim, _dtype) \ | ||||
if (_ndim == out_dim) { \ | if (_ndim == out_dim) { \ | ||||
setup_and_launch<_ndim, _dtype>(fusion_opr, func->func, \ | setup_and_launch<_ndim, _dtype>(fusion_opr, func->func, \ | ||||
@@ -14,8 +14,10 @@ | |||||
#if MGB_JIT && MGB_JIT_MLIR | #if MGB_JIT && MGB_JIT_MLIR | ||||
#include "./common.h" | #include "./common.h" | ||||
#include "megbrain/jit/mlir/ir/utils.h" | |||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" | #include "mlir/Dialect/StandardOps/IR/Ops.h" | ||||
#include <mlir/Dialect/Affine/IR/AffineOps.h> | |||||
using namespace mgb; | using namespace mgb; | ||||
using namespace jit; | using namespace jit; | ||||
@@ -28,9 +30,11 @@ cb(add, AddFOp); | |||||
cb(sub, SubFOp); | cb(sub, SubFOp); | ||||
cb(mul, MulFOp); | cb(mul, MulFOp); | ||||
cb(div, DivFOp); | cb(div, DivFOp); | ||||
cb(divI, SignedDivIOp); | |||||
cb(mod, RemFOp); | cb(mod, RemFOp); | ||||
cb(bit_and, AndOp); | cb(bit_and, AndOp); | ||||
cb(bit_or, OrOp); | cb(bit_or, OrOp); | ||||
cb(modI, SignedRemIOp); | |||||
#undef cb | #undef cb | ||||
#define cb(name, mode) \ | #define cb(name, mode) \ | ||||
@@ -62,6 +66,11 @@ mlir::Value ValueBuilderHelper::const_val(float val) { | |||||
m_builder.getF32FloatAttr(val)); | m_builder.getF32FloatAttr(val)); | ||||
} | } | ||||
mlir::Value ValueBuilderHelper::constI(int32_t val) { | |||||
return m_builder.create<mlir::ConstantOp>(m_location, | |||||
m_builder.getIndexAttr(val)); | |||||
} | |||||
#define cb(name, op) \ | #define cb(name, op) \ | ||||
mlir::Value ValueBuilderHelper::name(mlir::Value lhs) { \ | mlir::Value ValueBuilderHelper::name(mlir::Value lhs) { \ | ||||
return m_builder.create<mlir::op>(m_location, lhs); \ | return m_builder.create<mlir::op>(m_location, lhs); \ | ||||
@@ -97,6 +106,44 @@ mlir::Value ValueBuilderHelper::select(mlir::Value cond, mlir::Value true_val, | |||||
false_val); | false_val); | ||||
} | } | ||||
mlir::AffineMap jit::get_affinemap(mlir::OpBuilder& builder, | |||||
const mlir::Value& val, | |||||
const megdnn::TensorLayout& layout) { | |||||
auto type = val.getType().cast<mlir::MemRefType>(); | |||||
mgb_assert(type, "currently only support MemRefType"); | |||||
std::vector<mlir::AffineExpr> exprs; | |||||
for (int i = 0; i < type.getRank(); ++i) { | |||||
if (layout[i] == 1) { | |||||
exprs.push_back(builder.getAffineConstantExpr(0)); | |||||
} else { | |||||
exprs.push_back(builder.getAffineDimExpr(i)); | |||||
} | |||||
} | |||||
auto map = mlir::AffineMap::get(type.getRank(), 0, exprs, | |||||
builder.getContext()); | |||||
return map; | |||||
} | |||||
mlir::Value jit::get_affine_load_op(mlir::OpBuilder& builder, | |||||
const mlir::Location& loc, | |||||
const mlir::Value& val, | |||||
const mlir::ValueRange& index, | |||||
const megdnn::TensorLayout& dst) { | |||||
if (val.getType().isa<mlir::MemRefType>()) { | |||||
auto type = val.getType().cast<mlir::MemRefType>(); | |||||
megdnn::TensorLayout src_layout = mlir_type_to_layout(type); | |||||
src_layout.init_contiguous_stride(); | |||||
if (src_layout.eq_shape(dst)) { | |||||
return builder.create<mlir::AffineLoadOp>(loc, val, index); | |||||
} else { | |||||
auto lhs_map = get_affinemap(builder, val, src_layout); | |||||
return builder.create<mlir::AffineLoadOp>(loc, val, lhs_map, index); | |||||
} | |||||
} else { | |||||
return val; | |||||
} | |||||
} | |||||
#endif // MGB_JIT && MGB_JIT_MLIR | #endif // MGB_JIT && MGB_JIT_MLIR | ||||
// vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen |
@@ -14,7 +14,7 @@ | |||||
#include "megbrain_build_config.h" | #include "megbrain_build_config.h" | ||||
#if MGB_JIT && MGB_JIT_MLIR | #if MGB_JIT && MGB_JIT_MLIR | ||||
#include "megbrain/tensor.h" | |||||
#include <mlir/Dialect/StandardOps/IR/Ops.h> | #include <mlir/Dialect/StandardOps/IR/Ops.h> | ||||
#include <mlir/IR/OperationSupport.h> | #include <mlir/IR/OperationSupport.h> | ||||
#include <mlir/IR/Value.h> | #include <mlir/IR/Value.h> | ||||
@@ -39,9 +39,11 @@ public: | |||||
cb(sub); | cb(sub); | ||||
cb(mul); | cb(mul); | ||||
cb(div); | cb(div); | ||||
cb(divI); | |||||
cb(max); | cb(max); | ||||
cb(min); | cb(min); | ||||
cb(mod); | cb(mod); | ||||
cb(modI); | |||||
cb(gt); | cb(gt); | ||||
cb(ge); | cb(ge); | ||||
cb(lt); | cb(lt); | ||||
@@ -51,6 +53,7 @@ public: | |||||
cb(bit_or); | cb(bit_or); | ||||
#undef cb | #undef cb | ||||
mlir::Value const_val(float val); | mlir::Value const_val(float val); | ||||
mlir::Value constI(int32_t val); | |||||
#define cb(name) \ | #define cb(name) \ | ||||
mlir::Value name(mlir::ValueRange operands) { return name(operands[0]); } \ | mlir::Value name(mlir::ValueRange operands) { return name(operands[0]); } \ | ||||
@@ -89,6 +92,15 @@ mlir::Value get_operand(mlir::OpBuilder& builder, const mlir::Location& loc, | |||||
} | } | ||||
} | } | ||||
mlir::AffineMap get_affinemap(mlir::OpBuilder& builder, const mlir::Value& val, | |||||
const TensorLayout& layout); | |||||
mlir::Value get_affine_load_op(mlir::OpBuilder& builder, | |||||
const mlir::Location& loc, | |||||
const mlir::Value& val, | |||||
const mlir::ValueRange& index, | |||||
const TensorLayout& dst); | |||||
} // namespace jit | } // namespace jit | ||||
} // namespace mgb | } // namespace mgb | ||||
@@ -42,8 +42,8 @@ void lower_op_to_loops(Operation* op, ValueRange operands, | |||||
auto alloc = jit::insert_alloc_and_dealloc(memref_type, loc, rewriter); | auto alloc = jit::insert_alloc_and_dealloc(memref_type, loc, rewriter); | ||||
SmallVector<int64_t, 4> lower_bounds(memref_type.getRank(), 0); | |||||
SmallVector<int64_t, 4> steps(memref_type.getRank(), 1); | |||||
llvm::SmallVector<int64_t, 4> lower_bounds(memref_type.getRank(), 0); | |||||
llvm::SmallVector<int64_t, 4> steps(memref_type.getRank(), 1); | |||||
buildAffineLoopNest( | buildAffineLoopNest( | ||||
rewriter, loc, lower_bounds, memref_type.getShape(), steps, | rewriter, loc, lower_bounds, memref_type.getShape(), steps, | ||||
[&](OpBuilder& nested_builder, Location loc, ValueRange ivs) { | [&](OpBuilder& nested_builder, Location loc, ValueRange ivs) { | ||||
@@ -96,17 +96,23 @@ struct BinaryOpLowering : public ConversionPattern { | |||||
Operation* op, ArrayRef<Value> operands, | Operation* op, ArrayRef<Value> operands, | ||||
ConversionPatternRewriter& rewriter) const final { | ConversionPatternRewriter& rewriter) const final { | ||||
auto loc = op->getLoc(); | auto loc = op->getLoc(); | ||||
auto dst_memref_type = (*op->result_type_begin()).cast<MemRefType>(); | |||||
megdnn::TensorLayout dst_layout = mlir_type_to_layout(dst_memref_type); | |||||
dst_layout.init_contiguous_stride(); | |||||
lower_op_to_loops( | lower_op_to_loops( | ||||
op, operands, rewriter, | op, operands, rewriter, | ||||
[loc](OpBuilder& builder, ValueRange memref_operands, | |||||
ValueRange loop_ivs) { | |||||
[dst_layout, loc, this](OpBuilder& builder, | |||||
ValueRange memref_operands, | |||||
ValueRange loop_ivs) { | |||||
typename Op::Adaptor binary_adaptor(memref_operands); | typename Op::Adaptor binary_adaptor(memref_operands); | ||||
LoweredOp lower_op; | LoweredOp lower_op; | ||||
auto loaded_lhs = get_operand<AffineLoadOp>( | |||||
builder, loc, binary_adaptor.lhs(), loop_ivs); | |||||
auto loaded_rhs = get_operand<AffineLoadOp>( | |||||
builder, loc, binary_adaptor.rhs(), loop_ivs); | |||||
auto loaded_lhs = get_affine_load_op(builder, loc, | |||||
binary_adaptor.lhs(), | |||||
loop_ivs, dst_layout); | |||||
auto loaded_rhs = get_affine_load_op(builder, loc, | |||||
binary_adaptor.rhs(), | |||||
loop_ivs, dst_layout); | |||||
return lower_op(builder, loc, {loaded_lhs, loaded_rhs}); | return lower_op(builder, loc, {loaded_lhs, loaded_rhs}); | ||||
}); | }); | ||||
@@ -128,19 +134,26 @@ struct TernaryOpLowering : public ConversionPattern { | |||||
Operation* op, ArrayRef<Value> operands, | Operation* op, ArrayRef<Value> operands, | ||||
ConversionPatternRewriter& rewriter) const final { | ConversionPatternRewriter& rewriter) const final { | ||||
auto loc = op->getLoc(); | auto loc = op->getLoc(); | ||||
auto dst_memref_type = (*op->result_type_begin()).cast<MemRefType>(); | |||||
megdnn::TensorLayout dst_layout = mlir_type_to_layout(dst_memref_type); | |||||
dst_layout.init_contiguous_stride(); | |||||
lower_op_to_loops( | lower_op_to_loops( | ||||
op, operands, rewriter, | op, operands, rewriter, | ||||
[loc](OpBuilder& builder, ValueRange memref_operands, | |||||
ValueRange loop_ivs) { | |||||
[dst_layout, loc](OpBuilder& builder, | |||||
ValueRange memref_operands, | |||||
ValueRange loop_ivs) { | |||||
typename Op::Adaptor ternary_adaptor(memref_operands); | typename Op::Adaptor ternary_adaptor(memref_operands); | ||||
LoweredOp lower_op; | LoweredOp lower_op; | ||||
auto loaded_x = get_operand<AffineLoadOp>( | |||||
builder, loc, ternary_adaptor.x(), loop_ivs); | |||||
auto loaded_y = get_operand<AffineLoadOp>( | |||||
builder, loc, ternary_adaptor.y(), loop_ivs); | |||||
auto loaded_z = get_operand<AffineLoadOp>( | |||||
builder, loc, ternary_adaptor.z(), loop_ivs); | |||||
auto loaded_x = get_affine_load_op(builder, loc, | |||||
ternary_adaptor.x(), | |||||
loop_ivs, dst_layout); | |||||
auto loaded_y = get_affine_load_op(builder, loc, | |||||
ternary_adaptor.y(), | |||||
loop_ivs, dst_layout); | |||||
auto loaded_z = get_affine_load_op(builder, loc, | |||||
ternary_adaptor.z(), | |||||
loop_ivs, dst_layout); | |||||
return lower_op(builder, loc, | return lower_op(builder, loc, | ||||
{loaded_x, loaded_y, loaded_z}); | {loaded_x, loaded_y, loaded_z}); | ||||
@@ -166,8 +179,8 @@ struct AssignOpLowering : public ConversionPattern { | |||||
auto memref_type = operands[0].getType().cast<MemRefType>(); | auto memref_type = operands[0].getType().cast<MemRefType>(); | ||||
AssignOpAdaptor assign_adaptor(operands); | AssignOpAdaptor assign_adaptor(operands); | ||||
SmallVector<int64_t, 4> lower_bounds(memref_type.getRank(), 0); | |||||
SmallVector<int64_t, 4> steps(memref_type.getRank(), 1); | |||||
llvm::SmallVector<int64_t, 4> lower_bounds(memref_type.getRank(), 0); | |||||
llvm::SmallVector<int64_t, 4> steps(memref_type.getRank(), 1); | |||||
buildAffineLoopNest( | buildAffineLoopNest( | ||||
rewriter, loc, lower_bounds, memref_type.getShape(), steps, | rewriter, loc, lower_bounds, memref_type.getShape(), steps, | ||||
[&](OpBuilder& nested_builder, Location loc, ValueRange ivs) { | [&](OpBuilder& nested_builder, Location loc, ValueRange ivs) { | ||||
@@ -52,6 +52,54 @@ mlir::Value get_tid(ConversionPatternRewriter& rewriter, const Location& loc) { | |||||
return index; | return index; | ||||
} | } | ||||
megdnn::TensorLayout output_layout(gpu::LaunchOp& launch_op) { | |||||
auto func_op = launch_op.getParentOfType<mlir::FuncOp>(); | |||||
mgb_assert(func_op, "Unexpexted launch op."); | |||||
for (auto block_iter = func_op.rbegin(); block_iter != func_op.rend(); | |||||
block_iter++) { | |||||
for (auto op_iter = block_iter->rbegin(); op_iter != block_iter->rend(); | |||||
op_iter++) { | |||||
auto op = llvm::dyn_cast_or_null<AssignOp>(&(*op_iter)); | |||||
if (op && op.getNumOperands() > 0) { | |||||
return mlir_type_to_layout(*(op.operand_type_begin())); | |||||
} | |||||
} | |||||
} | |||||
mgb_throw(MegBrainError, "Unexpexted launch op."); | |||||
} | |||||
std::vector<mlir::Value> get_multidim_tid(ConversionPatternRewriter& rewriter, | |||||
const Location& loc, | |||||
const mlir::Value& val, | |||||
const megdnn::TensorLayout& dst) { | |||||
Value index = get_tid(rewriter, loc); | |||||
auto type = val.getType().dyn_cast_or_null<mlir::MemRefType>(); | |||||
if (type) { | |||||
ValueBuilderHelper helper(rewriter, loc); | |||||
std::vector<mlir::Value> idxs; | |||||
idxs.resize(dst.ndim); | |||||
mlir::Value dim_index = index; | |||||
for (int i = dst.ndim - 1; i >= 0; i--) { | |||||
auto cur_index = helper.modI(dim_index, helper.constI(dst[i])); | |||||
idxs[i] = cur_index; | |||||
dim_index = helper.divI(dim_index, helper.constI(dst[i])); | |||||
} | |||||
megdnn::TensorLayout src_layout = mlir_type_to_layout(type); | |||||
src_layout.init_contiguous_stride(); | |||||
for (int i = 0; i < type.getRank(); ++i) { | |||||
if (src_layout[i] == 1) { | |||||
idxs[i] = helper.constI(0); | |||||
} | |||||
} | |||||
return idxs; | |||||
} else { | |||||
return {index}; | |||||
} | |||||
} | |||||
template <typename Op, typename LoweredOp> | template <typename Op, typename LoweredOp> | ||||
struct UnaryOpLowering : public ConversionPattern { | struct UnaryOpLowering : public ConversionPattern { | ||||
UnaryOpLowering(MLIRContext* ctx, gpu::LaunchOp& launch_op) | UnaryOpLowering(MLIRContext* ctx, gpu::LaunchOp& launch_op) | ||||
@@ -66,7 +114,9 @@ struct UnaryOpLowering : public ConversionPattern { | |||||
typename Op::Adaptor binary_adaptor(operands); | typename Op::Adaptor binary_adaptor(operands); | ||||
rewriter.setInsertionPointToEnd(&(m_launch_op.body().front())); | rewriter.setInsertionPointToEnd(&(m_launch_op.body().front())); | ||||
auto index = get_tid(rewriter, loc); | |||||
auto dst_layout = output_layout(m_launch_op); | |||||
auto index = get_multidim_tid(rewriter, loc, binary_adaptor.lhs(), | |||||
dst_layout); | |||||
auto loaded_lhs = | auto loaded_lhs = | ||||
get_operand<LoadOp>(rewriter, loc, binary_adaptor.lhs(), index); | get_operand<LoadOp>(rewriter, loc, binary_adaptor.lhs(), index); | ||||
@@ -99,11 +149,15 @@ struct BinaryOpLowering : public ConversionPattern { | |||||
typename Op::Adaptor binary_adaptor(operands); | typename Op::Adaptor binary_adaptor(operands); | ||||
rewriter.setInsertionPointToEnd(&(m_launch_op.body().front())); | rewriter.setInsertionPointToEnd(&(m_launch_op.body().front())); | ||||
auto index = get_tid(rewriter, loc); | |||||
auto loaded_lhs = | |||||
get_operand<LoadOp>(rewriter, loc, binary_adaptor.lhs(), index); | |||||
auto loaded_rhs = | |||||
get_operand<LoadOp>(rewriter, loc, binary_adaptor.rhs(), index); | |||||
auto dst_layout = output_layout(m_launch_op); | |||||
auto lhs_index = get_multidim_tid(rewriter, loc, binary_adaptor.lhs(), | |||||
dst_layout); | |||||
auto rhs_index = get_multidim_tid(rewriter, loc, binary_adaptor.rhs(), | |||||
dst_layout); | |||||
auto loaded_lhs = get_operand<LoadOp>(rewriter, loc, | |||||
binary_adaptor.lhs(), lhs_index); | |||||
auto loaded_rhs = get_operand<LoadOp>(rewriter, loc, | |||||
binary_adaptor.rhs(), rhs_index); | |||||
LoweredOp lower_op; | LoweredOp lower_op; | ||||
@@ -135,13 +189,19 @@ struct TernaryOpLowering : public ConversionPattern { | |||||
typename Op::Adaptor ternary_adaptor(operands); | typename Op::Adaptor ternary_adaptor(operands); | ||||
rewriter.setInsertionPointToEnd(&(m_launch_op.body().front())); | rewriter.setInsertionPointToEnd(&(m_launch_op.body().front())); | ||||
auto index = get_tid(rewriter, loc); | |||||
auto loaded_x = | |||||
get_operand<LoadOp>(rewriter, loc, ternary_adaptor.x(), index); | |||||
auto loaded_y = | |||||
get_operand<LoadOp>(rewriter, loc, ternary_adaptor.y(), index); | |||||
auto loaded_z = | |||||
get_operand<LoadOp>(rewriter, loc, ternary_adaptor.z(), index); | |||||
auto dst_layout = output_layout(m_launch_op); | |||||
auto index_x = get_multidim_tid(rewriter, loc, ternary_adaptor.x(), | |||||
dst_layout); | |||||
auto index_y = get_multidim_tid(rewriter, loc, ternary_adaptor.y(), | |||||
dst_layout); | |||||
auto index_z = get_multidim_tid(rewriter, loc, ternary_adaptor.z(), | |||||
dst_layout); | |||||
auto loaded_x = get_operand<LoadOp>(rewriter, loc, ternary_adaptor.x(), | |||||
index_x); | |||||
auto loaded_y = get_operand<LoadOp>(rewriter, loc, ternary_adaptor.y(), | |||||
index_y); | |||||
auto loaded_z = get_operand<LoadOp>(rewriter, loc, ternary_adaptor.z(), | |||||
index_z); | |||||
LoweredOp lower_op; | LoweredOp lower_op; | ||||
@@ -242,7 +302,9 @@ struct AssignOpLowering : public ConversionPattern { | |||||
AssignOpAdaptor assign_adaptor(operands); | AssignOpAdaptor assign_adaptor(operands); | ||||
rewriter.setInsertionPointToEnd(&(m_launch_op.body().front())); | rewriter.setInsertionPointToEnd(&(m_launch_op.body().front())); | ||||
auto index = get_tid(rewriter, loc); | |||||
auto dst_layout = output_layout(m_launch_op); | |||||
auto index = get_multidim_tid(rewriter, loc, assign_adaptor.rhs(), | |||||
dst_layout); | |||||
auto loaded_lhs = | auto loaded_lhs = | ||||
get_operand<LoadOp>(rewriter, loc, assign_adaptor.lhs(), index); | get_operand<LoadOp>(rewriter, loc, assign_adaptor.lhs(), index); | ||||
@@ -98,7 +98,6 @@ mlir::MemRefType jit::layout_to_mlir_type(const megdnn::TensorLayout& layout, | |||||
for (size_t i = 0; i < layout.ndim; i++) { | for (size_t i = 0; i < layout.ndim; i++) { | ||||
shape.push_back(layout[i]); | shape.push_back(layout[i]); | ||||
} | } | ||||
switch (layout.dtype.enumv()) { | switch (layout.dtype.enumv()) { | ||||
case megdnn::DTypeEnum::Float32: | case megdnn::DTypeEnum::Float32: | ||||
return mlir::MemRefType::get(shape, builder.getF32Type()); | return mlir::MemRefType::get(shape, builder.getF32Type()); | ||||
@@ -73,10 +73,10 @@ private: | |||||
m_symbol_table); | m_symbol_table); | ||||
std::vector<mlir::Type> func_args; | std::vector<mlir::Type> func_args; | ||||
for (auto&& arg : args.inputs) { | for (auto&& arg : args.inputs) { | ||||
func_args.push_back(get_type(arg.layout)); | |||||
func_args.push_back(get_type(arg.from->layout())); | |||||
} | } | ||||
for (auto&& arg : args.outputs) { | for (auto&& arg : args.outputs) { | ||||
func_args.push_back(get_type(arg.layout)); | |||||
func_args.push_back(get_type(arg.from->layout())); | |||||
} | } | ||||
//! the last arg is nr_elements | //! the last arg is nr_elements | ||||
func_args.push_back(m_builder.getIndexType()); | func_args.push_back(m_builder.getIndexType()); | ||||
@@ -44,7 +44,6 @@ megdnn::TensorLayout mlir_type_to_layout(mlir::Type type); | |||||
megdnn::DType mlir_type_to_dtype(mlir::Type type); | megdnn::DType mlir_type_to_dtype(mlir::Type type); | ||||
mlir::MemRefType layout_to_mlir_type(const megdnn::TensorLayout& layout, | mlir::MemRefType layout_to_mlir_type(const megdnn::TensorLayout& layout, | ||||
mlir::Builder& builder); | mlir::Builder& builder); | ||||
} // namespace jit | } // namespace jit | ||||
} // namespace mgb | } // namespace mgb | ||||
@@ -130,8 +130,8 @@ void run_mlir(CompNode cn) { | |||||
auto graph = ComputingGraph::make(); | auto graph = ComputingGraph::make(); | ||||
HostTensorGenerator<dtype::Float32> gen; | HostTensorGenerator<dtype::Float32> gen; | ||||
auto host_x0 = gen({23, 42}, cn), host_x1 = gen({23, 42}, cn), | |||||
host_x2 = gen({23, 42}, cn), host_x3 = gen({23, 42}, cn); | |||||
auto host_x0 = gen({23, 42}, cn), host_x1 = gen({23, 1}, cn), | |||||
host_x2 = gen({23, 42}, cn); | |||||
auto a = opr::Host2DeviceCopy::make(*graph, host_x0), | auto a = opr::Host2DeviceCopy::make(*graph, host_x0), | ||||
b = opr::Host2DeviceCopy::make(*graph, host_x1), | b = opr::Host2DeviceCopy::make(*graph, host_x1), | ||||
@@ -159,6 +159,43 @@ void run_mlir(CompNode cn) { | |||||
MGB_ASSERT_TENSOR_EQ(host_y, host_y_jit); | MGB_ASSERT_TENSOR_EQ(host_y, host_y_jit); | ||||
} | } | ||||
void run_mlir_broadcast(CompNode cn) { | |||||
set_backend(Backend::MLIR); | |||||
auto graph = ComputingGraph::make(); | |||||
HostTensorGenerator<dtype::Float32> gen; | |||||
auto host_x0 = gen({10, 20, 5, 6}, cn), host_x1 = gen({1, 20, 1, 1}, cn), | |||||
host_x2 = gen({10, 1, 5, 1}, cn), host_x3 = gen({10, 1, 1, 1}, cn); | |||||
auto a = opr::Host2DeviceCopy::make(*graph, host_x0), | |||||
b = opr::Host2DeviceCopy::make(*graph, host_x1), | |||||
c = opr::Host2DeviceCopy::make(*graph, host_x2), | |||||
d = opr::Host2DeviceCopy::make(*graph, host_x3); | |||||
auto y = | |||||
opr::Elemwise::make({a, b, c}, opr::Elemwise::Mode::FUSE_MUL_ADD3) + | |||||
opr::Elemwise::make({d}, opr::Elemwise::Mode::ABS) - 0.3f; | |||||
auto ig_gen = | |||||
std::make_unique<InternalGraphGenerator>(y.node()->owner_opr()); | |||||
for (auto i : get_rev_topo_order(y)) { | |||||
if (!i->same_type<opr::Host2DeviceCopy>()) { | |||||
ig_gen->add_opr(i); | |||||
} | |||||
} | |||||
auto igraph = ig_gen->generate(); | |||||
auto y_jit = JITExecutor::make(igraph, ig_gen->orig_inps()); | |||||
HostTensorND host_y, host_y_jit; | |||||
auto func = graph->compile({make_callback_copy(y, host_y), | |||||
make_callback_copy(y_jit, host_y_jit)}); | |||||
func->execute(); | |||||
MGB_ASSERT_TENSOR_EQ(host_y, host_y_jit); | |||||
} | |||||
struct MlirTestOpt { | struct MlirTestOpt { | ||||
float low; | float low; | ||||
float high; | float high; | ||||
@@ -252,12 +289,14 @@ TYPED_TEST(TestJITNvrtcCodeGen, run) { | |||||
TEST(TestJITMlirCodeGen, Basic) { | TEST(TestJITMlirCodeGen, Basic) { | ||||
auto cn = CompNode::load("cpu0"); | auto cn = CompNode::load("cpu0"); | ||||
run_mlir(cn); | run_mlir(cn); | ||||
run_mlir_broadcast(cn); | |||||
} | } | ||||
TEST(TestJITMlirCodeGen, BasicGPU) { | TEST(TestJITMlirCodeGen, BasicGPU) { | ||||
REQUIRE_GPU(1); | REQUIRE_GPU(1); | ||||
auto cn = CompNode::load("gpu0"); | auto cn = CompNode::load("gpu0"); | ||||
run_mlir(cn); | run_mlir(cn); | ||||
run_mlir_broadcast(cn); | |||||
} | } | ||||
///////////////////////// unary /////////////////////////////// | ///////////////////////// unary /////////////////////////////// | ||||
@@ -1580,8 +1580,8 @@ void run_mlir(CompNode cn) { | |||||
JITExecutor* jit; | JITExecutor* jit; | ||||
unpack_vector(find_oprs<JITExecutor>(*funcs.second), jit); | unpack_vector(find_oprs<JITExecutor>(*funcs.second), jit); | ||||
ASSERT_EQ(2u, find_oprs<opr::Elemwise>(*funcs.second).size()); | |||||
ASSERT_EQ(3u, jit->input().size()); | |||||
ASSERT_EQ(0u, find_oprs<opr::Elemwise>(*funcs.second).size()); | |||||
ASSERT_EQ(5u, jit->input().size()); | |||||
} | } | ||||
TEST(TestJITExecutor, TestJITMlirFusion) { | TEST(TestJITExecutor, TestJITMlirFusion) { | ||||