Browse Source

feat(mlir/ir): add more op definitions

GitOrigin-RevId: 1e1285ef41
release-1.2
Megvii Engine Team 4 years ago
parent
commit
cb59c27835
7 changed files with 38 additions and 37 deletions
  1. +14
    -9
      src/jit/impl/mlir/ir/each_mode.cpp
  2. +11
    -20
      src/jit/impl/mlir/ir/types.cpp
  3. +1
    -4
      src/jit/impl/mlir/ir/types.h
  4. +1
    -1
      src/jit/impl/mlir/ir/utils.cpp
  5. +6
    -3
      src/jit/impl/mlir/mlir_gen.cpp
  6. +3
    -0
      src/jit/include/megbrain/jit/mlir/ir/dialect.h
  7. +2
    -0
      src/jit/include/megbrain/jit/mlir/ir/mgb_dialect.td

+ 14
- 9
src/jit/impl/mlir/ir/each_mode.cpp View File

@@ -22,6 +22,7 @@
#include "megbrain/exception.h" #include "megbrain/exception.h"
#include "megbrain/jit/mlir/ir/dialect.h" #include "megbrain/jit/mlir/ir/dialect.h"


#include <llvm/Support/raw_ostream.h>
#include <mlir/Dialect/StandardOps/IR/Ops.h> #include <mlir/Dialect/StandardOps/IR/Ops.h>


namespace mgb { namespace mgb {
@@ -442,31 +443,35 @@ mlir::Value lower_elemwise_to_std(mlir::Operation* op, mlir::OpBuilder& builder,
mlir::Value lower_typecvt_to_std(mlir::Operation* op, mlir::OpBuilder& builder, mlir::Value lower_typecvt_to_std(mlir::Operation* op, mlir::OpBuilder& builder,
mlir::Location loc, mlir::Value input) { mlir::Location loc, mlir::Value input) {
auto&& typecvt = llvm::dyn_cast<dialect::TypeCvt>(op); auto&& typecvt = llvm::dyn_cast<dialect::TypeCvt>(op);
megdnn::DType idtype = typecvt.idtype();
megdnn::DType odtype = typecvt.odtype();
mlir::Type idtype = typecvt.idtype();
mlir::Type odtype =
megdnn_dtype_to_mlir_type(typecvt.dtype(), builder.getContext());


mlir::Type itype = input.getType(); mlir::Type itype = input.getType();
mlir::Type otype = megdnn_dtype_to_mlir_type(odtype, builder.getContext());
mlir::Type otype = signless(odtype);
mgb_assert(signless(idtype) == itype);


if (mlir::FPExtOp::areCastCompatible(itype, otype)) { if (mlir::FPExtOp::areCastCompatible(itype, otype)) {
return builder.create<mlir::FPExtOp>(loc, otype, input); return builder.create<mlir::FPExtOp>(loc, otype, input);
} else if (mlir::FPTruncOp::areCastCompatible(itype, otype)) { } else if (mlir::FPTruncOp::areCastCompatible(itype, otype)) {
return builder.create<mlir::FPTruncOp>(loc, otype, input); return builder.create<mlir::FPTruncOp>(loc, otype, input);
} else if (mlir::FPToSIOp::areCastCompatible(itype, otype) and } else if (mlir::FPToSIOp::areCastCompatible(itype, otype) and
is_signed_int_dtype(odtype)) {
odtype.isSignedInteger()) {
return builder.create<mlir::FPToSIOp>(loc, otype, input); return builder.create<mlir::FPToSIOp>(loc, otype, input);
} else if (mlir::FPToUIOp::areCastCompatible(itype, otype) and } else if (mlir::FPToUIOp::areCastCompatible(itype, otype) and
is_unsigned_int_dtype(odtype)) {
odtype.isUnsignedInteger()) {
return builder.create<mlir::FPToUIOp>(loc, otype, input); return builder.create<mlir::FPToUIOp>(loc, otype, input);
} else if (mlir::SIToFPOp::areCastCompatible(itype, otype) and } else if (mlir::SIToFPOp::areCastCompatible(itype, otype) and
is_signed_int_dtype(idtype)) {
idtype.isSignedInteger()) {
return builder.create<mlir::SIToFPOp>(loc, otype, input); return builder.create<mlir::SIToFPOp>(loc, otype, input);
} else if (mlir::UIToFPOp::areCastCompatible(itype, otype) and } else if (mlir::UIToFPOp::areCastCompatible(itype, otype) and
is_unsigned_int_dtype(idtype)) {
idtype.isUnsignedInteger()) {
return builder.create<mlir::UIToFPOp>(loc, otype, input); return builder.create<mlir::UIToFPOp>(loc, otype, input);
} else { } else {
mgb_throw(InternalError, "cannot convert from %s to %s", idtype.name(),
odtype.name());
std::string tmp;
llvm::raw_string_ostream os(tmp);
os << "cannot convert from " << idtype << " to " << odtype;
mgb_throw_raw(InternalError{tmp});
} }


return nullptr; return nullptr;


+ 11
- 20
src/jit/impl/mlir/ir/types.cpp View File

@@ -28,13 +28,13 @@ mlir::Type megdnn_dtype_to_mlir_type(megdnn::DType type,
case megdnn::DTypeEnum::Float32: case megdnn::DTypeEnum::Float32:
return mlir::FloatType::getF32(ctx); return mlir::FloatType::getF32(ctx);
case megdnn::DTypeEnum::Uint8: case megdnn::DTypeEnum::Uint8:
return mlir::IntegerType::get(8, ctx);
return mlir::IntegerType::get(8, mlir::IntegerType::Unsigned, ctx);
case megdnn::DTypeEnum::Int8: case megdnn::DTypeEnum::Int8:
return mlir::IntegerType::get(8, ctx);
return mlir::IntegerType::get(8, mlir::IntegerType::Signed, ctx);
case megdnn::DTypeEnum::Int16: case megdnn::DTypeEnum::Int16:
return mlir::IntegerType::get(16, ctx);
return mlir::IntegerType::get(16, mlir::IntegerType::Signed, ctx);
case megdnn::DTypeEnum::Int32: case megdnn::DTypeEnum::Int32:
return mlir::IntegerType::get(32, ctx);
return mlir::IntegerType::get(32, mlir::IntegerType::Signed, ctx);
case megdnn::DTypeEnum::IntB1: case megdnn::DTypeEnum::IntB1:
return mlir::IntegerType::get(1, ctx); return mlir::IntegerType::get(1, ctx);
case megdnn::DTypeEnum::IntB2: case megdnn::DTypeEnum::IntB2:
@@ -57,6 +57,13 @@ mlir::Type megdnn_dtype_to_mlir_type(megdnn::DType type,
} }
} }


mlir::Type signless(mlir::Type type) {
if (auto intty = type.dyn_cast<mlir::IntegerType>()) {
return mlir::IntegerType::get(intty.getWidth(), type.getContext());
}
return type;
}

megdnn::DType mlir_type_to_megdnn_dtype(mlir::Type type) { megdnn::DType mlir_type_to_megdnn_dtype(mlir::Type type) {
mlir::Type element_type = type; mlir::Type element_type = type;
if (auto cast = type.dyn_cast_or_null<mlir::MemRefType>()) { if (auto cast = type.dyn_cast_or_null<mlir::MemRefType>()) {
@@ -91,22 +98,6 @@ megdnn::DType mlir_type_to_megdnn_dtype(mlir::Type type) {
return megdnn::DType::from_enum(enumv); return megdnn::DType::from_enum(enumv);
} }


bool is_signed_int_dtype(megdnn::DType type) {
auto enumv = type.enumv();
return enumv == megdnn::DTypeEnum::Int8 or
enumv == megdnn::DTypeEnum::Int16 or
enumv == megdnn::DTypeEnum::Int32 or
enumv == megdnn::DTypeEnum::IntB1 or
enumv == megdnn::DTypeEnum::IntB2 or
enumv == megdnn::DTypeEnum::IntB4;
}

bool is_unsigned_int_dtype(megdnn::DType type) {
auto enumv = type.enumv();
return enumv == megdnn::DTypeEnum::Uint8 or
enumv == megdnn::DTypeEnum::UintB4;
}

} // namespace jit } // namespace jit
} // namespace mgb } // namespace mgb




+ 1
- 4
src/jit/impl/mlir/ir/types.h View File

@@ -35,13 +35,10 @@ namespace jit {


mlir::Type megdnn_dtype_to_mlir_type(megdnn::DType type, mlir::Type megdnn_dtype_to_mlir_type(megdnn::DType type,
mlir::MLIRContext* ctx); mlir::MLIRContext* ctx);
mlir::Type signless(mlir::Type type);


megdnn::DType mlir_type_to_megdnn_dtype(mlir::Type type); megdnn::DType mlir_type_to_megdnn_dtype(mlir::Type type);


bool is_signed_int_dtype(megdnn::DType type);

bool is_unsigned_int_dtype(megdnn::DType type);

} // namespace jit } // namespace jit
} // namespace mgb } // namespace mgb




+ 1
- 1
src/jit/impl/mlir/ir/utils.cpp View File

@@ -87,7 +87,7 @@ mlir::MemRefType jit::layout_to_mlir_type(const megdnn::TensorLayout& layout,
shape.push_back(layout[i]); shape.push_back(layout[i]);
} }
mlir::Type type = megdnn_dtype_to_mlir_type(layout.dtype, builder.getContext()); mlir::Type type = megdnn_dtype_to_mlir_type(layout.dtype, builder.getContext());
return mlir::MemRefType::get(shape, type);
return mlir::MemRefType::get(shape, signless(type));
} }


#endif // MGB_JIT && MGB_JIT_MLIR #endif // MGB_JIT && MGB_JIT_MLIR


+ 6
- 3
src/jit/impl/mlir/mlir_gen.cpp View File

@@ -197,12 +197,15 @@ private:
.getType() .getType()
.dyn_cast_or_null<mlir::MemRefType>(); .dyn_cast_or_null<mlir::MemRefType>();
mgb_assert(itype, "currently only support MemRefType"); mgb_assert(itype, "currently only support MemRefType");
auto output_type = megdnn_dtype_to_mlir_type(opr.param(),
m_builder.getContext());
auto res_type = mlir::MemRefType::get( auto res_type = mlir::MemRefType::get(
itype.getShape(),
megdnn_dtype_to_mlir_type(opr.param(), m_builder.getContext()));
itype.getShape(), signless(output_type));
auto inp_type = megdnn_dtype_to_mlir_type(opr.input(0)->dtype(),
m_builder.getContext());
return m_builder.create<dialect::TypeCvt>( return m_builder.create<dialect::TypeCvt>(
m_builder.getUnknownLoc(), res_type, get(opr.input(0)), m_builder.getUnknownLoc(), res_type, get(opr.input(0)),
opr.input(0)->dtype(), opr.param());
mlir::TypeAttr::get(inp_type), opr.param());
} }


mlir::Value gen_dimshuffle(const opr::Dimshuffle& opr) { mlir::Value gen_dimshuffle(const opr::Dimshuffle& opr) {


+ 3
- 0
src/jit/include/megbrain/jit/mlir/ir/dialect.h View File

@@ -15,7 +15,10 @@
#include "megbrain_build_config.h" #include "megbrain_build_config.h"
#if MGB_JIT && MGB_JIT_MLIR #if MGB_JIT && MGB_JIT_MLIR


#include "megdnn/basic_types.h"
#include "megdnn/opr_param_defs.h" #include "megdnn/opr_param_defs.h"
#include "megbrain/opr/param_defs.h"
#include "megbrain/comp_node.h"


#include <mlir/IR/Dialect.h> #include <mlir/IR/Dialect.h>
#include <mlir/IR/Function.h> #include <mlir/IR/Function.h>


+ 2
- 0
src/jit/include/megbrain/jit/mlir/ir/mgb_dialect.td View File

@@ -15,6 +15,8 @@


include "ops.td" include "ops.td"


include "mlir/Interfaces/SideEffectInterfaces.td"

class GenericOp<string mnemonic, list<OpTrait> traits = []> : class GenericOp<string mnemonic, list<OpTrait> traits = []> :
Op<Mgb_Dialect, mnemonic, traits>; Op<Mgb_Dialect, mnemonic, traits>;




Loading…
Cancel
Save