@@ -294,22 +294,6 @@ void JITFusionPass::Impl::process_opr(OperatorNodeBase* opr) { | |||
cond_nr_inp = ig_gen->get_cnt_input_if_add(opr) <= max_nr_input, | |||
cond_mlir_specific = true; | |||
#if MGB_JIT_MLIR | |||
//! FIXME mlir does't support broadcast currently. | |||
auto backend = MGB_GETENV("MGB_JIT_BACKEND"); | |||
if (backend && !strcmp(backend, "MLIR")) { | |||
for (VarNode* var : opr->input()) { | |||
if (!SymbolVar{var}.as_immutable_scalar().valid()) { | |||
if (opr->node_prop().dep_map().at(var) & | |||
DepType::DEV_VALUE) { | |||
if (!var->shape().eq_shape(opr->output(0)->shape())) { | |||
cond_mlir_specific = false; | |||
} | |||
} | |||
} | |||
} | |||
} | |||
#endif | |||
if (cond_readers && cond_cn && cond_shp && cond_nr_inp && | |||
cond_mlir_specific) { | |||
ig_gen->add_opr(opr); | |||
@@ -57,23 +57,23 @@ void setup_and_launch(const JITExecutor* fusion_opr, CUfunction func, | |||
} | |||
}; | |||
for (const auto& arg : args.inputs) { | |||
set_params(arg.from->dev_tensor().raw_ptr(), arg.layout); | |||
set_params(arg.from->dev_tensor().raw_ptr(), arg.from->layout()); | |||
} | |||
int64_t nr_elements = 0; | |||
for (const auto& arg : args.outputs) { | |||
if (nr_elements == 0) { | |||
nr_elements = arg.layout.total_nr_elems(); | |||
nr_elements = arg.from->layout().total_nr_elems(); | |||
} else { | |||
mgb_assert(static_cast<size_t>(nr_elements) == | |||
arg.layout.total_nr_elems(), | |||
"The number of elements of outputs mismatch, expected: " | |||
"%zu got: %zu(%s)", | |||
static_cast<size_t>(nr_elements), | |||
arg.layout.total_nr_elems(), | |||
arg.layout.to_string().c_str()); | |||
arg.from->layout().total_nr_elems(), | |||
arg.from->layout().to_string().c_str()); | |||
} | |||
set_params(arg.from->dev_tensor().raw_ptr(), arg.layout); | |||
set_params(arg.from->dev_tensor().raw_ptr(), arg.from->layout()); | |||
} | |||
const CompNodeEnv& env = | |||
CompNodeEnv::from_comp_node(fusion_opr->comp_node()); | |||
@@ -134,8 +134,8 @@ void MLIRCUDAExecutable::FuncCache::exec(const JITExecutor* fusion_opr, | |||
mgb_assert(fusion_opr->args().outputs.size() == 1, | |||
"Currently only support 1 outputs, got %zu", | |||
fusion_opr->args().outputs.size()); | |||
int out_dim = fusion_opr->args().outputs[0].layout.ndim; | |||
DType dtype = fusion_opr->args().outputs[0].layout.dtype; | |||
int out_dim = fusion_opr->args().outputs[0].from->layout().ndim; | |||
DType dtype = fusion_opr->args().outputs[0].from->layout().dtype; | |||
#define cb_outdim(_ndim, _dtype) \ | |||
if (_ndim == out_dim) { \ | |||
setup_and_launch<_ndim, _dtype>(fusion_opr, func->func, \ | |||
@@ -14,8 +14,10 @@ | |||
#if MGB_JIT && MGB_JIT_MLIR | |||
#include "./common.h" | |||
#include "megbrain/jit/mlir/ir/utils.h" | |||
#include "mlir/Dialect/StandardOps/IR/Ops.h" | |||
#include <mlir/Dialect/Affine/IR/AffineOps.h> | |||
using namespace mgb; | |||
using namespace jit; | |||
@@ -28,9 +30,11 @@ cb(add, AddFOp); | |||
cb(sub, SubFOp); | |||
cb(mul, MulFOp); | |||
cb(div, DivFOp); | |||
cb(divI, SignedDivIOp); | |||
cb(mod, RemFOp); | |||
cb(bit_and, AndOp); | |||
cb(bit_or, OrOp); | |||
cb(modI, SignedRemIOp); | |||
#undef cb | |||
#define cb(name, mode) \ | |||
@@ -62,6 +66,11 @@ mlir::Value ValueBuilderHelper::const_val(float val) { | |||
m_builder.getF32FloatAttr(val)); | |||
} | |||
mlir::Value ValueBuilderHelper::constI(int32_t val) { | |||
return m_builder.create<mlir::ConstantOp>(m_location, | |||
m_builder.getIndexAttr(val)); | |||
} | |||
#define cb(name, op) \ | |||
mlir::Value ValueBuilderHelper::name(mlir::Value lhs) { \ | |||
return m_builder.create<mlir::op>(m_location, lhs); \ | |||
@@ -97,6 +106,44 @@ mlir::Value ValueBuilderHelper::select(mlir::Value cond, mlir::Value true_val, | |||
false_val); | |||
} | |||
mlir::AffineMap jit::get_affinemap(mlir::OpBuilder& builder, | |||
const mlir::Value& val, | |||
const megdnn::TensorLayout& layout) { | |||
auto type = val.getType().cast<mlir::MemRefType>(); | |||
mgb_assert(type, "currently only support MemRefType"); | |||
std::vector<mlir::AffineExpr> exprs; | |||
for (int i = 0; i < type.getRank(); ++i) { | |||
if (layout[i] == 1) { | |||
exprs.push_back(builder.getAffineConstantExpr(0)); | |||
} else { | |||
exprs.push_back(builder.getAffineDimExpr(i)); | |||
} | |||
} | |||
auto map = mlir::AffineMap::get(type.getRank(), 0, exprs, | |||
builder.getContext()); | |||
return map; | |||
} | |||
mlir::Value jit::get_affine_load_op(mlir::OpBuilder& builder, | |||
const mlir::Location& loc, | |||
const mlir::Value& val, | |||
const mlir::ValueRange& index, | |||
const megdnn::TensorLayout& dst) { | |||
if (val.getType().isa<mlir::MemRefType>()) { | |||
auto type = val.getType().cast<mlir::MemRefType>(); | |||
megdnn::TensorLayout src_layout = mlir_type_to_layout(type); | |||
src_layout.init_contiguous_stride(); | |||
if (src_layout.eq_shape(dst)) { | |||
return builder.create<mlir::AffineLoadOp>(loc, val, index); | |||
} else { | |||
auto lhs_map = get_affinemap(builder, val, src_layout); | |||
return builder.create<mlir::AffineLoadOp>(loc, val, lhs_map, index); | |||
} | |||
} else { | |||
return val; | |||
} | |||
} | |||
#endif // MGB_JIT && MGB_JIT_MLIR | |||
// vim: syntax=cpp.doxygen |
@@ -14,7 +14,7 @@ | |||
#include "megbrain_build_config.h" | |||
#if MGB_JIT && MGB_JIT_MLIR | |||
#include "megbrain/tensor.h" | |||
#include <mlir/Dialect/StandardOps/IR/Ops.h> | |||
#include <mlir/IR/OperationSupport.h> | |||
#include <mlir/IR/Value.h> | |||
@@ -39,9 +39,11 @@ public: | |||
cb(sub); | |||
cb(mul); | |||
cb(div); | |||
cb(divI); | |||
cb(max); | |||
cb(min); | |||
cb(mod); | |||
cb(modI); | |||
cb(gt); | |||
cb(ge); | |||
cb(lt); | |||
@@ -51,6 +53,7 @@ public: | |||
cb(bit_or); | |||
#undef cb | |||
mlir::Value const_val(float val); | |||
mlir::Value constI(int32_t val); | |||
#define cb(name) \ | |||
mlir::Value name(mlir::ValueRange operands) { return name(operands[0]); } \ | |||
@@ -89,6 +92,15 @@ mlir::Value get_operand(mlir::OpBuilder& builder, const mlir::Location& loc, | |||
} | |||
} | |||
mlir::AffineMap get_affinemap(mlir::OpBuilder& builder, const mlir::Value& val, | |||
const TensorLayout& layout); | |||
mlir::Value get_affine_load_op(mlir::OpBuilder& builder, | |||
const mlir::Location& loc, | |||
const mlir::Value& val, | |||
const mlir::ValueRange& index, | |||
const TensorLayout& dst); | |||
} // namespace jit | |||
} // namespace mgb | |||
@@ -42,8 +42,8 @@ void lower_op_to_loops(Operation* op, ValueRange operands, | |||
auto alloc = jit::insert_alloc_and_dealloc(memref_type, loc, rewriter); | |||
SmallVector<int64_t, 4> lower_bounds(memref_type.getRank(), 0); | |||
SmallVector<int64_t, 4> steps(memref_type.getRank(), 1); | |||
llvm::SmallVector<int64_t, 4> lower_bounds(memref_type.getRank(), 0); | |||
llvm::SmallVector<int64_t, 4> steps(memref_type.getRank(), 1); | |||
buildAffineLoopNest( | |||
rewriter, loc, lower_bounds, memref_type.getShape(), steps, | |||
[&](OpBuilder& nested_builder, Location loc, ValueRange ivs) { | |||
@@ -96,17 +96,23 @@ struct BinaryOpLowering : public ConversionPattern { | |||
Operation* op, ArrayRef<Value> operands, | |||
ConversionPatternRewriter& rewriter) const final { | |||
auto loc = op->getLoc(); | |||
auto dst_memref_type = (*op->result_type_begin()).cast<MemRefType>(); | |||
megdnn::TensorLayout dst_layout = mlir_type_to_layout(dst_memref_type); | |||
dst_layout.init_contiguous_stride(); | |||
lower_op_to_loops( | |||
op, operands, rewriter, | |||
[loc](OpBuilder& builder, ValueRange memref_operands, | |||
ValueRange loop_ivs) { | |||
[dst_layout, loc, this](OpBuilder& builder, | |||
ValueRange memref_operands, | |||
ValueRange loop_ivs) { | |||
typename Op::Adaptor binary_adaptor(memref_operands); | |||
LoweredOp lower_op; | |||
auto loaded_lhs = get_operand<AffineLoadOp>( | |||
builder, loc, binary_adaptor.lhs(), loop_ivs); | |||
auto loaded_rhs = get_operand<AffineLoadOp>( | |||
builder, loc, binary_adaptor.rhs(), loop_ivs); | |||
auto loaded_lhs = get_affine_load_op(builder, loc, | |||
binary_adaptor.lhs(), | |||
loop_ivs, dst_layout); | |||
auto loaded_rhs = get_affine_load_op(builder, loc, | |||
binary_adaptor.rhs(), | |||
loop_ivs, dst_layout); | |||
return lower_op(builder, loc, {loaded_lhs, loaded_rhs}); | |||
}); | |||
@@ -128,19 +134,26 @@ struct TernaryOpLowering : public ConversionPattern { | |||
Operation* op, ArrayRef<Value> operands, | |||
ConversionPatternRewriter& rewriter) const final { | |||
auto loc = op->getLoc(); | |||
auto dst_memref_type = (*op->result_type_begin()).cast<MemRefType>(); | |||
megdnn::TensorLayout dst_layout = mlir_type_to_layout(dst_memref_type); | |||
dst_layout.init_contiguous_stride(); | |||
lower_op_to_loops( | |||
op, operands, rewriter, | |||
[loc](OpBuilder& builder, ValueRange memref_operands, | |||
ValueRange loop_ivs) { | |||
[dst_layout, loc](OpBuilder& builder, | |||
ValueRange memref_operands, | |||
ValueRange loop_ivs) { | |||
typename Op::Adaptor ternary_adaptor(memref_operands); | |||
LoweredOp lower_op; | |||
auto loaded_x = get_operand<AffineLoadOp>( | |||
builder, loc, ternary_adaptor.x(), loop_ivs); | |||
auto loaded_y = get_operand<AffineLoadOp>( | |||
builder, loc, ternary_adaptor.y(), loop_ivs); | |||
auto loaded_z = get_operand<AffineLoadOp>( | |||
builder, loc, ternary_adaptor.z(), loop_ivs); | |||
auto loaded_x = get_affine_load_op(builder, loc, | |||
ternary_adaptor.x(), | |||
loop_ivs, dst_layout); | |||
auto loaded_y = get_affine_load_op(builder, loc, | |||
ternary_adaptor.y(), | |||
loop_ivs, dst_layout); | |||
auto loaded_z = get_affine_load_op(builder, loc, | |||
ternary_adaptor.z(), | |||
loop_ivs, dst_layout); | |||
return lower_op(builder, loc, | |||
{loaded_x, loaded_y, loaded_z}); | |||
@@ -166,8 +179,8 @@ struct AssignOpLowering : public ConversionPattern { | |||
auto memref_type = operands[0].getType().cast<MemRefType>(); | |||
AssignOpAdaptor assign_adaptor(operands); | |||
SmallVector<int64_t, 4> lower_bounds(memref_type.getRank(), 0); | |||
SmallVector<int64_t, 4> steps(memref_type.getRank(), 1); | |||
llvm::SmallVector<int64_t, 4> lower_bounds(memref_type.getRank(), 0); | |||
llvm::SmallVector<int64_t, 4> steps(memref_type.getRank(), 1); | |||
buildAffineLoopNest( | |||
rewriter, loc, lower_bounds, memref_type.getShape(), steps, | |||
[&](OpBuilder& nested_builder, Location loc, ValueRange ivs) { | |||
@@ -52,6 +52,54 @@ mlir::Value get_tid(ConversionPatternRewriter& rewriter, const Location& loc) { | |||
return index; | |||
} | |||
megdnn::TensorLayout output_layout(gpu::LaunchOp& launch_op) { | |||
auto func_op = launch_op.getParentOfType<mlir::FuncOp>(); | |||
mgb_assert(func_op, "Unexpexted launch op."); | |||
for (auto block_iter = func_op.rbegin(); block_iter != func_op.rend(); | |||
block_iter++) { | |||
for (auto op_iter = block_iter->rbegin(); op_iter != block_iter->rend(); | |||
op_iter++) { | |||
auto op = llvm::dyn_cast_or_null<AssignOp>(&(*op_iter)); | |||
if (op && op.getNumOperands() > 0) { | |||
return mlir_type_to_layout(*(op.operand_type_begin())); | |||
} | |||
} | |||
} | |||
mgb_throw(MegBrainError, "Unexpexted launch op."); | |||
} | |||
std::vector<mlir::Value> get_multidim_tid(ConversionPatternRewriter& rewriter, | |||
const Location& loc, | |||
const mlir::Value& val, | |||
const megdnn::TensorLayout& dst) { | |||
Value index = get_tid(rewriter, loc); | |||
auto type = val.getType().dyn_cast_or_null<mlir::MemRefType>(); | |||
if (type) { | |||
ValueBuilderHelper helper(rewriter, loc); | |||
std::vector<mlir::Value> idxs; | |||
idxs.resize(dst.ndim); | |||
mlir::Value dim_index = index; | |||
for (int i = dst.ndim - 1; i >= 0; i--) { | |||
auto cur_index = helper.modI(dim_index, helper.constI(dst[i])); | |||
idxs[i] = cur_index; | |||
dim_index = helper.divI(dim_index, helper.constI(dst[i])); | |||
} | |||
megdnn::TensorLayout src_layout = mlir_type_to_layout(type); | |||
src_layout.init_contiguous_stride(); | |||
for (int i = 0; i < type.getRank(); ++i) { | |||
if (src_layout[i] == 1) { | |||
idxs[i] = helper.constI(0); | |||
} | |||
} | |||
return idxs; | |||
} else { | |||
return {index}; | |||
} | |||
} | |||
template <typename Op, typename LoweredOp> | |||
struct UnaryOpLowering : public ConversionPattern { | |||
UnaryOpLowering(MLIRContext* ctx, gpu::LaunchOp& launch_op) | |||
@@ -66,7 +114,9 @@ struct UnaryOpLowering : public ConversionPattern { | |||
typename Op::Adaptor binary_adaptor(operands); | |||
rewriter.setInsertionPointToEnd(&(m_launch_op.body().front())); | |||
auto index = get_tid(rewriter, loc); | |||
auto dst_layout = output_layout(m_launch_op); | |||
auto index = get_multidim_tid(rewriter, loc, binary_adaptor.lhs(), | |||
dst_layout); | |||
auto loaded_lhs = | |||
get_operand<LoadOp>(rewriter, loc, binary_adaptor.lhs(), index); | |||
@@ -99,11 +149,15 @@ struct BinaryOpLowering : public ConversionPattern { | |||
typename Op::Adaptor binary_adaptor(operands); | |||
rewriter.setInsertionPointToEnd(&(m_launch_op.body().front())); | |||
auto index = get_tid(rewriter, loc); | |||
auto loaded_lhs = | |||
get_operand<LoadOp>(rewriter, loc, binary_adaptor.lhs(), index); | |||
auto loaded_rhs = | |||
get_operand<LoadOp>(rewriter, loc, binary_adaptor.rhs(), index); | |||
auto dst_layout = output_layout(m_launch_op); | |||
auto lhs_index = get_multidim_tid(rewriter, loc, binary_adaptor.lhs(), | |||
dst_layout); | |||
auto rhs_index = get_multidim_tid(rewriter, loc, binary_adaptor.rhs(), | |||
dst_layout); | |||
auto loaded_lhs = get_operand<LoadOp>(rewriter, loc, | |||
binary_adaptor.lhs(), lhs_index); | |||
auto loaded_rhs = get_operand<LoadOp>(rewriter, loc, | |||
binary_adaptor.rhs(), rhs_index); | |||
LoweredOp lower_op; | |||
@@ -135,13 +189,19 @@ struct TernaryOpLowering : public ConversionPattern { | |||
typename Op::Adaptor ternary_adaptor(operands); | |||
rewriter.setInsertionPointToEnd(&(m_launch_op.body().front())); | |||
auto index = get_tid(rewriter, loc); | |||
auto loaded_x = | |||
get_operand<LoadOp>(rewriter, loc, ternary_adaptor.x(), index); | |||
auto loaded_y = | |||
get_operand<LoadOp>(rewriter, loc, ternary_adaptor.y(), index); | |||
auto loaded_z = | |||
get_operand<LoadOp>(rewriter, loc, ternary_adaptor.z(), index); | |||
auto dst_layout = output_layout(m_launch_op); | |||
auto index_x = get_multidim_tid(rewriter, loc, ternary_adaptor.x(), | |||
dst_layout); | |||
auto index_y = get_multidim_tid(rewriter, loc, ternary_adaptor.y(), | |||
dst_layout); | |||
auto index_z = get_multidim_tid(rewriter, loc, ternary_adaptor.z(), | |||
dst_layout); | |||
auto loaded_x = get_operand<LoadOp>(rewriter, loc, ternary_adaptor.x(), | |||
index_x); | |||
auto loaded_y = get_operand<LoadOp>(rewriter, loc, ternary_adaptor.y(), | |||
index_y); | |||
auto loaded_z = get_operand<LoadOp>(rewriter, loc, ternary_adaptor.z(), | |||
index_z); | |||
LoweredOp lower_op; | |||
@@ -242,7 +302,9 @@ struct AssignOpLowering : public ConversionPattern { | |||
AssignOpAdaptor assign_adaptor(operands); | |||
rewriter.setInsertionPointToEnd(&(m_launch_op.body().front())); | |||
auto index = get_tid(rewriter, loc); | |||
auto dst_layout = output_layout(m_launch_op); | |||
auto index = get_multidim_tid(rewriter, loc, assign_adaptor.rhs(), | |||
dst_layout); | |||
auto loaded_lhs = | |||
get_operand<LoadOp>(rewriter, loc, assign_adaptor.lhs(), index); | |||
@@ -98,7 +98,6 @@ mlir::MemRefType jit::layout_to_mlir_type(const megdnn::TensorLayout& layout, | |||
for (size_t i = 0; i < layout.ndim; i++) { | |||
shape.push_back(layout[i]); | |||
} | |||
switch (layout.dtype.enumv()) { | |||
case megdnn::DTypeEnum::Float32: | |||
return mlir::MemRefType::get(shape, builder.getF32Type()); | |||
@@ -73,10 +73,10 @@ private: | |||
m_symbol_table); | |||
std::vector<mlir::Type> func_args; | |||
for (auto&& arg : args.inputs) { | |||
func_args.push_back(get_type(arg.layout)); | |||
func_args.push_back(get_type(arg.from->layout())); | |||
} | |||
for (auto&& arg : args.outputs) { | |||
func_args.push_back(get_type(arg.layout)); | |||
func_args.push_back(get_type(arg.from->layout())); | |||
} | |||
//! the last arg is nr_elements | |||
func_args.push_back(m_builder.getIndexType()); | |||
@@ -44,7 +44,6 @@ megdnn::TensorLayout mlir_type_to_layout(mlir::Type type); | |||
megdnn::DType mlir_type_to_dtype(mlir::Type type); | |||
mlir::MemRefType layout_to_mlir_type(const megdnn::TensorLayout& layout, | |||
mlir::Builder& builder); | |||
} // namespace jit | |||
} // namespace mgb | |||
@@ -130,8 +130,8 @@ void run_mlir(CompNode cn) { | |||
auto graph = ComputingGraph::make(); | |||
HostTensorGenerator<dtype::Float32> gen; | |||
auto host_x0 = gen({23, 42}, cn), host_x1 = gen({23, 42}, cn), | |||
host_x2 = gen({23, 42}, cn), host_x3 = gen({23, 42}, cn); | |||
auto host_x0 = gen({23, 42}, cn), host_x1 = gen({23, 1}, cn), | |||
host_x2 = gen({23, 42}, cn); | |||
auto a = opr::Host2DeviceCopy::make(*graph, host_x0), | |||
b = opr::Host2DeviceCopy::make(*graph, host_x1), | |||
@@ -159,6 +159,43 @@ void run_mlir(CompNode cn) { | |||
MGB_ASSERT_TENSOR_EQ(host_y, host_y_jit); | |||
} | |||
void run_mlir_broadcast(CompNode cn) { | |||
set_backend(Backend::MLIR); | |||
auto graph = ComputingGraph::make(); | |||
HostTensorGenerator<dtype::Float32> gen; | |||
auto host_x0 = gen({10, 20, 5, 6}, cn), host_x1 = gen({1, 20, 1, 1}, cn), | |||
host_x2 = gen({10, 1, 5, 1}, cn), host_x3 = gen({10, 1, 1, 1}, cn); | |||
auto a = opr::Host2DeviceCopy::make(*graph, host_x0), | |||
b = opr::Host2DeviceCopy::make(*graph, host_x1), | |||
c = opr::Host2DeviceCopy::make(*graph, host_x2), | |||
d = opr::Host2DeviceCopy::make(*graph, host_x3); | |||
auto y = | |||
opr::Elemwise::make({a, b, c}, opr::Elemwise::Mode::FUSE_MUL_ADD3) + | |||
opr::Elemwise::make({d}, opr::Elemwise::Mode::ABS) - 0.3f; | |||
auto ig_gen = | |||
std::make_unique<InternalGraphGenerator>(y.node()->owner_opr()); | |||
for (auto i : get_rev_topo_order(y)) { | |||
if (!i->same_type<opr::Host2DeviceCopy>()) { | |||
ig_gen->add_opr(i); | |||
} | |||
} | |||
auto igraph = ig_gen->generate(); | |||
auto y_jit = JITExecutor::make(igraph, ig_gen->orig_inps()); | |||
HostTensorND host_y, host_y_jit; | |||
auto func = graph->compile({make_callback_copy(y, host_y), | |||
make_callback_copy(y_jit, host_y_jit)}); | |||
func->execute(); | |||
MGB_ASSERT_TENSOR_EQ(host_y, host_y_jit); | |||
} | |||
struct MlirTestOpt { | |||
float low; | |||
float high; | |||
@@ -252,12 +289,14 @@ TYPED_TEST(TestJITNvrtcCodeGen, run) { | |||
TEST(TestJITMlirCodeGen, Basic) { | |||
auto cn = CompNode::load("cpu0"); | |||
run_mlir(cn); | |||
run_mlir_broadcast(cn); | |||
} | |||
TEST(TestJITMlirCodeGen, BasicGPU) { | |||
REQUIRE_GPU(1); | |||
auto cn = CompNode::load("gpu0"); | |||
run_mlir(cn); | |||
run_mlir_broadcast(cn); | |||
} | |||
///////////////////////// unary /////////////////////////////// | |||
@@ -1580,8 +1580,8 @@ void run_mlir(CompNode cn) { | |||
JITExecutor* jit; | |||
unpack_vector(find_oprs<JITExecutor>(*funcs.second), jit); | |||
ASSERT_EQ(2u, find_oprs<opr::Elemwise>(*funcs.second).size()); | |||
ASSERT_EQ(3u, jit->input().size()); | |||
ASSERT_EQ(0u, find_oprs<opr::Elemwise>(*funcs.second).size()); | |||
ASSERT_EQ(5u, jit->input().size()); | |||
} | |||
TEST(TestJITExecutor, TestJITMlirFusion) { | |||