GitOrigin-RevId: 27b1649c04
tags/v1.0.0-rc1
@@ -101,9 +101,9 @@ Compiler* Compiler::get(ComputingGraph& graph, CompNode comp_node) { | |||
compiler = std::make_unique<CudaCompiler>(); | |||
break; | |||
} | |||
#endif | |||
mgb_throw(InternalError, "No compiler support for cuda"); | |||
break; | |||
#endif | |||
case CompNode::DeviceType::CPU: | |||
#if MGB_JIT_MLIR | |||
if (!backend || !strcmp(backend, "MLIR")) { | |||
@@ -20,6 +20,10 @@ | |||
#if MGB_JIT | |||
#if MGB_JIT_MLIR | |||
#include "./mlir/ir/each_mode.h" | |||
#endif | |||
using namespace mgb; | |||
using namespace gopt; | |||
using namespace jit; | |||
@@ -339,35 +343,76 @@ bool JITFusionPass::Impl::can_be_fused(cg::OperatorNodeBase* opr) const { | |||
return false; | |||
} | |||
//! As MLIR backend has some contraints | |||
auto backend = MGB_GETENV("MGB_JIT_BACKEND"); | |||
// float elemwise | |||
if (auto elem = gopt::try_cast_as_op<opr::Elemwise>(opr)) { | |||
return ast_c::check_elem_mode(elem->param().mode) && | |||
bool ret = true; | |||
#if MGB_JIT_MLIR | |||
if (!strcmp(backend, "MLIR")) { | |||
switch (elem->param().mode) { | |||
#define cb(_, _mode) \ | |||
case opr::Elemwise::Mode::_mode: \ | |||
ret = true; \ | |||
break; | |||
MLIR_MGB_FOREACH_ELEMWISE_MODE_UNARY(cb) | |||
MLIR_MGB_FOREACH_ELEMWISE_MODE_BINARY(cb) | |||
MLIR_MGB_FOREACH_ELEMWISE_MODE_TERNARY(cb) | |||
default: | |||
ret = false; | |||
#undef cb | |||
} | |||
#define FOREACH_ELEMWISE_SKIP_MODE(cb) cb(SIN) | |||
//! FIXME mlir on cuda does't support sin currently. | |||
if (opr->output(0)->comp_node().device_type() == | |||
CompNode::DeviceType::CUDA) { | |||
switch (elem->param().mode) { | |||
#define cb(_mode) \ | |||
case opr::Elemwise::Mode::_mode: \ | |||
ret = false; \ | |||
break; | |||
FOREACH_ELEMWISE_SKIP_MODE(cb) | |||
default: | |||
break; | |||
#undef cb | |||
} | |||
} | |||
#undef FOREACH_ELEMWISE_SKIP_MODE | |||
} | |||
#endif // MGB_JIT_MLIR | |||
return ret && ast_c::check_elem_mode(elem->param().mode) && | |||
elem->output(0)->dtype().category() == DTypeCategory::FLOAT; | |||
} | |||
if (opr->same_type<opr::PowC>()) { | |||
return true; | |||
} | |||
if (strcmp(backend, "MLIR")) { | |||
if (opr->same_type<opr::PowC>()) { | |||
return true; | |||
} | |||
// float typecvt (e.g. used in f16 training) | |||
if (opr->same_type<opr::TypeCvt>()) { | |||
auto category = opr->input(0)->dtype().category(); | |||
if (category != opr->output(0)->dtype().category()) | |||
return false; | |||
return category == DTypeCategory::FLOAT; | |||
} | |||
// float typecvt (e.g. used in f16 training) | |||
if (opr->same_type<opr::TypeCvt>()) { | |||
auto category = opr->input(0)->dtype().category(); | |||
if (category != opr->output(0)->dtype().category()) | |||
return false; | |||
return category == DTypeCategory::FLOAT; | |||
} | |||
// float reduce | |||
if ((m_feature_bits & JITFeatureBits::REDUCE) && | |||
opr->same_type<opr::Reduce>()) { | |||
return opr->output(0)->dtype().category() == DTypeCategory::FLOAT; | |||
} | |||
// float reduce | |||
if ((m_feature_bits & JITFeatureBits::REDUCE) && | |||
opr->same_type<opr::Reduce>()) { | |||
return opr->output(0)->dtype().category() == DTypeCategory::FLOAT; | |||
} | |||
// dimshuffle | |||
if ((m_feature_bits & JITFeatureBits::DIMSHUFFLE) && | |||
opr->same_type<opr::Dimshuffle>()) { | |||
auto param = opr->cast_final_safe<opr::Dimshuffle>().param(); | |||
return param.pattern_len <= 4; | |||
// dimshuffle | |||
if ((m_feature_bits & JITFeatureBits::DIMSHUFFLE) && | |||
opr->same_type<opr::Dimshuffle>()) { | |||
auto param = opr->cast_final_safe<opr::Dimshuffle>().param(); | |||
return param.pattern_len <= 4; | |||
} | |||
} | |||
// existing JITExecutor | |||
@@ -10,7 +10,6 @@ | |||
* implied. | |||
*/ | |||
#include "llvm/Pass.h" | |||
#include "megbrain_build_config.h" | |||
#if MGB_JIT && MGB_JIT_MLIR | |||
@@ -40,6 +39,7 @@ | |||
#include <llvm/Support/TargetSelect.h> | |||
#include <llvm/IRReader/IRReader.h> | |||
#include <llvm/Linker/Linker.h> | |||
#include <llvm/Pass.h> | |||
#include <dlfcn.h> | |||
#include <dirent.h> | |||
@@ -77,6 +77,16 @@ private: | |||
mlir::Location m_location; | |||
}; | |||
template <typename Op> | |||
mlir::Value get_operand(mlir::OpBuilder& builder, const mlir::Location& loc, | |||
const mlir::Value& val, const mlir::ValueRange& index) { | |||
if (val.getType().isa<mlir::MemRefType>()) { | |||
return builder.create<Op>(loc, val, index); | |||
} else { | |||
return val; | |||
} | |||
} | |||
} // namespace jit | |||
} // namespace mgb | |||
@@ -14,6 +14,7 @@ | |||
#if MGB_JIT && MGB_JIT_MLIR | |||
#include "megbrain/jit/mlir/ir/dialect.h" | |||
#include "./types.h" | |||
#include <mlir/IR/Builders.h> | |||
#include <mlir/IR/OpImplementation.h> | |||
@@ -74,8 +74,8 @@ struct UnaryOpLowering : public ConversionPattern { | |||
typename Op::Adaptor binary_adaptor(memref_operands); | |||
LoweredOp lower_op; | |||
auto loaded_lhs = builder.create<AffineLoadOp>( | |||
loc, binary_adaptor.lhs(), loop_ivs); | |||
auto loaded_lhs = get_operand<AffineLoadOp>( | |||
builder, loc, binary_adaptor.lhs(), loop_ivs); | |||
return lower_op(builder, loc, {loaded_lhs}); | |||
}); | |||
@@ -104,10 +104,10 @@ struct BinaryOpLowering : public ConversionPattern { | |||
typename Op::Adaptor binary_adaptor(memref_operands); | |||
LoweredOp lower_op; | |||
auto loaded_lhs = builder.create<AffineLoadOp>( | |||
loc, binary_adaptor.lhs(), loop_ivs); | |||
auto loaded_rhs = builder.create<AffineLoadOp>( | |||
loc, binary_adaptor.rhs(), loop_ivs); | |||
auto loaded_lhs = get_operand<AffineLoadOp>( | |||
builder, loc, binary_adaptor.lhs(), loop_ivs); | |||
auto loaded_rhs = get_operand<AffineLoadOp>( | |||
builder, loc, binary_adaptor.rhs(), loop_ivs); | |||
return lower_op(builder, loc, {loaded_lhs, loaded_rhs}); | |||
}); | |||
@@ -136,12 +136,12 @@ struct TernaryOpLowering : public ConversionPattern { | |||
typename Op::Adaptor ternary_adaptor(memref_operands); | |||
LoweredOp lower_op; | |||
auto loaded_x = builder.create<AffineLoadOp>( | |||
loc, ternary_adaptor.x(), loop_ivs); | |||
auto loaded_y = builder.create<AffineLoadOp>( | |||
loc, ternary_adaptor.y(), loop_ivs); | |||
auto loaded_z = builder.create<AffineLoadOp>( | |||
loc, ternary_adaptor.z(), loop_ivs); | |||
auto loaded_x = get_operand<AffineLoadOp>( | |||
builder, loc, ternary_adaptor.x(), loop_ivs); | |||
auto loaded_y = get_operand<AffineLoadOp>( | |||
builder, loc, ternary_adaptor.y(), loop_ivs); | |||
auto loaded_z = get_operand<AffineLoadOp>( | |||
builder, loc, ternary_adaptor.z(), loop_ivs); | |||
return lower_op(builder, loc, | |||
{loaded_x, loaded_y, loaded_z}); | |||
@@ -193,6 +193,19 @@ struct ReturnOpLowering : public OpRewritePattern<jit::ReturnOp> { | |||
} | |||
}; | |||
struct ConstantScalarOpLowering | |||
: public OpRewritePattern<jit::ConstantScalarOp> { | |||
using OpRewritePattern<jit::ConstantScalarOp>::OpRewritePattern; | |||
LogicalResult matchAndRewrite(jit::ConstantScalarOp op, | |||
PatternRewriter& rewriter) const final { | |||
ConstantScalarOpAdaptor constant_scalar_adaptor(op); | |||
rewriter.replaceOpWithNewOp<mlir::ConstantOp>( | |||
op, constant_scalar_adaptor.value()); | |||
return success(); | |||
} | |||
}; | |||
class MgbToAffineLoweringPass | |||
: public PassWrapper<MgbToAffineLoweringPass, FunctionPass> { | |||
public: | |||
@@ -207,7 +220,8 @@ public: | |||
cb) MLIR_MGB_FOREACH_ELEMWISE_MODE_BINARY(cb) | |||
MLIR_MGB_FOREACH_ELEMWISE_MODE_TERNARY(cb) | |||
ReturnOpLowering, | |||
AssignOpLowering>(&getContext()); | |||
AssignOpLowering, ConstantScalarOpLowering>( | |||
&getContext()); | |||
#undef cb | |||
if (failed(applyPartialConversion(getFunction(), target, patterns))) { | |||
@@ -38,16 +38,6 @@ using namespace jit; | |||
namespace { | |||
mlir::Value get_operand(ConversionPatternRewriter& rewriter, | |||
const mlir::Location& loc, const mlir::Value& val, | |||
const mlir::Value& index) { | |||
if (val.getType().isa<mlir::MemRefType>()) { | |||
return rewriter.create<LoadOp>(loc, val, index); | |||
} else { | |||
return val; | |||
} | |||
} | |||
mlir::Value get_tid(ConversionPatternRewriter& rewriter, const Location& loc) { | |||
auto thread_idx = rewriter.create<gpu::ThreadIdOp>( | |||
loc, rewriter.getIndexType(), rewriter.getStringAttr("x")); | |||
@@ -64,7 +54,7 @@ mlir::Value get_tid(ConversionPatternRewriter& rewriter, const Location& loc) { | |||
template <typename Op, typename LoweredOp> | |||
struct UnaryOpLowering : public ConversionPattern { | |||
UnaryOpLowering(MLIRContext* ctx, gpu::LaunchOp* launch_op) | |||
UnaryOpLowering(MLIRContext* ctx, gpu::LaunchOp& launch_op) | |||
: ConversionPattern(Op::getOperationName(), 1, ctx), | |||
m_launch_op{launch_op} {} | |||
@@ -74,11 +64,11 @@ struct UnaryOpLowering : public ConversionPattern { | |||
auto loc = op->getLoc(); | |||
typename Op::Adaptor binary_adaptor(operands); | |||
rewriter.setInsertionPointToEnd(&(m_launch_op->body().front())); | |||
rewriter.setInsertionPointToEnd(&(m_launch_op.body().front())); | |||
auto index = get_tid(rewriter, loc); | |||
auto loaded_lhs = | |||
get_operand(rewriter, loc, binary_adaptor.lhs(), index); | |||
get_operand<LoadOp>(rewriter, loc, binary_adaptor.lhs(), index); | |||
LoweredOp lower_op; | |||
@@ -87,7 +77,7 @@ struct UnaryOpLowering : public ConversionPattern { | |||
} | |||
private: | |||
gpu::LaunchOp* m_launch_op; | |||
gpu::LaunchOp& m_launch_op; | |||
}; | |||
#define cb(_op, _) \ | |||
@@ -97,7 +87,7 @@ MLIR_MGB_FOREACH_ELEMWISE_MODE_UNARY(cb) | |||
template <typename Op, typename LoweredOp> | |||
struct BinaryOpLowering : public ConversionPattern { | |||
BinaryOpLowering(MLIRContext* ctx, gpu::LaunchOp* launch_op) | |||
BinaryOpLowering(MLIRContext* ctx, gpu::LaunchOp& launch_op) | |||
: ConversionPattern(Op::getOperationName(), 1, ctx), | |||
m_launch_op{launch_op} {} | |||
@@ -107,13 +97,13 @@ struct BinaryOpLowering : public ConversionPattern { | |||
auto loc = op->getLoc(); | |||
typename Op::Adaptor binary_adaptor(operands); | |||
rewriter.setInsertionPointToEnd(&(m_launch_op->body().front())); | |||
rewriter.setInsertionPointToEnd(&(m_launch_op.body().front())); | |||
auto index = get_tid(rewriter, loc); | |||
auto loaded_lhs = | |||
get_operand(rewriter, loc, binary_adaptor.lhs(), index); | |||
get_operand<LoadOp>(rewriter, loc, binary_adaptor.lhs(), index); | |||
auto loaded_rhs = | |||
get_operand(rewriter, loc, binary_adaptor.rhs(), index); | |||
get_operand<LoadOp>(rewriter, loc, binary_adaptor.rhs(), index); | |||
LoweredOp lower_op; | |||
@@ -123,7 +113,7 @@ struct BinaryOpLowering : public ConversionPattern { | |||
} | |||
private: | |||
gpu::LaunchOp* m_launch_op; | |||
gpu::LaunchOp& m_launch_op; | |||
}; | |||
#define cb(_op, _) \ | |||
@@ -133,7 +123,7 @@ MLIR_MGB_FOREACH_ELEMWISE_MODE_BINARY(cb) | |||
template <typename Op, typename LoweredOp> | |||
struct TernaryOpLowering : public ConversionPattern { | |||
TernaryOpLowering(MLIRContext* ctx, gpu::LaunchOp* launch_op) | |||
TernaryOpLowering(MLIRContext* ctx, gpu::LaunchOp& launch_op) | |||
: ConversionPattern(Op::getOperationName(), 1, ctx), | |||
m_launch_op{launch_op} {} | |||
@@ -143,15 +133,15 @@ struct TernaryOpLowering : public ConversionPattern { | |||
auto loc = op->getLoc(); | |||
typename Op::Adaptor ternary_adaptor(operands); | |||
rewriter.setInsertionPointToEnd(&(m_launch_op->body().front())); | |||
rewriter.setInsertionPointToEnd(&(m_launch_op.body().front())); | |||
auto index = get_tid(rewriter, loc); | |||
auto loaded_x = | |||
get_operand(rewriter, loc, ternary_adaptor.x(), index); | |||
get_operand<LoadOp>(rewriter, loc, ternary_adaptor.x(), index); | |||
auto loaded_y = | |||
get_operand(rewriter, loc, ternary_adaptor.y(), index); | |||
get_operand<LoadOp>(rewriter, loc, ternary_adaptor.y(), index); | |||
auto loaded_z = | |||
get_operand(rewriter, loc, ternary_adaptor.z(), index); | |||
get_operand<LoadOp>(rewriter, loc, ternary_adaptor.z(), index); | |||
LoweredOp lower_op; | |||
@@ -161,7 +151,7 @@ struct TernaryOpLowering : public ConversionPattern { | |||
} | |||
private: | |||
gpu::LaunchOp* m_launch_op; | |||
gpu::LaunchOp& m_launch_op; | |||
}; | |||
#define cb(_op, _) \ | |||
@@ -171,7 +161,7 @@ MLIR_MGB_FOREACH_ELEMWISE_MODE_TERNARY(cb) | |||
#undef cb | |||
struct ReturnOpLowering : public ConversionPattern { | |||
ReturnOpLowering(MLIRContext* ctx, gpu::LaunchOp* launch_op) | |||
ReturnOpLowering(MLIRContext* ctx, gpu::LaunchOp& launch_op) | |||
: ConversionPattern(jit::ReturnOp::getOperationName(), 1, ctx), | |||
m_launch_op{launch_op} {} | |||
@@ -182,10 +172,10 @@ struct ReturnOpLowering : public ConversionPattern { | |||
auto loc = op->getLoc(); | |||
//! remove the first gpu.terminator | |||
m_launch_op->body().front().front().erase(); | |||
m_launch_op.body().front().front().erase(); | |||
//! if (tid >= nr_tid) {return;} in the begin of the block | |||
rewriter.setInsertionPointToStart(&(m_launch_op->body().front())); | |||
rewriter.setInsertionPointToStart(&(m_launch_op.body().front())); | |||
Block* cond_block = rewriter.getInsertionBlock(); | |||
Block::iterator op_position = rewriter.getInsertionPoint(); | |||
Block* remaining_ops_block = | |||
@@ -195,7 +185,7 @@ struct ReturnOpLowering : public ConversionPattern { | |||
auto index = get_tid(rewriter, loc); | |||
auto comparison = rewriter.create<mlir::CmpIOp>( | |||
loc, CmpIPredicate::sge, index, | |||
m_launch_op->getParentOfType<mlir::FuncOp>() | |||
m_launch_op.getParentOfType<mlir::FuncOp>() | |||
.getArguments() | |||
.back()); | |||
@@ -216,11 +206,31 @@ struct ReturnOpLowering : public ConversionPattern { | |||
} | |||
private: | |||
gpu::LaunchOp* m_launch_op; | |||
gpu::LaunchOp& m_launch_op; | |||
}; | |||
struct ConstantScalarOpLowering | |||
: public OpRewritePattern<jit::ConstantScalarOp> { | |||
ConstantScalarOpLowering(MLIRContext* ctx, gpu::LaunchOp& launch_op) | |||
: OpRewritePattern<jit::ConstantScalarOp>(ctx), | |||
m_launch_op{launch_op} {} | |||
LogicalResult matchAndRewrite(jit::ConstantScalarOp op, | |||
PatternRewriter& rewriter) const final { | |||
ConstantScalarOpAdaptor constant_scalar_adaptor(op); | |||
rewriter.setInsertionPointToEnd(&(m_launch_op.body().front())); | |||
rewriter.replaceOpWithNewOp<mlir::ConstantOp>( | |||
op, constant_scalar_adaptor.value()); | |||
return success(); | |||
} | |||
private: | |||
gpu::LaunchOp& m_launch_op; | |||
}; | |||
struct AssignOpLowering : public ConversionPattern { | |||
AssignOpLowering(MLIRContext* ctx, gpu::LaunchOp* launch_op) | |||
AssignOpLowering(MLIRContext* ctx, gpu::LaunchOp& launch_op) | |||
: ConversionPattern(jit::AssignOp::getOperationName(), 2, ctx), | |||
m_launch_op{launch_op} {} | |||
@@ -230,12 +240,12 @@ struct AssignOpLowering : public ConversionPattern { | |||
auto loc = op->getLoc(); | |||
AssignOpAdaptor assign_adaptor(operands); | |||
rewriter.setInsertionPointToEnd(&(m_launch_op->body().front())); | |||
rewriter.setInsertionPointToEnd(&(m_launch_op.body().front())); | |||
auto index = get_tid(rewriter, loc); | |||
auto loaded_lhs = | |||
get_operand(rewriter, loc, assign_adaptor.lhs(), index); | |||
get_operand<LoadOp>(rewriter, loc, assign_adaptor.lhs(), index); | |||
rewriter.create<StoreOp>(loc, loaded_lhs, assign_adaptor.rhs(), index); | |||
rewriter.eraseOp(op); | |||
@@ -243,7 +253,7 @@ struct AssignOpLowering : public ConversionPattern { | |||
} | |||
private: | |||
gpu::LaunchOp* m_launch_op; | |||
gpu::LaunchOp& m_launch_op; | |||
}; | |||
class MgbToGpuLoweringPass | |||
@@ -271,7 +281,8 @@ public: | |||
cb) MLIR_MGB_FOREACH_ELEMWISE_MODE_BINARY(cb) | |||
MLIR_MGB_FOREACH_ELEMWISE_MODE_TERNARY(cb) | |||
ReturnOpLowering, | |||
AssignOpLowering>(&getContext(), &launch_op); | |||
ConstantScalarOpLowering, AssignOpLowering>( | |||
&getContext(), launch_op); | |||
#undef cb | |||
if (failed(applyPartialConversion(func_op, target, patterns))) { | |||
@@ -17,6 +17,7 @@ include "mlir/IR/OpBase.td" | |||
include "mlir/Interfaces/SideEffectInterfaces.td" | |||
include "./interfaces.td" | |||
include "./predicates.td" | |||
def Mgb_Dialect : Dialect { | |||
let name = "mgb"; | |||
@@ -90,7 +91,7 @@ def RoundOp : ElemwiseUnaryOp<"round", [NoSideEffect]>; | |||
class ElemwiseBinaryOp<string mnemonic, list<OpTrait> traits = [NoSideEffect]> : | |||
ElemwiseOp<mnemonic, traits> { | |||
let arguments = (ins F32MemRef:$lhs, F32MemRef:$rhs); | |||
let arguments = (ins ElemwiseFloatAny:$lhs, ElemwiseFloatAny:$rhs); | |||
let results = (outs F32MemRef); | |||
let builders = [OpBuilder< | |||
@@ -141,7 +142,7 @@ def FuseAddHswishOp : ElemwiseBinaryOp<"fuse_add_hswish", [NoSideEffect]>; | |||
class ElemwiseTernaryOp<string mnemonic, list<OpTrait> traits = [NoSideEffect]> : | |||
ElemwiseOp<mnemonic, traits> { | |||
let arguments = (ins F32MemRef:$x, F32MemRef:$y, F32MemRef:$z); | |||
let arguments = (ins ElemwiseFloatAny:$x, ElemwiseFloatAny:$y, ElemwiseFloatAny:$z); | |||
let results = (outs F32MemRef); | |||
let builders = [OpBuilder< | |||
@@ -178,6 +179,25 @@ def ReturnOp : GenericOp<"return", | |||
} | |||
def ConstantScalarOp: GenericOp<"sconst", [NoSideEffect]> { | |||
let summary = "scalar constant"; | |||
let arguments = (ins AnyAttr:$value); | |||
let results = (outs F32:$result); | |||
let builders = [OpBuilder< | |||
"Builder* builder, OperationState& result, float value", [{ | |||
result.addAttribute("value", builder->getF32FloatAttr(value)); | |||
result.addTypes(builder->getF32Type()); | |||
}] | |||
>]; | |||
let extraClassDeclaration = [{ | |||
Attribute getValue() { return getAttr("value"); } | |||
FloatAttr getFloatAttr() { return getAttrOfType<FloatAttr>("value"); } | |||
}]; | |||
} | |||
def AssignOp : GenericOp<"assign", []> { | |||
let summary = "assign op"; | |||
let description = [{ | |||
@@ -0,0 +1,24 @@ | |||
/** | |||
* \file src/jit/impl/mlir/ir/predicates.td | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
* implied. | |||
*/ | |||
#ifndef MGB_MLIR_PREDICATES | |||
#define MGB_MLIR_PREDICATES | |||
#ifndef OP_BASE | |||
include "mlir/IR/OpBase.td" | |||
#endif | |||
def ElemwiseFloatAny : TypeConstraint< | |||
CPred<"is_elemwise_float($_self)">, "elemwise-float">; | |||
#endif | |||
@@ -0,0 +1,39 @@ | |||
/** | |||
* \file src/jit/impl/mlir/ir/types.h | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
* implied. | |||
*/ | |||
#pragma once | |||
#include "megbrain_build_config.h" | |||
#if MGB_JIT && MGB_JIT_MLIR | |||
#include <mlir/IR/StandardTypes.h> | |||
namespace mgb { | |||
namespace jit { | |||
inline bool is_elemwise_float(const mlir::Type& dt) { | |||
if (auto cast = dt.dyn_cast_or_null<mlir::MemRefType>()) { | |||
if (cast.getElementType().getKind() == mlir::StandardTypes::F32) { | |||
return true; | |||
} | |||
} | |||
if (dt.isa<mlir::FloatType>()) { | |||
return true; | |||
} | |||
return false; | |||
} | |||
} // namespace jit | |||
} // namespace mgb | |||
#endif // MGB_JIT && MGB_JIT_MLIR | |||
// vim: syntax=cpp.doxygen |
@@ -49,6 +49,9 @@ mlir::Type jit::deduce_result_type(mlir::ValueRange operands) { | |||
megdnn::TensorShape dst; | |||
megdnn::DType dst_type; | |||
for (auto operand : operands) { | |||
if (operand.getType().isa<mlir::FloatType>()) { | |||
continue; | |||
} | |||
auto type = operand.getType().dyn_cast_or_null<mlir::MemRefType>(); | |||
mgb_assert(type, "currently only support MemRefType"); | |||
@@ -137,6 +137,27 @@ private: | |||
return; | |||
} | |||
if (opr->same_type<opr::ImmutableTensor>()) { | |||
auto imm = SymbolVar{opr->output(0)}.as_immutable_scalar(); | |||
if (imm.valid()) { | |||
auto dtype = imm->dtype(); | |||
float scalar_value; | |||
if (dtype == dtype::Float32()) { | |||
scalar_value = imm->get<float>(); | |||
} else { | |||
mgb_throw(InternalError, | |||
"mlir backend currently only support f32 " | |||
"dtype, but got %s", | |||
dtype.name()); | |||
} | |||
auto&& out = m_builder.create<jit::ConstantScalarOp>( | |||
m_builder.getUnknownLoc(), m_builder.getF32Type(), | |||
m_builder.getF32FloatAttr(scalar_value)); | |||
mgb_assert(mlir::succeeded( | |||
declare(opr->output(0)->name(), out))); | |||
} | |||
} | |||
if (opr->same_type<opr::Elemwise>()) { | |||
auto&& out = gen_op(opr->cast_final<opr::Elemwise>()); | |||
mgb_assert( | |||
@@ -137,7 +137,7 @@ void run_mlir(CompNode cn) { | |||
b = opr::Host2DeviceCopy::make(*graph, host_x1), | |||
c = opr::Host2DeviceCopy::make(*graph, host_x2); | |||
auto y = a + b * c; | |||
auto y = a + b * c + 0.3f; | |||
auto ig_gen = | |||
std::make_unique<InternalGraphGenerator>(y.node()->owner_opr()); | |||