GitOrigin-RevId: bd1b80c84f
release-1.1
@@ -753,12 +753,14 @@ install(TARGETS mgb_opr_param_defs EXPORT ${MGE_EXPORT_TARGETS}) | |||||
if(MGE_WITH_JIT_MLIR) | if(MGE_WITH_JIT_MLIR) | ||||
# generate param_defs.td | # generate param_defs.td | ||||
set(MGE_GENFILE_DIR ${PROJECT_BINARY_DIR}/src/genfiles) | set(MGE_GENFILE_DIR ${PROJECT_BINARY_DIR}/src/genfiles) | ||||
set(MGE_GEN_IR_DIR ${PROJECT_BINARY_DIR}/src/core/include/megbrain/ir) | |||||
set(OPR_PARAM_DEFS_SRCS ${MGE_GENFILE_DIR}/opr_param_defs.py) | set(OPR_PARAM_DEFS_SRCS ${MGE_GENFILE_DIR}/opr_param_defs.py) | ||||
set(OPR_PARAM_DEFS_SCRIPT ${PROJECT_SOURCE_DIR}/dnn/scripts/gen_tablegen.py) | set(OPR_PARAM_DEFS_SCRIPT ${PROJECT_SOURCE_DIR}/dnn/scripts/gen_tablegen.py) | ||||
set(OPR_PARAM_DEFS_OUT ${MGE_GENFILE_DIR}/param_defs.td) | |||||
set(OPR_PARAM_DEFS_OUT ${MGE_GEN_IR_DIR}/param_defs.td) | |||||
file(COPY ${PROJECT_SOURCE_DIR}/dnn/scripts/opr_param_defs.py DESTINATION ${MGE_GENFILE_DIR}) | file(COPY ${PROJECT_SOURCE_DIR}/dnn/scripts/opr_param_defs.py DESTINATION ${MGE_GENFILE_DIR}) | ||||
file(READ ${PROJECT_SOURCE_DIR}/tools/param_defs/mgb_opr_param_defs.py CONTENTS) | file(READ ${PROJECT_SOURCE_DIR}/tools/param_defs/mgb_opr_param_defs.py CONTENTS) | ||||
file(APPEND ${OPR_PARAM_DEFS_SRCS} ${CONTENTS}) | file(APPEND ${OPR_PARAM_DEFS_SRCS} ${CONTENTS}) | ||||
file(MAKE_DIRECTORY ${MGE_GEN_IR_DIR}) | |||||
add_custom_target(param_defs_tblgen | add_custom_target(param_defs_tblgen | ||||
COMMAND ${PYTHON_EXECUTABLE} ${OPR_PARAM_DEFS_SCRIPT} ${OPR_PARAM_DEFS_SRCS} ${OPR_PARAM_DEFS_OUT} | COMMAND ${PYTHON_EXECUTABLE} ${OPR_PARAM_DEFS_SCRIPT} ${OPR_PARAM_DEFS_SRCS} ${OPR_PARAM_DEFS_OUT} | ||||
DEPENDS ${OPR_PARAM_DEFS_SRCS} ${OPR_PARAM_DEFS_SCRIPT} | DEPENDS ${OPR_PARAM_DEFS_SRCS} ${OPR_PARAM_DEFS_SCRIPT} | ||||
@@ -766,7 +768,7 @@ if(MGE_WITH_JIT_MLIR) | |||||
) | ) | ||||
# mlir tblgen sources | # mlir tblgen sources | ||||
set(MGE_IR_DIR ${PROJECT_SOURCE_DIR}/src/core/include/megbrain/ir) | set(MGE_IR_DIR ${PROJECT_SOURCE_DIR}/src/core/include/megbrain/ir) | ||||
set(MGE_IR_INCLUDE_DIRS ${MLIR_LLVM_INCLUDE_DIR} ${MGE_GENFILE_DIR} ${MGE_IR_DIR}) | |||||
set(MGE_IR_INCLUDE_DIRS ${MLIR_LLVM_INCLUDE_DIR} ${MGE_IR_DIR} ${MGE_GEN_IR_DIR}) | |||||
list(TRANSFORM MGE_IR_INCLUDE_DIRS PREPEND "-I") | list(TRANSFORM MGE_IR_INCLUDE_DIRS PREPEND "-I") | ||||
file(GLOB_RECURSE MGE_IR_TDS ${MGE_IR_DIR}/*.td) | file(GLOB_RECURSE MGE_IR_TDS ${MGE_IR_DIR}/*.td) | ||||
endif() | endif() | ||||
@@ -1,5 +1,5 @@ | |||||
if(MGE_WITH_JIT_MLIR) | if(MGE_WITH_JIT_MLIR) | ||||
add_subdirectory(jit/impl/mlir/ir) | |||||
add_subdirectory(jit/include/megbrain/jit/mlir/ir) | |||||
endif() | endif() | ||||
file(GLOB_RECURSE SOURCES core/impl/*.cpp gopt/impl/*.cpp opr/impl/*.cpp opr/impl/nvof/*.cpp plugin/impl/*.cpp serialization/impl/*.cpp core/impl/*.inl gopt/impl/*.inl opr/impl/*.inl plugin/impl/*.inl serialization/impl/*.inl) | file(GLOB_RECURSE SOURCES core/impl/*.cpp gopt/impl/*.cpp opr/impl/*.cpp opr/impl/nvof/*.cpp plugin/impl/*.cpp serialization/impl/*.cpp core/impl/*.inl gopt/impl/*.inl opr/impl/*.inl plugin/impl/*.inl serialization/impl/*.inl) | ||||
@@ -100,9 +100,10 @@ if(MGE_WITH_JIT AND MGE_WITH_HALIDE) | |||||
target_link_libraries(megbrain PRIVATE ${HALIDE_LLVM_LIBS}) | target_link_libraries(megbrain PRIVATE ${HALIDE_LLVM_LIBS}) | ||||
endif() | endif() | ||||
if(MGE_WITH_JIT_MLIR) | if(MGE_WITH_JIT_MLIR) | ||||
target_link_libraries(megbrain PRIVATE mlir_op_def) | |||||
target_link_libraries(megbrain PRIVATE mlir_shape_inference) | |||||
target_include_directories(megbrain PRIVATE ${MLIR_LLVM_INCLUDE_DIR}) | |||||
target_link_libraries(megbrain PRIVATE ${MLIR_LLVM_LIBS}) | target_link_libraries(megbrain PRIVATE ${MLIR_LLVM_LIBS}) | ||||
add_dependencies(megbrain mgb_dialect) | |||||
target_include_directories(megbrain PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/jit/include) | |||||
endif() | endif() | ||||
if (MGB_WITH_FLATBUFFERS) | if (MGB_WITH_FLATBUFFERS) | ||||
set (GEN_FLATBUFFERS_SCHEMA_PY ${PROJECT_SOURCE_DIR}/dnn/scripts/gen_flatbuffers_schema.py) | set (GEN_FLATBUFFERS_SCHEMA_PY ${PROJECT_SOURCE_DIR}/dnn/scripts/gen_flatbuffers_schema.py) | ||||
@@ -17,6 +17,7 @@ | |||||
#include "./executable_cpu.h" | #include "./executable_cpu.h" | ||||
#include "./executable_cuda.h" | #include "./executable_cuda.h" | ||||
#include "./mlir_gen.h" | #include "./mlir_gen.h" | ||||
#include "megbrain/common.h" | #include "megbrain/common.h" | ||||
#include "megbrain/comp_node_env.h" | #include "megbrain/comp_node_env.h" | ||||
#include "megbrain/jit/mlir/ir/dialect.h" | #include "megbrain/jit/mlir/ir/dialect.h" | ||||
@@ -14,37 +14,44 @@ | |||||
#if MGB_JIT && MGB_JIT_MLIR | #if MGB_JIT && MGB_JIT_MLIR | ||||
#include "./executable_cpu.h" | #include "./executable_cpu.h" | ||||
#include "./ir/types.h" | |||||
#include "megbrain/jit/mlir/ir/utils.h" | #include "megbrain/jit/mlir/ir/utils.h" | ||||
#include <mlir/ExecutionEngine/OptUtils.h> | |||||
#include <mlir/ExecutionEngine/CRunnerUtils.h> | #include <mlir/ExecutionEngine/CRunnerUtils.h> | ||||
#include <mlir/ExecutionEngine/OptUtils.h> | |||||
using namespace mgb; | using namespace mgb; | ||||
using namespace jit; | using namespace jit; | ||||
namespace { | namespace { | ||||
template <typename T, int N> | |||||
StridedMemRefType<T, N>* get_strided_memref_type( | |||||
const megdnn::TensorND& tensor) { | |||||
using DescType = StridedMemRefType<T, N>; | |||||
DescType* desc = static_cast<DescType*>(malloc(sizeof(DescType))); | |||||
desc->basePtr = tensor.ptr<T>(); | |||||
desc->data = tensor.ptr<T>(); | |||||
desc->offset = 0; | |||||
for (size_t i = 0; i < tensor.layout.ndim; i++) { | |||||
desc->sizes[i] = tensor.layout.shape[i]; | |||||
desc->strides[i] = tensor.layout.stride[i]; | |||||
} | |||||
return desc; | |||||
} | |||||
template <int N> | template <int N> | ||||
void* tensor2memref_dim(const megdnn::TensorND& tensor) { | void* tensor2memref_dim(const megdnn::TensorND& tensor) { | ||||
switch (tensor.layout.dtype.enumv()) { | switch (tensor.layout.dtype.enumv()) { | ||||
case megdnn::DTypeEnum::Float32: { | |||||
StridedMemRefType<float, N>* desc = | |||||
static_cast<StridedMemRefType<float, N>*>( | |||||
malloc(sizeof(StridedMemRefType<float, N>))); | |||||
desc->basePtr = tensor.ptr<float>(); | |||||
desc->data = tensor.ptr<float>(); | |||||
desc->offset = 0; | |||||
for (size_t i = 0; i < tensor.layout.ndim; i++) { | |||||
desc->sizes[i] = tensor.layout.shape[i]; | |||||
desc->strides[i] = tensor.layout.stride[i]; | |||||
} | |||||
return desc; | |||||
break; | |||||
} | |||||
#define cb(_dtype, _type) \ | |||||
case megdnn::DTypeEnum::_dtype: \ | |||||
return get_strided_memref_type<_type, N>(tensor); | |||||
FOR_EACH_DNN_DTYPE(cb) | |||||
#undef cb | |||||
default: | default: | ||||
mgb_throw(InternalError, "Unsupport dtype, got %s", | |||||
mgb_throw(InternalError, "Unsupported dtype: %s", | |||||
tensor.layout.dtype.name()); | tensor.layout.dtype.name()); | ||||
break; | |||||
} | } | ||||
return nullptr; | return nullptr; | ||||
} | } | ||||
@@ -10,18 +10,18 @@ | |||||
* implied. | * implied. | ||||
*/ | */ | ||||
#include <vector> | |||||
#include "megbrain_build_config.h" | #include "megbrain_build_config.h" | ||||
#include "megdnn/dtype.h" | |||||
#if MGB_JIT && MGB_JIT_MLIR | #if MGB_JIT && MGB_JIT_MLIR | ||||
#if MGB_CUDA | #if MGB_CUDA | ||||
#include "./executable_cuda.h" | #include "./executable_cuda.h" | ||||
#include "./ir/types.h" | |||||
#include "megbrain/comp_node_env.h" | #include "megbrain/comp_node_env.h" | ||||
#include "megbrain/jit/mlir/ir/utils.h" | #include "megbrain/jit/mlir/ir/utils.h" | ||||
#include "megbrain/utils/persistent_cache.h" | #include "megbrain/utils/persistent_cache.h" | ||||
#include "megbrain/utils/timer.h" | #include "megbrain/utils/timer.h" | ||||
#include "megdnn/dtype.h" | |||||
#include <mlir/Dialect/GPU/GPUDialect.h> | #include <mlir/Dialect/GPU/GPUDialect.h> | ||||
#include <mlir/ExecutionEngine/CRunnerUtils.h> | #include <mlir/ExecutionEngine/CRunnerUtils.h> | ||||
@@ -83,6 +83,24 @@ void setup_and_launch(const JITExecutor* fusion_opr, CUfunction func, | |||||
MGB_CUDA_CU_CHECK(cuLaunchKernel(func, num_block, 1, 1, block_size, 1, 1, 0, | MGB_CUDA_CU_CHECK(cuLaunchKernel(func, num_block, 1, 1, block_size, 1, 1, 0, | ||||
env.cuda_env().stream, params.data(), 0)); | env.cuda_env().stream, params.data(), 0)); | ||||
} | } | ||||
template <int out_dim> | |||||
void setup_and_launch_dim(const megdnn::DType dtype, | |||||
const JITExecutor* fusion_opr, CUfunction func, | |||||
int block_size) { | |||||
switch (dtype.enumv()) { | |||||
#define cb(_dtype, _type) \ | |||||
case megdnn::DTypeEnum::_dtype: \ | |||||
setup_and_launch<out_dim, _type>(fusion_opr, func, block_size); \ | |||||
return; | |||||
FOR_EACH_DNN_DTYPE(cb) | |||||
#undef cb | |||||
default: | |||||
mgb_throw(InternalError, "Unsupported dtype: %s", dtype.name()); | |||||
} | |||||
return; | |||||
} | |||||
} // namespace | } // namespace | ||||
const std::string MLIRCUDAExecutable::sm_blob_annotation = "nvvm.cubin"; | const std::string MLIRCUDAExecutable::sm_blob_annotation = "nvvm.cubin"; | ||||
@@ -136,30 +154,19 @@ void MLIRCUDAExecutable::FuncCache::exec(const JITExecutor* fusion_opr, | |||||
fusion_opr->args().outputs.size()); | fusion_opr->args().outputs.size()); | ||||
int out_dim = fusion_opr->args().outputs[0].from->layout().ndim; | int out_dim = fusion_opr->args().outputs[0].from->layout().ndim; | ||||
DType dtype = fusion_opr->args().outputs[0].from->layout().dtype; | DType dtype = fusion_opr->args().outputs[0].from->layout().dtype; | ||||
#define cb_outdim(_ndim, _dtype) \ | |||||
if (_ndim == out_dim) { \ | |||||
setup_and_launch<_ndim, _dtype>(fusion_opr, func->func, \ | |||||
func->block_size); \ | |||||
return; \ | |||||
} | |||||
#define cb(_dtype) \ | |||||
cb_outdim(1, float); \ | |||||
cb_outdim(2, float); \ | |||||
cb_outdim(3, float); \ | |||||
cb_outdim(4, float); \ | |||||
mgb_throw(InternalError, "unsupported out_dim=%zu", \ | |||||
static_cast<size_t>(out_dim)); \ | |||||
return; | |||||
switch (dtype.enumv()) { | |||||
case DTypeEnum::Float32: | |||||
cb(float); | |||||
default: | |||||
mgb_throw(InternalError, "unsupport dtype: %s", dtype.name()); | |||||
} | |||||
switch (out_dim) { | |||||
#define cb(_ndim) \ | |||||
case _ndim: \ | |||||
setup_and_launch_dim<_ndim>(dtype, fusion_opr, func->func, \ | |||||
func->block_size); \ | |||||
break; | |||||
cb(1); | |||||
cb(2); | |||||
cb(3); | |||||
cb(4); | |||||
#undef cb | #undef cb | ||||
#undef cb_outdim | |||||
} | |||||
} | } | ||||
#endif // MGB_CUDA | #endif // MGB_CUDA | ||||
@@ -1,39 +0,0 @@ | |||||
set(MGB_MLIR_TABLEGEN_INC_BASE ${CMAKE_CURRENT_BINARY_DIR}/include/) | |||||
file(MAKE_DIRECTORY ${MGB_MLIR_TABLEGEN_INC_BASE}/megbrain/jit/mlir/ir/) | |||||
list(APPEND MGB_MLIR_TABLEGEN_INC ${MGB_MLIR_TABLEGEN_INC_BASE}) | |||||
external_tablegen_library( | |||||
NAME | |||||
mlir_shape_inference | |||||
TBLGEN | |||||
MLIR | |||||
SRCS | |||||
"interfaces.td" | |||||
INCLUDES | |||||
${MGB_MLIR_TABLEGEN_INC} ${MLIR_LLVM_INCLUDE_DIR} | |||||
OUTS | |||||
-gen-op-interface-decls include/megbrain/jit/mlir/ir/interfaces.h.inc | |||||
-gen-op-interface-defs include/megbrain/jit/mlir/ir/interfaces.cpp.inc | |||||
) | |||||
external_tablegen_library( | |||||
NAME | |||||
mlir_op_def | |||||
TBLGEN | |||||
MLIR | |||||
SRCS | |||||
"ops.td" | |||||
INCLUDES | |||||
${MGB_MLIR_TABLEGEN_INC} ${MLIR_LLVM_INCLUDE_DIR} | |||||
OUTS | |||||
-gen-op-decls include/megbrain/jit/mlir/ir/ops.h.inc | |||||
-gen-op-defs include/megbrain/jit/mlir/ir/ops.cpp.inc | |||||
) | |||||
# mgb_dialect | |||||
set(MGB_DIALECT_TD ${PROJECT_SOURCE_DIR}/src/jit/include/megbrain/jit/mlir/ir/mgb_dialect.td) | |||||
set(LLVM_TARGET_DEFINITIONS ${MGB_DIALECT_TD}) | |||||
tablegen(MLIR mgb_dialect.h.inc ${MGE_IR_INCLUDE_DIRS} "--gen-op-decls") | |||||
tablegen(MLIR mgb_dialect.cpp.inc ${MGE_IR_INCLUDE_DIRS} "--gen-op-defs") | |||||
add_custom_target(mgb_dialect DEPENDS mgb_dialect.h.inc mgb_dialect.cpp.inc ${MGB_DIALECT_TD} ${MGE_IR_TDS}) | |||||
add_dependencies(mgb_dialect param_defs_tblgen) |
@@ -14,91 +14,99 @@ | |||||
#if MGB_JIT && MGB_JIT_MLIR | #if MGB_JIT && MGB_JIT_MLIR | ||||
#include "./common.h" | #include "./common.h" | ||||
#include "megbrain/jit/mlir/ir/utils.h" | #include "megbrain/jit/mlir/ir/utils.h" | ||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" | |||||
#include <mlir/Dialect/Affine/IR/AffineOps.h> | #include <mlir/Dialect/Affine/IR/AffineOps.h> | ||||
#include <mlir/Dialect/StandardOps/IR/Ops.h> | |||||
using namespace mgb; | using namespace mgb; | ||||
using namespace jit; | using namespace jit; | ||||
/* ===================== trivial unary functions ===================== */ | |||||
#define cb(name, op) \ | |||||
mlir::Value ValueBuilderHelper::name(mlir::Value lhs) { \ | |||||
return m_builder.create<mlir::op>(m_location, lhs); \ | |||||
} | |||||
cb(abs, AbsFOp); | |||||
cb(ceil, CeilFOp); | |||||
cb(cos, CosOp); | |||||
cb(exp2, Exp2Op); | |||||
cb(exp, ExpOp); | |||||
cb(floor, FloorFOp); | |||||
cb(log10, Log10Op); | |||||
cb(log2, Log2Op); | |||||
cb(log, LogOp); | |||||
cb(neg, NegFOp); | |||||
cb(rsqrt, RsqrtOp); | |||||
cb(sin, SinOp); | |||||
cb(sqrt, SqrtOp); | |||||
cb(tanh, TanhOp); | |||||
#undef cb | |||||
/* ===================== trivial binary functions ===================== */ | |||||
#define cb(name, op) \ | #define cb(name, op) \ | ||||
mlir::Value ValueBuilderHelper::name(mlir::Value lhs, mlir::Value rhs) { \ | mlir::Value ValueBuilderHelper::name(mlir::Value lhs, mlir::Value rhs) { \ | ||||
return m_builder.create<mlir::op>(m_location, lhs, rhs); \ | return m_builder.create<mlir::op>(m_location, lhs, rhs); \ | ||||
} | } | ||||
cb(add, AddFOp); | cb(add, AddFOp); | ||||
cb(sub, SubFOp); | |||||
cb(mul, MulFOp); | |||||
cb(div, DivFOp); | |||||
cb(divI, SignedDivIOp); | |||||
cb(mod, RemFOp); | |||||
cb(bit_and, AndOp); | cb(bit_and, AndOp); | ||||
cb(bit_or, OrOp); | cb(bit_or, OrOp); | ||||
cb(div, DivFOp); | |||||
cb(divI, SignedDivIOp); | |||||
cb(modI, SignedRemIOp); | cb(modI, SignedRemIOp); | ||||
cb(mod, RemFOp); | |||||
cb(mul, MulFOp); | |||||
cb(sub, SubFOp); | |||||
#undef cb | #undef cb | ||||
/* ===================== compare functions ===================== */ | |||||
#define cb(name, mode) \ | #define cb(name, mode) \ | ||||
mlir::Value ValueBuilderHelper::name(mlir::Value lhs, mlir::Value rhs) { \ | mlir::Value ValueBuilderHelper::name(mlir::Value lhs, mlir::Value rhs) { \ | ||||
return m_builder.create<mlir::CmpFOp>( \ | return m_builder.create<mlir::CmpFOp>( \ | ||||
m_location, mlir::CmpFPredicate::mode, lhs, rhs); \ | m_location, mlir::CmpFPredicate::mode, lhs, rhs); \ | ||||
} | } | ||||
cb(gt, OGT); | |||||
cb(eq, OEQ); | |||||
cb(ge, OGE); | cb(ge, OGE); | ||||
cb(lt, OLT); | |||||
cb(gt, OGT); | |||||
cb(le, OLE); | cb(le, OLE); | ||||
cb(eq, OEQ); | |||||
cb(lt, OLT); | |||||
#undef cb | #undef cb | ||||
mlir::Value ValueBuilderHelper::min(mlir::Value lhs, mlir::Value rhs) { | |||||
mlir::Value ValueBuilderHelper::max(mlir::Value lhs, mlir::Value rhs) { | |||||
mlir::Value cmp = m_builder.create<mlir::CmpFOp>( | mlir::Value cmp = m_builder.create<mlir::CmpFOp>( | ||||
m_location, mlir::CmpFPredicate::OLT, lhs, rhs); | |||||
m_location, mlir::CmpFPredicate::OGT, lhs, rhs); | |||||
return m_builder.create<mlir::SelectOp>(m_location, cmp, lhs, rhs); | return m_builder.create<mlir::SelectOp>(m_location, cmp, lhs, rhs); | ||||
} | } | ||||
mlir::Value ValueBuilderHelper::max(mlir::Value lhs, mlir::Value rhs) { | |||||
mlir::Value ValueBuilderHelper::min(mlir::Value lhs, mlir::Value rhs) { | |||||
mlir::Value cmp = m_builder.create<mlir::CmpFOp>( | mlir::Value cmp = m_builder.create<mlir::CmpFOp>( | ||||
m_location, mlir::CmpFPredicate::OGT, lhs, rhs); | |||||
m_location, mlir::CmpFPredicate::OLT, lhs, rhs); | |||||
return m_builder.create<mlir::SelectOp>(m_location, cmp, lhs, rhs); | return m_builder.create<mlir::SelectOp>(m_location, cmp, lhs, rhs); | ||||
} | } | ||||
mlir::Value ValueBuilderHelper::const_val(float val) { | |||||
/* ===================== constant functions ===================== */ | |||||
mlir::Value ValueBuilderHelper::const_f32(float val) { | |||||
return m_builder.create<mlir::ConstantOp>(m_location, | return m_builder.create<mlir::ConstantOp>(m_location, | ||||
m_builder.getF32FloatAttr(val)); | m_builder.getF32FloatAttr(val)); | ||||
} | } | ||||
mlir::Value ValueBuilderHelper::constI(int32_t val) { | |||||
mlir::Value ValueBuilderHelper::const_i32(int32_t val) { | |||||
return m_builder.create<mlir::ConstantOp>(m_location, | return m_builder.create<mlir::ConstantOp>(m_location, | ||||
m_builder.getIndexAttr(val)); | m_builder.getIndexAttr(val)); | ||||
} | } | ||||
#define cb(name, op) \ | |||||
mlir::Value ValueBuilderHelper::name(mlir::Value lhs) { \ | |||||
return m_builder.create<mlir::op>(m_location, lhs); \ | |||||
} | |||||
cb(neg, NegFOp); | |||||
cb(ceil, CeilFOp); | |||||
cb(cos, CosOp); | |||||
cb(exp, ExpOp); | |||||
cb(exp2, Exp2Op); | |||||
cb(log10, Log10Op); | |||||
cb(log2, Log2Op); | |||||
cb(log, LogOp); | |||||
cb(rsqrt, RsqrtOp); | |||||
cb(sin, SinOp); | |||||
cb(sqrt, SqrtOp); | |||||
cb(tanh, TanhOp); | |||||
#undef cb | |||||
mlir::Value ValueBuilderHelper::abs(mlir::Value lhs) { | |||||
auto zero = const_val(0.f); | |||||
return select(ge(lhs, zero), lhs, sub(zero, lhs)); | |||||
} | |||||
mlir::Value ValueBuilderHelper::floor(mlir::Value lhs) { | |||||
//! FIXME use standard floor when upgrade llvm | |||||
return neg(ceil(neg(lhs))); | |||||
} | |||||
/* ===================== select function ===================== */ | |||||
mlir::Value ValueBuilderHelper::select(mlir::Value cond, mlir::Value true_val, | mlir::Value ValueBuilderHelper::select(mlir::Value cond, mlir::Value true_val, | ||||
mlir::Value false_val) { | mlir::Value false_val) { | ||||
@@ -106,6 +114,8 @@ mlir::Value ValueBuilderHelper::select(mlir::Value cond, mlir::Value true_val, | |||||
false_val); | false_val); | ||||
} | } | ||||
/* ===================== helper functions ===================== */ | |||||
mlir::AffineMap jit::get_affinemap(mlir::OpBuilder& builder, | mlir::AffineMap jit::get_affinemap(mlir::OpBuilder& builder, | ||||
const mlir::Value& val, | const mlir::Value& val, | ||||
const megdnn::TensorLayout& layout) { | const megdnn::TensorLayout& layout) { | ||||
@@ -125,10 +135,10 @@ mlir::AffineMap jit::get_affinemap(mlir::OpBuilder& builder, | |||||
} | } | ||||
mlir::Value jit::get_affine_load_op(mlir::OpBuilder& builder, | mlir::Value jit::get_affine_load_op(mlir::OpBuilder& builder, | ||||
const mlir::Location& loc, | |||||
const mlir::Value& val, | |||||
const mlir::ValueRange& index, | |||||
const megdnn::TensorLayout& dst) { | |||||
const mlir::Location& loc, | |||||
const mlir::Value& val, | |||||
const mlir::ValueRange& index, | |||||
const megdnn::TensorLayout& dst) { | |||||
if (val.getType().isa<mlir::MemRefType>()) { | if (val.getType().isa<mlir::MemRefType>()) { | ||||
auto type = val.getType().cast<mlir::MemRefType>(); | auto type = val.getType().cast<mlir::MemRefType>(); | ||||
megdnn::TensorLayout src_layout = mlir_type_to_layout(type); | megdnn::TensorLayout src_layout = mlir_type_to_layout(type); | ||||
@@ -14,7 +14,9 @@ | |||||
#include "megbrain_build_config.h" | #include "megbrain_build_config.h" | ||||
#if MGB_JIT && MGB_JIT_MLIR | #if MGB_JIT && MGB_JIT_MLIR | ||||
#include "megbrain/tensor.h" | #include "megbrain/tensor.h" | ||||
#include <mlir/Dialect/StandardOps/IR/Ops.h> | #include <mlir/Dialect/StandardOps/IR/Ops.h> | ||||
#include <mlir/IR/OperationSupport.h> | #include <mlir/IR/OperationSupport.h> | ||||
#include <mlir/IR/Value.h> | #include <mlir/IR/Value.h> | ||||
@@ -30,50 +32,59 @@ public: | |||||
ValueBuilderHelper(mlir::OpBuilder& b, mlir::Location location) | ValueBuilderHelper(mlir::OpBuilder& b, mlir::Location location) | ||||
: m_builder{b}, m_location{location} {}; | : m_builder{b}, m_location{location} {}; | ||||
#define cb(name) \ | |||||
mlir::Value name(mlir::ValueRange operands) { \ | |||||
return name(operands[0], operands[1]); \ | |||||
} \ | |||||
mlir::Value name(mlir::Value lhs, mlir::Value rhs) | |||||
cb(add); | |||||
cb(sub); | |||||
cb(mul); | |||||
cb(div); | |||||
cb(divI); | |||||
cb(max); | |||||
cb(min); | |||||
cb(mod); | |||||
cb(modI); | |||||
cb(gt); | |||||
cb(ge); | |||||
cb(lt); | |||||
cb(le); | |||||
cb(eq); | |||||
cb(bit_and); | |||||
cb(bit_or); | |||||
#undef cb | |||||
mlir::Value const_val(float val); | |||||
mlir::Value constI(int32_t val); | |||||
#define cb(name) \ | #define cb(name) \ | ||||
mlir::Value name(mlir::ValueRange operands) { return name(operands[0]); } \ | mlir::Value name(mlir::ValueRange operands) { return name(operands[0]); } \ | ||||
mlir::Value name(mlir::Value lhs) | mlir::Value name(mlir::Value lhs) | ||||
cb(neg); | |||||
// unary functions | |||||
cb(abs); | cb(abs); | ||||
cb(ceil); | cb(ceil); | ||||
cb(floor); | |||||
cb(cos); | cb(cos); | ||||
cb(exp); | cb(exp); | ||||
cb(exp2); | cb(exp2); | ||||
cb(floor); | |||||
cb(log); | |||||
cb(log10); | cb(log10); | ||||
cb(log2); | cb(log2); | ||||
cb(log); | |||||
cb(neg); | |||||
cb(rsqrt); | cb(rsqrt); | ||||
cb(sin); | cb(sin); | ||||
cb(sqrt); | cb(sqrt); | ||||
cb(tanh); | cb(tanh); | ||||
#undef cb | #undef cb | ||||
#define cb(name) \ | |||||
mlir::Value name(mlir::ValueRange operands) { \ | |||||
return name(operands[0], operands[1]); \ | |||||
} \ | |||||
mlir::Value name(mlir::Value lhs, mlir::Value rhs) | |||||
// binary functions | |||||
cb(add); | |||||
cb(bit_and); | |||||
cb(bit_or); | |||||
cb(div); | |||||
cb(divI); | |||||
cb(eq); | |||||
cb(ge); | |||||
cb(gt); | |||||
cb(le); | |||||
cb(lt); | |||||
cb(max); | |||||
cb(min); | |||||
cb(mod); | |||||
cb(modI); | |||||
cb(mul); | |||||
cb(sub); | |||||
#undef cb | |||||
// constant functions | |||||
mlir::Value const_f32(float val); | |||||
mlir::Value const_i32(int32_t val); | |||||
// select function | |||||
mlir::Value select(mlir::Value cond, mlir::Value true_val, | mlir::Value select(mlir::Value cond, mlir::Value true_val, | ||||
mlir::Value false_val); | mlir::Value false_val); | ||||
@@ -14,6 +14,7 @@ | |||||
#if MGB_JIT && MGB_JIT_MLIR | #if MGB_JIT && MGB_JIT_MLIR | ||||
#include "megbrain/jit/mlir/ir/dialect.h" | #include "megbrain/jit/mlir/ir/dialect.h" | ||||
#include "./types.h" | #include "./types.h" | ||||
#include <mlir/IR/Builders.h> | #include <mlir/IR/Builders.h> | ||||
@@ -28,14 +29,12 @@ MgbDialect::MgbDialect(mlir::MLIRContext* ctx) | |||||
: mlir::Dialect("mgb", ctx, mlir::TypeID::get<MgbDialect>()) { | : mlir::Dialect("mgb", ctx, mlir::TypeID::get<MgbDialect>()) { | ||||
addOperations< | addOperations< | ||||
#define GET_OP_LIST | #define GET_OP_LIST | ||||
#include "megbrain/jit/mlir/ir/ops.cpp.inc" | |||||
#include "megbrain/jit/mlir/ir/mgb_dialect.cpp.inc" | |||||
>(); | >(); | ||||
} | } | ||||
#define GET_OP_CLASSES | #define GET_OP_CLASSES | ||||
#include "megbrain/jit/mlir/ir/ops.cpp.inc" | |||||
#include "megbrain/jit/mlir/ir/interfaces.cpp.inc" | |||||
#include "megbrain/jit/mlir/ir/mgb_dialect.cpp.inc" | |||||
#endif // MGB_JIT && MGB_JIT_MLIR | #endif // MGB_JIT && MGB_JIT_MLIR | ||||
@@ -0,0 +1,480 @@ | |||||
/** | |||||
* \file src/jit/impl/mlir/ir/each_mode.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | |||||
#include "megbrain_build_config.h" | |||||
#if MGB_JIT && MGB_JIT_MLIR | |||||
#include "./common.h" | |||||
#include "./each_mode.h" | |||||
#include "./numerical.h" | |||||
#include "./types.h" | |||||
#include "megbrain/common.h" | |||||
#include "megbrain/exception.h" | |||||
#include "megbrain/jit/mlir/ir/dialect.h" | |||||
#include <mlir/Dialect/StandardOps/IR/Ops.h> | |||||
namespace mgb { | |||||
namespace jit { | |||||
using Mode = megdnn::param::Elemwise::Mode; | |||||
template <Mode mode> | |||||
mlir::Value lower_mode(mlir::OpBuilder& builder, mlir::Location loc, | |||||
ValueRange operands); | |||||
/* ===================== trivial implementations ===================== */ | |||||
#define cb(mode, fun) \ | |||||
template <> \ | |||||
mlir::Value lower_mode<Mode::mode>(mlir::OpBuilder & builder, \ | |||||
mlir::Location loc, \ | |||||
ValueRange operands) { \ | |||||
ValueBuilderHelper helper(builder, loc); \ | |||||
return helper.fun(operands); \ | |||||
} | |||||
//! unary | |||||
cb(ABS, abs); | |||||
cb(CEIL, ceil); | |||||
cb(COS, cos); | |||||
cb(EXP, exp); | |||||
cb(FLOOR, floor); | |||||
cb(LOG, log); | |||||
cb(NEGATE, neg); | |||||
cb(SIN, sin); | |||||
cb(TANH, tanh); | |||||
//! binary | |||||
cb(ADD, add); | |||||
cb(MAX, max); | |||||
cb(MIN, min); | |||||
cb(MOD, mod); | |||||
cb(MUL, mul); | |||||
cb(SUB, sub); | |||||
cb(TRUE_DIV, div); | |||||
#undef cb | |||||
/* ===================== unary op ===================== */ | |||||
//! ACOS: pi / 2 - arctan2(x, sqrt(1 - x * x)) | |||||
template <> | |||||
mlir::Value lower_mode<Mode::ACOS>(mlir::OpBuilder& builder, mlir::Location loc, | |||||
ValueRange operands) { | |||||
ValueBuilderHelper helper(builder, loc); | |||||
auto x = operands[0]; | |||||
auto one_minus_x_2 = helper.sub(helper.const_f32(1.f), helper.mul(x, x)); | |||||
auto asin = atan2_approx(helper, x, helper.sqrt(one_minus_x_2)); | |||||
auto pi_over_2 = helper.const_f32(1.57079637f); | |||||
return helper.sub(pi_over_2, asin); | |||||
} | |||||
//! ASIN: arctan2(x, sqrt(1 - x * x)) | |||||
template <> | |||||
mlir::Value lower_mode<Mode::ASIN>(mlir::OpBuilder& builder, mlir::Location loc, | |||||
ValueRange operands) { | |||||
ValueBuilderHelper helper(builder, loc); | |||||
auto x = operands[0]; | |||||
auto one_minus_x_2 = helper.sub(helper.const_f32(1.f), helper.mul(x, x)); | |||||
return atan2_approx(helper, x, helper.sqrt(one_minus_x_2)); | |||||
} | |||||
//! ERFCINV: inverse of complementary gauss error function | |||||
//! https://github.com/scipy/scipy/blob/master/scipy/special/cephes/erfinv.c | |||||
template <> | |||||
mlir::Value lower_mode<Mode::ERFCINV>(mlir::OpBuilder& builder, | |||||
mlir::Location loc, ValueRange operands) { | |||||
ValueBuilderHelper helper(builder, loc); | |||||
auto minus_sqrt2 = helper.const_f32(-1.4142135623f); | |||||
auto x = helper.mul(helper.const_f32(0.5f), operands[0]); | |||||
return helper.div(ndtri_approx(helper, x), minus_sqrt2); | |||||
} | |||||
//! ERFC: complementary error function | |||||
template <> | |||||
mlir::Value lower_mode<Mode::ERFC>(mlir::OpBuilder& builder, mlir::Location loc, | |||||
ValueRange operands) { | |||||
ValueBuilderHelper helper(builder, loc); | |||||
return helper.sub(helper.const_f32(1.f), erf_approx(helper, operands[0])); | |||||
} | |||||
//! ERFINV: inverse of gauss error function | |||||
//! https://github.com/scipy/scipy/blob/master/scipy/special/cephes/erfinv.c | |||||
template <> | |||||
mlir::Value lower_mode<Mode::ERFINV>(mlir::OpBuilder& builder, | |||||
mlir::Location loc, ValueRange operands) { | |||||
ValueBuilderHelper helper(builder, loc); | |||||
auto sqrt2 = helper.const_f32(1.4142135623f); | |||||
auto x = helper.mul(helper.const_f32(0.5f), | |||||
helper.add(operands[0], helper.const_f32(1.f))); | |||||
return helper.div(ndtri_approx(helper, x), sqrt2); | |||||
} | |||||
//! ERF: gauss error function | |||||
template <> | |||||
mlir::Value lower_mode<Mode::ERF>(mlir::OpBuilder& builder, mlir::Location loc, | |||||
ValueRange operands) { | |||||
ValueBuilderHelper helper(builder, loc); | |||||
return erf_approx(helper, operands[0]); | |||||
} | |||||
//! EXPM1: exp(x) - 1 | |||||
template <> | |||||
mlir::Value lower_mode<Mode::EXPM1>(mlir::OpBuilder& builder, | |||||
mlir::Location loc, ValueRange operands) { | |||||
ValueBuilderHelper helper(builder, loc); | |||||
return helper.sub(helper.exp(operands[0]), helper.const_f32(1.f)); | |||||
} | |||||
//! FAST_TANH: x * (27.f + x * x) / (27.f + 9.f * x * x); | |||||
template <> | |||||
mlir::Value lower_mode<Mode::FAST_TANH>(mlir::OpBuilder& builder, | |||||
mlir::Location loc, | |||||
ValueRange operands) { | |||||
ValueBuilderHelper helper(builder, loc); | |||||
auto square = helper.mul(operands[0], operands[0]); | |||||
return helper.div( | |||||
helper.mul(operands[0], helper.add(helper.const_f32(27.f), square)), | |||||
helper.add(helper.const_f32(27.f), | |||||
helper.mul(helper.const_f32(9.f), square))); | |||||
} | |||||
//! H_SWISH: x * clip(x + 3, 0, 6) / 6 | |||||
template <> | |||||
mlir::Value lower_mode<Mode::H_SWISH>(mlir::OpBuilder& builder, | |||||
mlir::Location loc, ValueRange operands) { | |||||
ValueBuilderHelper helper(builder, loc); | |||||
auto const_3 = helper.const_f32(3.f); | |||||
auto const_0 = helper.const_f32(0.f); | |||||
auto const_6 = helper.const_f32(6.f); | |||||
auto tmp = helper.add(operands[0], const_3); | |||||
return helper.div(helper.mul(operands[0], | |||||
helper.min(helper.max(tmp, const_0), const_6)), | |||||
const_6); | |||||
} | |||||
//! LOG1P: log(1 + p) | |||||
template <> | |||||
mlir::Value lower_mode<Mode::LOG1P>(mlir::OpBuilder& builder, | |||||
mlir::Location loc, ValueRange operands) { | |||||
ValueBuilderHelper helper(builder, loc); | |||||
return helper.log(helper.add(operands[0], helper.const_f32(1.f))); | |||||
} | |||||
//! RELU: max(x, 0) | |||||
template <> | |||||
mlir::Value lower_mode<Mode::RELU>(mlir::OpBuilder& builder, mlir::Location loc, | |||||
ValueRange operands) { | |||||
ValueBuilderHelper helper(builder, loc); | |||||
return helper.max(operands[0], helper.const_f32(0.f)); | |||||
} | |||||
//! ROUND | |||||
template <> | |||||
mlir::Value lower_mode<Mode::ROUND>(mlir::OpBuilder& builder, | |||||
mlir::Location loc, ValueRange operands) { | |||||
ValueBuilderHelper helper(builder, loc); | |||||
return helper.select( | |||||
helper.gt(operands[0], helper.const_f32(0.f)), | |||||
helper.floor(helper.add(operands[0], helper.const_f32(0.5f))), | |||||
helper.ceil(helper.sub(operands[0], helper.const_f32(0.5f)))); | |||||
} | |||||
//! SIGMOID: 1.f / (expf(-y) + 1.f)) | |||||
template <> | |||||
mlir::Value lower_mode<Mode::SIGMOID>(mlir::OpBuilder& builder, | |||||
mlir::Location loc, ValueRange operands) { | |||||
ValueBuilderHelper helper(builder, loc); | |||||
return helper.div(helper.const_f32(1.f), | |||||
helper.add(helper.exp(helper.neg(operands[0])), | |||||
helper.const_f32(1.f))); | |||||
} | |||||
/* ===================== binary op ===================== */ | |||||
//! ABS_GRAD: x > 0 ? y : -y | |||||
template <> | |||||
mlir::Value lower_mode<Mode::ABS_GRAD>(mlir::OpBuilder& builder, | |||||
mlir::Location loc, | |||||
ValueRange operands) { | |||||
ValueBuilderHelper helper(builder, loc); | |||||
return helper.select(helper.gt(operands[0], helper.const_f32(0.f)), | |||||
operands[1], helper.neg(operands[1])); | |||||
} | |||||
//! ATAN2 | |||||
template <> | |||||
mlir::Value lower_mode<Mode::ATAN2>(mlir::OpBuilder& builder, | |||||
mlir::Location loc, ValueRange operands) { | |||||
ValueBuilderHelper helper(builder, loc); | |||||
return atan2_approx(helper, operands[0], operands[1]); | |||||
} | |||||
//! EQ: x == y ? 1 : 0 | |||||
template <> | |||||
mlir::Value lower_mode<Mode::EQ>(mlir::OpBuilder& builder, mlir::Location loc, | |||||
ValueRange operands) { | |||||
ValueBuilderHelper helper(builder, loc); | |||||
return helper.select(helper.eq(operands[0], operands[1]), | |||||
helper.const_f32(1.f), helper.const_f32(0.f)); | |||||
} | |||||
//! FAST_TANH_GRAD: ((-48.f * x * x) / (3.f + x * x) + 27.f + x * x) / (3.f + x | |||||
//! * x) * y | |||||
template <> | |||||
mlir::Value lower_mode<Mode::FAST_TANH_GRAD>(mlir::OpBuilder& builder, | |||||
mlir::Location loc, | |||||
ValueRange operands) { | |||||
ValueBuilderHelper helper(builder, loc); | |||||
auto x_pow2 = helper.mul(operands[0], operands[0]); | |||||
auto deno = helper.add(helper.const_f32(3.f), x_pow2); | |||||
return helper.mul( | |||||
helper.div( | |||||
helper.add( | |||||
helper.add(helper.div(helper.mul(helper.const_f32( | |||||
-48.f), | |||||
x_pow2), | |||||
deno), | |||||
helper.const_f32(27.f)), | |||||
x_pow2), | |||||
helper.mul(deno, helper.const_f32(9.f))), | |||||
operands[1]); | |||||
} | |||||
//! FLOOR_DIV: floor(x/y) | |||||
template <> | |||||
mlir::Value lower_mode<Mode::FLOOR_DIV>(mlir::OpBuilder& builder, | |||||
mlir::Location loc, | |||||
ValueRange operands) { | |||||
ValueBuilderHelper helper(builder, loc); | |||||
return helper.floor(helper.div(operands[0], operands[1])); | |||||
} | |||||
//! FUSE_ADD_H_SWISH: (x+y) * min(max(x + y + 3, 0), 6) * (1/6) | |||||
template <> | |||||
mlir::Value lower_mode<Mode::FUSE_ADD_H_SWISH>(mlir::OpBuilder& builder, | |||||
mlir::Location loc, | |||||
ValueRange operands) { | |||||
ValueBuilderHelper helper(builder, loc); | |||||
auto sum = helper.add(operands[0], operands[1]); | |||||
auto const_3 = helper.const_f32(3.f); | |||||
auto const_0 = helper.const_f32(0.f); | |||||
auto const_6 = helper.const_f32(6.f); | |||||
auto tmp = helper.add(sum, const_3); | |||||
return helper.div( | |||||
helper.mul(sum, helper.min(helper.max(tmp, const_0), const_6)), | |||||
const_6); | |||||
} | |||||
//! FUSE_ADD_RELU: (x + y) <= ctype(0) ? ctype(0) : (x + y) | |||||
template <> | |||||
mlir::Value lower_mode<Mode::FUSE_ADD_RELU>(mlir::OpBuilder& builder, | |||||
mlir::Location loc, | |||||
ValueRange operands) { | |||||
ValueBuilderHelper helper(builder, loc); | |||||
auto sum = helper.add(operands[0], operands[1]); | |||||
return helper.max(sum, helper.const_f32(0.f)); | |||||
} | |||||
//! FUSE_ADD_SIGMOID: 1.f / (expf(-(x+y)) + 1.f)) | |||||
template <> | |||||
mlir::Value lower_mode<Mode::FUSE_ADD_SIGMOID>(mlir::OpBuilder& builder, | |||||
mlir::Location loc, | |||||
ValueRange operands) { | |||||
ValueBuilderHelper helper(builder, loc); | |||||
return helper.div(helper.const_f32(1.f), | |||||
helper.add(helper.exp(helper.neg( | |||||
helper.add(operands[0], operands[1]))), | |||||
helper.const_f32(1.f))); | |||||
} | |||||
//! FUSE_ADD_TANH: tanh(x + y) | |||||
template <> | |||||
mlir::Value lower_mode<Mode::FUSE_ADD_TANH>(mlir::OpBuilder& builder, | |||||
mlir::Location loc, | |||||
ValueRange operands) { | |||||
ValueBuilderHelper helper(builder, loc); | |||||
return helper.tanh(helper.add(operands[0], operands[1])); | |||||
} | |||||
//! H_SWISH_GRAD: x < -3.f ? 0.f : (x > 3.f ? y : (2.f * x + 3.f) / 6.f * y) | |||||
template <> | |||||
mlir::Value lower_mode<Mode::H_SWISH_GRAD>(mlir::OpBuilder& builder, | |||||
mlir::Location loc, | |||||
ValueRange operands) { | |||||
ValueBuilderHelper helper(builder, loc); | |||||
return helper.select( | |||||
helper.lt(operands[0], helper.const_f32(-3.f)), | |||||
helper.const_f32(0.f), | |||||
helper.select( | |||||
helper.gt(operands[0], helper.const_f32(3.f)), operands[1], | |||||
helper.mul( | |||||
helper.div( | |||||
helper.add(helper.mul(helper.const_f32(2.f), | |||||
operands[0]), | |||||
helper.const_f32(3.f)), | |||||
helper.const_f32(6.f)), | |||||
operands[1]))); | |||||
} | |||||
//! LEQ: x <= y ? 1 : 0 | |||||
template <> | |||||
mlir::Value lower_mode<Mode::LEQ>(mlir::OpBuilder& builder, mlir::Location loc, | |||||
ValueRange operands) { | |||||
ValueBuilderHelper helper(builder, loc); | |||||
return helper.select(helper.le(operands[0], operands[1]), | |||||
helper.const_f32(1.f), helper.const_f32(0.f)); | |||||
} | |||||
//! LOG_SUM_EXP: log(exp(x) + exp(y)) | |||||
template <> | |||||
mlir::Value lower_mode<Mode::LOG_SUM_EXP>(mlir::OpBuilder& builder, | |||||
mlir::Location loc, | |||||
ValueRange operands) { | |||||
ValueBuilderHelper helper(builder, loc); | |||||
return helper.log( | |||||
helper.add(helper.exp(operands[0]), helper.exp(operands[1]))); | |||||
} | |||||
//! LT: x < y ? 1 : 0 | |||||
template <> | |||||
mlir::Value lower_mode<Mode::LT>(mlir::OpBuilder& builder, mlir::Location loc, | |||||
ValueRange operands) { | |||||
ValueBuilderHelper helper(builder, loc); | |||||
return helper.select(helper.lt(operands[0], operands[1]), | |||||
helper.const_f32(1.f), helper.const_f32(0.f)); | |||||
} | |||||
//! POW: x^y = exp(y * log(x)) | |||||
template <> | |||||
mlir::Value lower_mode<Mode::POW>(mlir::OpBuilder& builder, mlir::Location loc, | |||||
ValueRange operands) { | |||||
ValueBuilderHelper helper(builder, loc); | |||||
return helper.exp(helper.mul(operands[1], helper.log(operands[0]))); | |||||
} | |||||
//! SIGMOID_GRAD: x * (1 - x) * y | |||||
template <> | |||||
mlir::Value lower_mode<Mode::SIGMOID_GRAD>(mlir::OpBuilder& builder, | |||||
mlir::Location loc, | |||||
ValueRange operands) { | |||||
ValueBuilderHelper helper(builder, loc); | |||||
return helper.mul(helper.mul(operands[0], helper.sub(helper.const_f32(1.f), | |||||
operands[0])), | |||||
operands[1]); | |||||
} | |||||
//! SWITCH_GT0: (x > 0) * y | |||||
template <> | |||||
mlir::Value lower_mode<Mode::SWITCH_GT0>(mlir::OpBuilder& builder, | |||||
mlir::Location loc, | |||||
ValueRange operands) { | |||||
ValueBuilderHelper helper(builder, loc); | |||||
return helper.select(helper.gt(operands[0], helper.const_f32(0.f)), | |||||
operands[1], helper.const_f32(0.f)); | |||||
} | |||||
//! TANH_GRAD: (1 - x * x) * y | |||||
template <> | |||||
mlir::Value lower_mode<Mode::TANH_GRAD>(mlir::OpBuilder& builder, | |||||
mlir::Location loc, | |||||
ValueRange operands) { | |||||
ValueBuilderHelper helper(builder, loc); | |||||
return helper.mul(helper.sub(helper.const_f32(1.0f), | |||||
helper.mul(operands[0], operands[0])), | |||||
operands[1]); | |||||
} | |||||
/* ===================== ternary op ===================== */ | |||||
//! COND_LEQ_MOV: x <= y ? z : ctype(0) | |||||
template <> | |||||
mlir::Value lower_mode<Mode::COND_LEQ_MOV>(mlir::OpBuilder& builder, | |||||
mlir::Location loc, | |||||
ValueRange operands) { | |||||
ValueBuilderHelper helper(builder, loc); | |||||
return helper.select(helper.le(operands[0], operands[1]), operands[2], | |||||
helper.const_f32(0.f)); | |||||
} | |||||
//! FUSE_MUL_ADD3: x * y + z | |||||
template <> | |||||
mlir::Value lower_mode<Mode::FUSE_MUL_ADD3>(mlir::OpBuilder& builder, | |||||
mlir::Location loc, | |||||
ValueRange operands) { | |||||
ValueBuilderHelper helper(builder, loc); | |||||
return helper.add(helper.mul(operands[0], operands[1]), operands[2]); | |||||
} | |||||
/* ===================== elemwise ===================== */ | |||||
mlir::Value lower_elemwise_to_std(mlir::Operation* op, mlir::OpBuilder& builder, | |||||
mlir::Location loc, ValueRange operands) { | |||||
auto mode = llvm::dyn_cast<dialect::Elemwise>(op).mode(); | |||||
switch (mode) { | |||||
#define cb(_, _mode) \ | |||||
case Mode::_mode: \ | |||||
return lower_mode<Mode::_mode>(builder, loc, operands); | |||||
MLIR_MGB_FOREACH_ELEMWISE_MODE_UNARY(cb); | |||||
MLIR_MGB_FOREACH_ELEMWISE_MODE_BINARY(cb); | |||||
MLIR_MGB_FOREACH_ELEMWISE_MODE_TERNARY(cb); | |||||
default: | |||||
return nullptr; | |||||
} | |||||
#undef cb | |||||
} | |||||
/* ===================== typecvt ===================== */ | |||||
mlir::Value lower_typecvt_to_std(mlir::Operation* op, mlir::OpBuilder& builder, | |||||
mlir::Location loc, mlir::Value input) { | |||||
auto&& typecvt = llvm::dyn_cast<dialect::TypeCvt>(op); | |||||
megdnn::DType idtype = typecvt.idtype(); | |||||
megdnn::DType odtype = typecvt.odtype(); | |||||
mlir::Type itype = input.getType(); | |||||
mlir::Type otype = megdnn_dtype_to_mlir_type(odtype, builder.getContext()); | |||||
if (mlir::FPExtOp::areCastCompatible(itype, otype)) { | |||||
return builder.create<mlir::FPExtOp>(loc, otype, input); | |||||
} else if (mlir::FPTruncOp::areCastCompatible(itype, otype)) { | |||||
return builder.create<mlir::FPTruncOp>(loc, otype, input); | |||||
} else if (mlir::FPToSIOp::areCastCompatible(itype, otype) and | |||||
is_signed_int_dtype(odtype)) { | |||||
return builder.create<mlir::FPToSIOp>(loc, otype, input); | |||||
} else if (mlir::FPToUIOp::areCastCompatible(itype, otype) and | |||||
is_unsigned_int_dtype(odtype)) { | |||||
return builder.create<mlir::FPToUIOp>(loc, otype, input); | |||||
} else if (mlir::SIToFPOp::areCastCompatible(itype, otype) and | |||||
is_signed_int_dtype(idtype)) { | |||||
return builder.create<mlir::SIToFPOp>(loc, otype, input); | |||||
} else if (mlir::UIToFPOp::areCastCompatible(itype, otype) and | |||||
is_unsigned_int_dtype(idtype)) { | |||||
return builder.create<mlir::UIToFPOp>(loc, otype, input); | |||||
} else { | |||||
mgb_throw(InternalError, "cannot convert from %s to %s", idtype.name(), | |||||
odtype.name()); | |||||
} | |||||
return nullptr; | |||||
} | |||||
} // namespace jit | |||||
} // namespace mgb | |||||
#endif // MGB_JIT && MGB_JIT_MLIR | |||||
// vim: syntax=cpp.doxygen |
@@ -15,65 +15,60 @@ | |||||
#include "megbrain_build_config.h" | #include "megbrain_build_config.h" | ||||
#if MGB_JIT && MGB_JIT_MLIR | #if MGB_JIT && MGB_JIT_MLIR | ||||
#include "megbrain/jit/mlir/ir/dialect.h" | |||||
#include "megdnn/opr_param_defs.h" | |||||
#include "./common.h" | |||||
#include "./numerical.h" | |||||
#include <mlir/Dialect/StandardOps/IR/Ops.h> | |||||
#include <mlir/IR/Builders.h> | #include <mlir/IR/Builders.h> | ||||
#include <mlir/IR/Value.h> | |||||
// clang-format off | // clang-format off | ||||
#define MLIR_MGB_FOREACH_ELEMWISE_MODE_UNARY(cb) \ | #define MLIR_MGB_FOREACH_ELEMWISE_MODE_UNARY(cb) \ | ||||
cb(ReluOp, RELU) \ | |||||
cb(AbsOp, ABS) \ | cb(AbsOp, ABS) \ | ||||
cb(NegOp, NEGATE) \ | |||||
cb(AcosOp, ACOS) \ | cb(AcosOp, ACOS) \ | ||||
cb(AsinOp, ASIN) \ | cb(AsinOp, ASIN) \ | ||||
cb(CeilOp, CEIL) \ | cb(CeilOp, CEIL) \ | ||||
cb(CosOp, COS) \ | cb(CosOp, COS) \ | ||||
cb(ErfCInvOp, ERFCINV) \ | |||||
cb(ErfCOp, ERFC) \ | |||||
cb(ErfInvOp, ERFINV) \ | |||||
cb(ErfOp, ERF) \ | |||||
cb(ExpM1Op, EXPM1) \ | |||||
cb(ExpOp, EXP) \ | cb(ExpOp, EXP) \ | ||||
cb(FastTanhOp, FAST_TANH) \ | |||||
cb(FloorOp, FLOOR) \ | cb(FloorOp, FLOOR) \ | ||||
cb(LogOp, LOG) \ | |||||
cb(HswishOp, H_SWISH) \ | |||||
cb(Log1POp, LOG1P) \ | cb(Log1POp, LOG1P) \ | ||||
cb(LogOp, LOG) \ | |||||
cb(NegOp, NEGATE) \ | |||||
cb(ReluOp, RELU) \ | |||||
cb(RoundOp, ROUND) \ | |||||
cb(SigmoidOp, SIGMOID) \ | cb(SigmoidOp, SIGMOID) \ | ||||
cb(SinOp, SIN) \ | cb(SinOp, SIN) \ | ||||
cb(TanhOp, TANH) \ | |||||
cb(FastTanhOp, FAST_TANH) \ | |||||
cb(HswishOp, H_SWISH) \ | |||||
cb(ExpM1Op, EXPM1) \ | |||||
cb(RoundOp, ROUND) \ | |||||
cb(ErfOp, ERF) \ | |||||
cb(ErfInvOp, ERFINV) \ | |||||
cb(ErfCOp, ERFC) \ | |||||
cb(ErfCInvOp, ERFCINV) | |||||
cb(TanhOp, TANH) | |||||
#define MLIR_MGB_FOREACH_ELEMWISE_MODE_BINARY(cb) \ | #define MLIR_MGB_FOREACH_ELEMWISE_MODE_BINARY(cb) \ | ||||
cb(AbsGradOp, ABS_GRAD) \ | cb(AbsGradOp, ABS_GRAD) \ | ||||
cb(AddOp, ADD) \ | cb(AddOp, ADD) \ | ||||
cb(Atan2Op, ATAN2) \ | |||||
cb(EqOp, EQ) \ | |||||
cb(FastTanhGradOp, FAST_TANH_GRAD) \ | |||||
cb(FloorDivOp, FLOOR_DIV) \ | cb(FloorDivOp, FLOOR_DIV) \ | ||||
cb(FuseAddHswishOp, FUSE_ADD_H_SWISH) \ | |||||
cb(FuseAddReluOp, FUSE_ADD_RELU) \ | |||||
cb(FuseAddSigmoidOp, FUSE_ADD_SIGMOID) \ | |||||
cb(FuseAddTanhOp, FUSE_ADD_TANH) \ | |||||
cb(HswishGradOp, H_SWISH_GRAD) \ | |||||
cb(LeqOp, LEQ) \ | |||||
cb(LogSumExpOp, LOG_SUM_EXP) \ | |||||
cb(LtOp, LT) \ | |||||
cb(MaxOp, MAX) \ | cb(MaxOp, MAX) \ | ||||
cb(MinOp, MIN) \ | cb(MinOp, MIN) \ | ||||
cb(ModOp, MOD) \ | cb(ModOp, MOD) \ | ||||
cb(SubOp, SUB) \ | |||||
cb(MulOp, MUL) \ | cb(MulOp, MUL) \ | ||||
cb(TrueDivOp, TRUE_DIV) \ | |||||
cb(PowOp, POW) \ | cb(PowOp, POW) \ | ||||
cb(SigmoidGradOp, SIGMOID_GRAD) \ | cb(SigmoidGradOp, SIGMOID_GRAD) \ | ||||
cb(SubOp, SUB) \ | |||||
cb(SwishGt0Op, SWITCH_GT0) \ | cb(SwishGt0Op, SWITCH_GT0) \ | ||||
cb(TanhGradOp, TANH_GRAD) \ | cb(TanhGradOp, TANH_GRAD) \ | ||||
cb(LtOp, LT) \ | |||||
cb(LeqOp, LEQ) \ | |||||
cb(EqOp, EQ) \ | |||||
cb(FuseAddReluOp, FUSE_ADD_RELU) \ | |||||
cb(LogSumExpOp, LOG_SUM_EXP) \ | |||||
cb(FuseAddTanhOp, FUSE_ADD_TANH) \ | |||||
cb(FastTanhGradOp, FAST_TANH_GRAD) \ | |||||
cb(FuseAddSigmoidOp, FUSE_ADD_SIGMOID) \ | |||||
cb(HswishGradOp, H_SWISH_GRAD) \ | |||||
cb(FuseAddHswishOp, FUSE_ADD_H_SWISH) \ | |||||
cb(Atan2Op, ATAN2) | |||||
cb(TrueDivOp, TRUE_DIV) | |||||
#define MLIR_MGB_FOREACH_ELEMWISE_MODE_TERNARY(cb) \ | #define MLIR_MGB_FOREACH_ELEMWISE_MODE_TERNARY(cb) \ | ||||
cb(CondLeqMovOp, COND_LEQ_MOV) \ | cb(CondLeqMovOp, COND_LEQ_MOV) \ | ||||
@@ -83,432 +78,19 @@ | |||||
namespace mgb { | namespace mgb { | ||||
namespace jit { | namespace jit { | ||||
template <typename mgb_op> | |||||
struct StandardOp; | |||||
#define cb(mgb_op, fun) \ | |||||
template <> \ | |||||
struct StandardOp<jit::mgb_op> { \ | |||||
mlir::Value operator()(mlir::OpBuilder& builder, mlir::Location loc, \ | |||||
ValueRange operands) { \ | |||||
ValueBuilderHelper helper(builder, loc); \ | |||||
return helper.fun(operands); \ | |||||
} \ | |||||
} | |||||
//! unary | |||||
cb(AbsOp, abs); | |||||
cb(NegOp, neg); | |||||
cb(ExpOp, exp); | |||||
cb(CosOp, cos); | |||||
cb(CeilOp, ceil); | |||||
cb(FloorOp, floor); | |||||
cb(LogOp, log); | |||||
cb(SinOp, sin); | |||||
cb(TanhOp, tanh); | |||||
//! binary | |||||
cb(AddOp, add); | |||||
cb(MaxOp, max); | |||||
cb(MinOp, min); | |||||
cb(SubOp, sub); | |||||
cb(MulOp, mul); | |||||
cb(ModOp, mod); | |||||
cb(TrueDivOp, div); | |||||
#undef cb | |||||
/////////////////////////// unary op /////////////////////////// | |||||
//! max(x, 0) | |||||
template <> | |||||
struct StandardOp<jit::ReluOp> { | |||||
mlir::Value operator()(mlir::OpBuilder& builder, mlir::Location loc, | |||||
ValueRange operands) { | |||||
ValueBuilderHelper helper(builder, loc); | |||||
return helper.max(operands[0], helper.const_val(0.f)); | |||||
} | |||||
}; | |||||
//! x * (27.f + x * x) / (27.f + 9.f * x * x); | |||||
template <> | |||||
struct StandardOp<jit::FastTanhOp> { | |||||
mlir::Value operator()(mlir::OpBuilder& builder, mlir::Location loc, | |||||
ValueRange operands) { | |||||
ValueBuilderHelper helper(builder, loc); | |||||
auto square = helper.mul(operands[0], operands[0]); | |||||
return helper.div( | |||||
helper.mul(operands[0], | |||||
helper.add(helper.const_val(27.f), square)), | |||||
helper.add(helper.const_val(27.f), | |||||
helper.mul(helper.const_val(9.f), square))); | |||||
} | |||||
}; | |||||
//! x * clip(x + 3, 0, 6) / 6 | |||||
template <> | |||||
struct StandardOp<jit::HswishOp> { | |||||
mlir::Value operator()(mlir::OpBuilder& builder, mlir::Location loc, | |||||
ValueRange operands) { | |||||
ValueBuilderHelper helper(builder, loc); | |||||
auto const_3 = helper.const_val(3.f); | |||||
auto const_0 = helper.const_val(0.f); | |||||
auto const_6 = helper.const_val(6.f); | |||||
auto tmp = helper.add(operands[0], const_3); | |||||
return helper.div( | |||||
helper.mul(operands[0], | |||||
helper.min(helper.max(tmp, const_0), const_6)), | |||||
const_6); | |||||
} | |||||
}; | |||||
//! log(1 + p) | |||||
template <> | |||||
struct StandardOp<jit::Log1POp> { | |||||
mlir::Value operator()(mlir::OpBuilder& builder, mlir::Location loc, | |||||
ValueRange operands) { | |||||
ValueBuilderHelper helper(builder, loc); | |||||
return helper.log(helper.add(operands[0], helper.const_val(1.f))); | |||||
} | |||||
}; | |||||
//! 1.f / (expf(-y) + 1.f)) | |||||
template <> | |||||
struct StandardOp<jit::SigmoidOp> { | |||||
mlir::Value operator()(mlir::OpBuilder& builder, mlir::Location loc, | |||||
ValueRange operands) { | |||||
ValueBuilderHelper helper(builder, loc); | |||||
return helper.div(helper.const_val(1.f), | |||||
helper.add(helper.exp(helper.neg(operands[0])), | |||||
helper.const_val(1.f))); | |||||
} | |||||
}; | |||||
//! exp(x) - 1 | |||||
template <> | |||||
struct StandardOp<jit::ExpM1Op> { | |||||
mlir::Value operator()(mlir::OpBuilder& builder, mlir::Location loc, | |||||
ValueRange operands) { | |||||
ValueBuilderHelper helper(builder, loc); | |||||
return helper.sub(helper.exp(operands[0]), helper.const_val(1.f)); | |||||
} | |||||
}; | |||||
template <> | |||||
struct StandardOp<jit::RoundOp> { | |||||
mlir::Value operator()(mlir::OpBuilder& builder, mlir::Location loc, | |||||
ValueRange operands) { | |||||
ValueBuilderHelper helper(builder, loc); | |||||
return helper.select( | |||||
helper.gt(operands[0], helper.const_val(0.f)), | |||||
helper.floor(helper.add(operands[0], helper.const_val(0.5f))), | |||||
helper.ceil(helper.sub(operands[0], helper.const_val(0.5f)))); | |||||
} | |||||
}; | |||||
//! pi / 2 - arctan2(x, sqrt(1 - x * x)) | |||||
template <> | |||||
struct StandardOp<jit::AcosOp> { | |||||
mlir::Value operator()(mlir::OpBuilder& builder, mlir::Location loc, | |||||
ValueRange operands) { | |||||
ValueBuilderHelper helper(builder, loc); | |||||
auto x = operands[0]; | |||||
auto one_minus_x_2 = helper.sub(helper.const_val(1.f), helper.mul(x, x)); | |||||
auto asin = atan2_approx(helper, x, helper.sqrt(one_minus_x_2)); | |||||
auto pi_over_2 = helper.const_val(1.57079637f); | |||||
return helper.sub(pi_over_2, asin); | |||||
} | |||||
}; | |||||
//! arctan2(x, sqrt(1 - x * x)) | |||||
template <> | |||||
struct StandardOp<jit::AsinOp> { | |||||
mlir::Value operator()(mlir::OpBuilder& builder, mlir::Location loc, | |||||
ValueRange operands) { | |||||
ValueBuilderHelper helper(builder, loc); | |||||
auto x = operands[0]; | |||||
auto one_minus_x_2 = helper.sub(helper.const_val(1.f), helper.mul(x, x)); | |||||
return atan2_approx(helper, x, helper.sqrt(one_minus_x_2)); | |||||
} | |||||
}; | |||||
//! gauss error function | |||||
template <> | |||||
struct StandardOp<jit::ErfOp> { | |||||
mlir::Value operator()(mlir::OpBuilder& builder, mlir::Location loc, | |||||
ValueRange operands) { | |||||
ValueBuilderHelper helper(builder, loc); | |||||
return erf_approx(helper, operands[0]); | |||||
} | |||||
}; | |||||
//! inverse of gauss error function | |||||
//! https://github.com/scipy/scipy/blob/master/scipy/special/cephes/erfinv.c | |||||
template <> | |||||
struct StandardOp<jit::ErfInvOp> { | |||||
mlir::Value operator()(mlir::OpBuilder& builder, mlir::Location loc, | |||||
ValueRange operands) { | |||||
ValueBuilderHelper helper(builder, loc); | |||||
auto sqrt2 = helper.const_val(1.4142135623f); | |||||
auto x = helper.mul(helper.const_val(0.5f), | |||||
helper.add(operands[0], helper.const_val(1.f))); | |||||
return helper.div(ndtri_approx(helper, x), sqrt2); | |||||
} | |||||
}; | |||||
//! complementary error function | |||||
template <> | |||||
struct StandardOp<jit::ErfCOp> { | |||||
mlir::Value operator()(mlir::OpBuilder& builder, mlir::Location loc, | |||||
ValueRange operands) { | |||||
ValueBuilderHelper helper(builder, loc); | |||||
return helper.sub(helper.const_val(1.f), erf_approx(helper, operands[0])); | |||||
} | |||||
}; | |||||
//! inverse of complementary gauss error function | |||||
//! https://github.com/scipy/scipy/blob/master/scipy/special/cephes/erfinv.c | |||||
template <> | |||||
struct StandardOp<jit::ErfCInvOp> { | |||||
mlir::Value operator()(mlir::OpBuilder& builder, mlir::Location loc, | |||||
ValueRange operands) { | |||||
ValueBuilderHelper helper(builder, loc); | |||||
auto minus_sqrt2 = helper.const_val(-1.4142135623f); | |||||
auto x = helper.mul(helper.const_val(0.5f), operands[0]); | |||||
return helper.div(ndtri_approx(helper, x), minus_sqrt2); | |||||
} | |||||
}; | |||||
/////////////////////////// binary op /////////////////////////// | |||||
//! binary: x > 0 ? y : -y | |||||
template <> | |||||
struct StandardOp<jit::AbsGradOp> { | |||||
mlir::Value operator()(mlir::OpBuilder& builder, mlir::Location loc, | |||||
ValueRange operands) { | |||||
ValueBuilderHelper helper(builder, loc); | |||||
return helper.select(helper.gt(operands[0], helper.const_val(0.f)), | |||||
operands[1], helper.neg(operands[1])); | |||||
} | |||||
}; | |||||
//! x^y = exp(y * log(x)) | |||||
template <> | |||||
struct StandardOp<jit::PowOp> { | |||||
mlir::Value operator()(mlir::OpBuilder& builder, mlir::Location loc, | |||||
ValueRange operands) { | |||||
ValueBuilderHelper helper(builder, loc); | |||||
return helper.exp(helper.mul(operands[1], helper.log(operands[0]))); | |||||
} | |||||
}; | |||||
//! x * (1 - x) * y | |||||
template <> | |||||
struct StandardOp<jit::SigmoidGradOp> { | |||||
mlir::Value operator()(mlir::OpBuilder& builder, mlir::Location loc, | |||||
ValueRange operands) { | |||||
ValueBuilderHelper helper(builder, loc); | |||||
return helper.mul( | |||||
helper.mul(operands[0], | |||||
helper.sub(helper.const_val(1.f), operands[0])), | |||||
operands[1]); | |||||
} | |||||
}; | |||||
//! (x > 0) * y | |||||
template <> | |||||
struct StandardOp<jit::SwishGt0Op> { | |||||
mlir::Value operator()(mlir::OpBuilder& builder, mlir::Location loc, | |||||
ValueRange operands) { | |||||
ValueBuilderHelper helper(builder, loc); | |||||
return helper.select(helper.gt(operands[0], helper.const_val(0.f)), | |||||
operands[1], helper.const_val(0.f)); | |||||
} | |||||
}; | |||||
//! (1 - x * x) * y | |||||
template <> | |||||
struct StandardOp<jit::TanhGradOp> { | |||||
mlir::Value operator()(mlir::OpBuilder& builder, mlir::Location loc, | |||||
ValueRange operands) { | |||||
ValueBuilderHelper helper(builder, loc); | |||||
return helper.mul(helper.sub(helper.const_val(1.0f), | |||||
helper.mul(operands[0], operands[0])), | |||||
operands[1]); | |||||
} | |||||
}; | |||||
#define cb(op, fun) \ | |||||
template <> \ | |||||
struct StandardOp<jit::op> { \ | |||||
mlir::Value operator()(mlir::OpBuilder& builder, mlir::Location loc, \ | |||||
ValueRange operands) { \ | |||||
ValueBuilderHelper helper(builder, loc); \ | |||||
return helper.select(helper.fun(operands[0], operands[1]), \ | |||||
helper.const_val(1.f), \ | |||||
helper.const_val(0.f)); \ | |||||
} \ | |||||
} | |||||
cb(LtOp, lt); | |||||
cb(LeqOp, le); | |||||
cb(EqOp, eq); | |||||
#undef cb | |||||
//! (x + y) <= ctype(0) ? ctype(0) : (x + y) | |||||
template <> | |||||
struct StandardOp<jit::FuseAddReluOp> { | |||||
mlir::Value operator()(mlir::OpBuilder& builder, mlir::Location loc, | |||||
ValueRange operands) { | |||||
ValueBuilderHelper helper(builder, loc); | |||||
auto sum = helper.add(operands[0], operands[1]); | |||||
return helper.max(sum, helper.const_val(0.f)); | |||||
} | |||||
}; | |||||
//! log(exp(x) + exp(y)) | |||||
template <> | |||||
struct StandardOp<jit::LogSumExpOp> { | |||||
mlir::Value operator()(mlir::OpBuilder& builder, mlir::Location loc, | |||||
ValueRange operands) { | |||||
ValueBuilderHelper helper(builder, loc); | |||||
return helper.log( | |||||
helper.add(helper.exp(operands[0]), helper.exp(operands[1]))); | |||||
} | |||||
}; | |||||
//! floor(x/y) | |||||
template <> | |||||
struct StandardOp<jit::FloorDivOp> { | |||||
mlir::Value operator()(mlir::OpBuilder& builder, mlir::Location loc, | |||||
ValueRange operands) { | |||||
ValueBuilderHelper helper(builder, loc); | |||||
return helper.floor(helper.div(operands[0], operands[1])); | |||||
} | |||||
}; | |||||
//! tanh(x + y) | |||||
template <> | |||||
struct StandardOp<jit::FuseAddTanhOp> { | |||||
mlir::Value operator()(mlir::OpBuilder& builder, mlir::Location loc, | |||||
ValueRange operands) { | |||||
ValueBuilderHelper helper(builder, loc); | |||||
return helper.tanh(helper.add(operands[0], operands[1])); | |||||
} | |||||
}; | |||||
//! ((-48.f * x * x) / (3.f + x * x) + 27.f + x * x) / (3.f + x * x) * y | |||||
template <> | |||||
struct StandardOp<jit::FastTanhGradOp> { | |||||
mlir::Value operator()(mlir::OpBuilder& builder, mlir::Location loc, | |||||
ValueRange operands) { | |||||
ValueBuilderHelper helper(builder, loc); | |||||
auto x_pow2 = helper.mul(operands[0], operands[0]); | |||||
auto deno = helper.add(helper.const_val(3.f), x_pow2); | |||||
return helper.mul( | |||||
helper.div( | |||||
helper.add( | |||||
helper.add( | |||||
helper.div(helper.mul(helper.const_val( | |||||
-48.f), | |||||
x_pow2), | |||||
deno), | |||||
helper.const_val(27.f)), | |||||
x_pow2), | |||||
helper.mul(deno, helper.const_val(9.f))), | |||||
operands[1]); | |||||
} | |||||
}; | |||||
//! 1.f / (expf(-(x+y)) + 1.f)) | |||||
template <> | |||||
struct StandardOp<jit::FuseAddSigmoidOp> { | |||||
mlir::Value operator()(mlir::OpBuilder& builder, mlir::Location loc, | |||||
ValueRange operands) { | |||||
ValueBuilderHelper helper(builder, loc); | |||||
return helper.div(helper.const_val(1.f), | |||||
helper.add(helper.exp(helper.neg(helper.add( | |||||
operands[0], operands[1]))), | |||||
helper.const_val(1.f))); | |||||
} | |||||
}; | |||||
//! x < -3.f ? 0.f : (x > 3.f ? y : (2.f * x + 3.f) / 6.f * y) | |||||
template <> | |||||
struct StandardOp<jit::HswishGradOp> { | |||||
mlir::Value operator()(mlir::OpBuilder& builder, mlir::Location loc, | |||||
ValueRange operands) { | |||||
ValueBuilderHelper helper(builder, loc); | |||||
return helper.select( | |||||
helper.lt(operands[0], helper.const_val(-3.f)), | |||||
helper.const_val(0.f), | |||||
helper.select( | |||||
helper.gt(operands[0], helper.const_val(3.f)), | |||||
operands[1], | |||||
helper.mul( | |||||
helper.div( | |||||
helper.add(helper.mul(helper.const_val( | |||||
2.f), | |||||
operands[0]), | |||||
helper.const_val(3.f)), | |||||
helper.const_val(6.f)), | |||||
operands[1]))); | |||||
} | |||||
}; | |||||
//! (x+y) * min(max(x + y + 3, 0), 6) * (1/6) | |||||
template <> | |||||
struct StandardOp<jit::FuseAddHswishOp> { | |||||
mlir::Value operator()(mlir::OpBuilder& builder, mlir::Location loc, | |||||
ValueRange operands) { | |||||
ValueBuilderHelper helper(builder, loc); | |||||
auto sum = helper.add(operands[0], operands[1]); | |||||
auto const_3 = helper.const_val(3.f); | |||||
auto const_0 = helper.const_val(0.f); | |||||
auto const_6 = helper.const_val(6.f); | |||||
auto tmp = helper.add(sum, const_3); | |||||
return helper.div( | |||||
helper.mul(sum, helper.min(helper.max(tmp, const_0), const_6)), | |||||
const_6); | |||||
} | |||||
}; | |||||
//! arctan | |||||
template <> | |||||
struct StandardOp<jit::Atan2Op> { | |||||
mlir::Value operator()(mlir::OpBuilder& builder, mlir::Location loc, | |||||
ValueRange operands) { | |||||
ValueBuilderHelper helper(builder, loc); | |||||
return atan2_approx(helper, operands[0], operands[1]); | |||||
} | |||||
}; | |||||
/////////////////////////// ternary op /////////////////////////// | |||||
//! x <= y ? z : ctype(0) | |||||
template <> | |||||
struct StandardOp<jit::CondLeqMovOp> { | |||||
mlir::Value operator()(mlir::OpBuilder& builder, mlir::Location loc, | |||||
ValueRange operands) { | |||||
ValueBuilderHelper helper(builder, loc); | |||||
return helper.select(helper.le(operands[0], operands[1]), operands[2], | |||||
helper.const_val(0.f)); | |||||
} | |||||
}; | |||||
mlir::Value lower_elemwise_to_std(mlir::Operation* op, | |||||
mlir::OpBuilder& builder, | |||||
mlir::Location loc, | |||||
mlir::ValueRange operands); | |||||
//! x * y + z | |||||
template <> | |||||
struct StandardOp<jit::FuseMulAdd3Op> { | |||||
mlir::Value operator()(mlir::OpBuilder& builder, mlir::Location loc, | |||||
ValueRange operands) { | |||||
ValueBuilderHelper helper(builder, loc); | |||||
return helper.add(helper.mul(operands[0], operands[1]), operands[2]); | |||||
} | |||||
}; | |||||
mlir::Value lower_typecvt_to_std(mlir::Operation* op, | |||||
mlir::OpBuilder& builder, | |||||
mlir::Location loc, | |||||
mlir::Value input); | |||||
} // namespace jit | } // namespace jit | ||||
} // namespace mgb | } // namespace mgb | ||||
#endif // MGB_JIT_MLIR | |||||
#endif // MGB_JIT && MGB_JIT_MLIR | |||||
// vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen |
@@ -1,33 +0,0 @@ | |||||
/** | |||||
* \file src/jit/impl/mlir/ir/interfaces.td | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | |||||
#ifndef MGB_MLIR_INTERFACES | |||||
#define MGB_MLIR_INTERFACES | |||||
#ifndef OP_BASE | |||||
include "mlir/IR/OpBase.td" | |||||
#endif | |||||
def GenericBuilderInterface : OpInterface<"GenericBuilder"> { | |||||
let methods = [ | |||||
StaticInterfaceMethod<"TODO", "Type", "getResultType", (ins "ArrayRef<Value>":$operands)>, | |||||
StaticInterfaceMethod<"TODO", "Operation*", "create", (ins | |||||
"OpBuilder*":$builder, | |||||
"Location":$loc, | |||||
"ArrayRef<Value>":$operands | |||||
)>, | |||||
]; | |||||
} | |||||
def ElemwiseOpInterface : OpInterface<"ElemwiseOp">; | |||||
#endif |
@@ -13,18 +13,19 @@ | |||||
#include "megbrain_build_config.h" | #include "megbrain_build_config.h" | ||||
#if MGB_JIT && MGB_JIT_MLIR | #if MGB_JIT && MGB_JIT_MLIR | ||||
#include "./common.h" | |||||
#include "./each_mode.h" | |||||
#include "megbrain/common.h" | #include "megbrain/common.h" | ||||
#include "megbrain/jit/mlir/ir/dialect.h" | #include "megbrain/jit/mlir/ir/dialect.h" | ||||
#include "megbrain/jit/mlir/ir/passes.h" | #include "megbrain/jit/mlir/ir/passes.h" | ||||
#include "megbrain/jit/mlir/ir/utils.h" | #include "megbrain/jit/mlir/ir/utils.h" | ||||
#include "./each_mode.h" | |||||
#include <llvm/ADT/Sequence.h> | #include <llvm/ADT/Sequence.h> | ||||
#include <mlir/Dialect/Affine/IR/AffineOps.h> | #include <mlir/Dialect/Affine/IR/AffineOps.h> | ||||
#include <mlir/IR/StandardTypes.h> | |||||
#include <mlir/Pass/Pass.h> | #include <mlir/Pass/Pass.h> | ||||
#include <mlir/Transforms/DialectConversion.h> | #include <mlir/Transforms/DialectConversion.h> | ||||
#include "mlir/IR/StandardTypes.h" | |||||
using namespace mgb; | using namespace mgb; | ||||
using namespace jit; | using namespace jit; | ||||
@@ -57,41 +58,10 @@ void lower_op_to_loops(Operation* op, ValueRange operands, | |||||
rewriter.replaceOp(op, alloc); | rewriter.replaceOp(op, alloc); | ||||
} | } | ||||
template <typename Op, typename LoweredOp> | |||||
struct UnaryOpLowering : public ConversionPattern { | |||||
UnaryOpLowering(MLIRContext* ctx) | |||||
: ConversionPattern(Op::getOperationName(), 1, ctx) {} | |||||
LogicalResult matchAndRewrite( | |||||
Operation* op, ArrayRef<Value> operands, | |||||
ConversionPatternRewriter& rewriter) const final { | |||||
auto loc = op->getLoc(); | |||||
lower_op_to_loops( | |||||
op, operands, rewriter, | |||||
[loc](OpBuilder& builder, ValueRange memref_operands, | |||||
ValueRange loop_ivs) { | |||||
typename Op::Adaptor binary_adaptor(memref_operands); | |||||
LoweredOp lower_op; | |||||
auto loaded_lhs = get_operand<AffineLoadOp>( | |||||
builder, loc, binary_adaptor.lhs(), loop_ivs); | |||||
return lower_op(builder, loc, {loaded_lhs}); | |||||
}); | |||||
return success(); | |||||
} | |||||
}; | |||||
#define cb(_op, _) \ | |||||
using _op##Lowering = UnaryOpLowering<jit::_op, jit::StandardOp<jit::_op>>; | |||||
MLIR_MGB_FOREACH_ELEMWISE_MODE_UNARY(cb) | |||||
#undef cb | |||||
template <typename Op, typename LoweredOp> | |||||
struct BinaryOpLowering : public ConversionPattern { | |||||
BinaryOpLowering(MLIRContext* ctx) | |||||
: ConversionPattern(Op::getOperationName(), 1, ctx) {} | |||||
struct ElemwiseLowering : public ConversionPattern { | |||||
ElemwiseLowering(MLIRContext* ctx) | |||||
: ConversionPattern(mgb::dialect::Elemwise::getOperationName(), 1, | |||||
ctx) {} | |||||
LogicalResult matchAndRewrite( | LogicalResult matchAndRewrite( | ||||
Operation* op, ArrayRef<Value> operands, | Operation* op, ArrayRef<Value> operands, | ||||
ConversionPatternRewriter& rewriter) const final { | ConversionPatternRewriter& rewriter) const final { | ||||
@@ -101,83 +71,51 @@ struct BinaryOpLowering : public ConversionPattern { | |||||
dst_layout.init_contiguous_stride(); | dst_layout.init_contiguous_stride(); | ||||
lower_op_to_loops( | lower_op_to_loops( | ||||
op, operands, rewriter, | op, operands, rewriter, | ||||
[dst_layout, loc, this](OpBuilder& builder, | |||||
ValueRange memref_operands, | |||||
ValueRange loop_ivs) { | |||||
typename Op::Adaptor binary_adaptor(memref_operands); | |||||
LoweredOp lower_op; | |||||
auto loaded_lhs = get_affine_load_op(builder, loc, | |||||
binary_adaptor.lhs(), | |||||
loop_ivs, dst_layout); | |||||
auto loaded_rhs = get_affine_load_op(builder, loc, | |||||
binary_adaptor.rhs(), | |||||
loop_ivs, dst_layout); | |||||
return lower_op(builder, loc, {loaded_lhs, loaded_rhs}); | |||||
[dst_layout, loc, op](OpBuilder& builder, | |||||
ValueRange memref_operands, | |||||
ValueRange loop_ivs) { | |||||
auto inputs = llvm::to_vector<4>(llvm::map_range( | |||||
memref_operands, [&](mlir::Value val) { | |||||
return get_affine_load_op(builder, loc, val, | |||||
loop_ivs, dst_layout); | |||||
})); | |||||
return lower_elemwise_to_std(op, builder, loc, inputs); | |||||
}); | }); | ||||
return success(); | return success(); | ||||
} | } | ||||
}; | }; | ||||
#define cb(_op, _) \ | |||||
using _op##Lowering = BinaryOpLowering<jit::_op, jit::StandardOp<jit::_op>>; | |||||
MLIR_MGB_FOREACH_ELEMWISE_MODE_BINARY(cb) | |||||
#undef cb | |||||
template <typename Op, typename LoweredOp> | |||||
struct TernaryOpLowering : public ConversionPattern { | |||||
TernaryOpLowering(MLIRContext* ctx) | |||||
: ConversionPattern(Op::getOperationName(), 1, ctx) {} | |||||
struct TypeCvtLowering : public ConversionPattern { | |||||
TypeCvtLowering(MLIRContext* ctx) | |||||
: ConversionPattern(mgb::dialect::TypeCvt::getOperationName(), 1, | |||||
ctx) {} | |||||
LogicalResult matchAndRewrite( | LogicalResult matchAndRewrite( | ||||
Operation* op, ArrayRef<Value> operands, | Operation* op, ArrayRef<Value> operands, | ||||
ConversionPatternRewriter& rewriter) const final { | ConversionPatternRewriter& rewriter) const final { | ||||
auto loc = op->getLoc(); | auto loc = op->getLoc(); | ||||
auto dst_memref_type = (*op->result_type_begin()).cast<MemRefType>(); | |||||
megdnn::TensorLayout dst_layout = mlir_type_to_layout(dst_memref_type); | |||||
dst_layout.init_contiguous_stride(); | |||||
lower_op_to_loops( | lower_op_to_loops( | ||||
op, operands, rewriter, | op, operands, rewriter, | ||||
[dst_layout, loc](OpBuilder& builder, | |||||
ValueRange memref_operands, | |||||
ValueRange loop_ivs) { | |||||
typename Op::Adaptor ternary_adaptor(memref_operands); | |||||
LoweredOp lower_op; | |||||
auto loaded_x = get_affine_load_op(builder, loc, | |||||
ternary_adaptor.x(), | |||||
loop_ivs, dst_layout); | |||||
auto loaded_y = get_affine_load_op(builder, loc, | |||||
ternary_adaptor.y(), | |||||
loop_ivs, dst_layout); | |||||
auto loaded_z = get_affine_load_op(builder, loc, | |||||
ternary_adaptor.z(), | |||||
loop_ivs, dst_layout); | |||||
return lower_op(builder, loc, | |||||
{loaded_x, loaded_y, loaded_z}); | |||||
[loc, op](OpBuilder& builder, ValueRange memref_operands, | |||||
ValueRange loop_ivs) { | |||||
mlir::Value input = get_operand<AffineLoadOp>( | |||||
builder, loc, memref_operands[0], loop_ivs); | |||||
return lower_typecvt_to_std(op, builder, loc, input); | |||||
}); | }); | ||||
return success(); | return success(); | ||||
} | } | ||||
}; | }; | ||||
#define cb(_op, _) \ | |||||
using _op##Lowering = \ | |||||
TernaryOpLowering<jit::_op, jit::StandardOp<jit::_op>>; | |||||
MLIR_MGB_FOREACH_ELEMWISE_MODE_TERNARY(cb) | |||||
#undef cb | |||||
struct AssignOpLowering : public ConversionPattern { | struct AssignOpLowering : public ConversionPattern { | ||||
AssignOpLowering(MLIRContext* ctx) | AssignOpLowering(MLIRContext* ctx) | ||||
: ConversionPattern(jit::AssignOp::getOperationName(), 1, ctx) {} | |||||
: ConversionPattern(dialect::AssignOp::getOperationName(), 1, ctx) { | |||||
} | |||||
LogicalResult matchAndRewrite( | LogicalResult matchAndRewrite( | ||||
Operation* op, ArrayRef<Value> operands, | Operation* op, ArrayRef<Value> operands, | ||||
ConversionPatternRewriter& rewriter) const final { | ConversionPatternRewriter& rewriter) const final { | ||||
auto loc = op->getLoc(); | auto loc = op->getLoc(); | ||||
auto memref_type = operands[0].getType().cast<MemRefType>(); | auto memref_type = operands[0].getType().cast<MemRefType>(); | ||||
AssignOpAdaptor assign_adaptor(operands); | |||||
dialect::AssignOpAdaptor assign_adaptor(operands); | |||||
llvm::SmallVector<int64_t, 4> lower_bounds(memref_type.getRank(), 0); | llvm::SmallVector<int64_t, 4> lower_bounds(memref_type.getRank(), 0); | ||||
llvm::SmallVector<int64_t, 4> steps(memref_type.getRank(), 1); | llvm::SmallVector<int64_t, 4> steps(memref_type.getRank(), 1); | ||||
@@ -195,10 +133,10 @@ struct AssignOpLowering : public ConversionPattern { | |||||
} | } | ||||
}; | }; | ||||
struct ReturnOpLowering : public OpRewritePattern<jit::ReturnOp> { | |||||
using OpRewritePattern<jit::ReturnOp>::OpRewritePattern; | |||||
struct ReturnOpLowering : public OpRewritePattern<dialect::ReturnOp> { | |||||
using OpRewritePattern<dialect::ReturnOp>::OpRewritePattern; | |||||
LogicalResult matchAndRewrite(jit::ReturnOp op, | |||||
LogicalResult matchAndRewrite(dialect::ReturnOp op, | |||||
PatternRewriter& rewriter) const final { | PatternRewriter& rewriter) const final { | ||||
// We lower "mgb.return" directly to "std.return". | // We lower "mgb.return" directly to "std.return". | ||||
rewriter.replaceOpWithNewOp<mlir::ReturnOp>(op); | rewriter.replaceOpWithNewOp<mlir::ReturnOp>(op); | ||||
@@ -207,12 +145,12 @@ struct ReturnOpLowering : public OpRewritePattern<jit::ReturnOp> { | |||||
}; | }; | ||||
struct ConstantScalarOpLowering | struct ConstantScalarOpLowering | ||||
: public OpRewritePattern<jit::ConstantScalarOp> { | |||||
using OpRewritePattern<jit::ConstantScalarOp>::OpRewritePattern; | |||||
: public OpRewritePattern<dialect::ConstantScalarOp> { | |||||
using OpRewritePattern<dialect::ConstantScalarOp>::OpRewritePattern; | |||||
LogicalResult matchAndRewrite(jit::ConstantScalarOp op, | |||||
LogicalResult matchAndRewrite(dialect::ConstantScalarOp op, | |||||
PatternRewriter& rewriter) const final { | PatternRewriter& rewriter) const final { | ||||
ConstantScalarOpAdaptor constant_scalar_adaptor(op); | |||||
dialect::ConstantScalarOpAdaptor constant_scalar_adaptor(op); | |||||
rewriter.replaceOpWithNewOp<mlir::ConstantOp>( | rewriter.replaceOpWithNewOp<mlir::ConstantOp>( | ||||
op, constant_scalar_adaptor.value()); | op, constant_scalar_adaptor.value()); | ||||
return success(); | return success(); | ||||
@@ -234,14 +172,9 @@ public: | |||||
target.addIllegalDialect<MgbDialect>(); | target.addIllegalDialect<MgbDialect>(); | ||||
OwningRewritePatternList patterns; | OwningRewritePatternList patterns; | ||||
#define cb(_op, _) _op##Lowering, | |||||
patterns.insert<MLIR_MGB_FOREACH_ELEMWISE_MODE_UNARY( | |||||
cb) MLIR_MGB_FOREACH_ELEMWISE_MODE_BINARY(cb) | |||||
MLIR_MGB_FOREACH_ELEMWISE_MODE_TERNARY(cb) | |||||
ReturnOpLowering, | |||||
patterns.insert<ElemwiseLowering, TypeCvtLowering, ReturnOpLowering, | |||||
AssignOpLowering, ConstantScalarOpLowering>( | AssignOpLowering, ConstantScalarOpLowering>( | ||||
&getContext()); | &getContext()); | ||||
#undef cb | |||||
if (failed(applyPartialConversion(getFunction(), target, patterns))) { | if (failed(applyPartialConversion(getFunction(), target, patterns))) { | ||||
signalPassFailure(); | signalPassFailure(); | ||||
@@ -13,12 +13,19 @@ | |||||
#include "megbrain_build_config.h" | #include "megbrain_build_config.h" | ||||
#if MGB_JIT && MGB_JIT_MLIR | #if MGB_JIT && MGB_JIT_MLIR | ||||
#include "./common.h" | |||||
#include "./each_mode.h" | #include "./each_mode.h" | ||||
#include "megbrain/common.h" | #include "megbrain/common.h" | ||||
#include "megbrain/jit/mlir/ir/dialect.h" | #include "megbrain/jit/mlir/ir/dialect.h" | ||||
#include "megbrain/jit/mlir/ir/passes.h" | #include "megbrain/jit/mlir/ir/passes.h" | ||||
#include "megbrain/jit/mlir/ir/utils.h" | #include "megbrain/jit/mlir/ir/utils.h" | ||||
#include <llvm/ADT/PointerUnion.h> | |||||
#include <llvm/ADT/Sequence.h> | |||||
#include <llvm/ADT/SetVector.h> | |||||
#include <llvm/ADT/Twine.h> | |||||
#include <llvm/IR/Type.h> | |||||
#include <mlir/Dialect/GPU/GPUDialect.h> | #include <mlir/Dialect/GPU/GPUDialect.h> | ||||
#include <mlir/Dialect/SCF/SCF.h> | #include <mlir/Dialect/SCF/SCF.h> | ||||
#include <mlir/Dialect/StandardOps/IR/Ops.h> | #include <mlir/Dialect/StandardOps/IR/Ops.h> | ||||
@@ -27,12 +34,6 @@ | |||||
#include <mlir/Pass/Pass.h> | #include <mlir/Pass/Pass.h> | ||||
#include <mlir/Transforms/DialectConversion.h> | #include <mlir/Transforms/DialectConversion.h> | ||||
#include <llvm/ADT/PointerUnion.h> | |||||
#include <llvm/ADT/Sequence.h> | |||||
#include <llvm/ADT/SetVector.h> | |||||
#include <llvm/ADT/Twine.h> | |||||
#include <llvm/IR/Type.h> | |||||
using namespace mgb; | using namespace mgb; | ||||
using namespace jit; | using namespace jit; | ||||
@@ -59,7 +60,7 @@ megdnn::TensorLayout output_layout(gpu::LaunchOp& launch_op) { | |||||
block_iter++) { | block_iter++) { | ||||
for (auto op_iter = block_iter->rbegin(); op_iter != block_iter->rend(); | for (auto op_iter = block_iter->rbegin(); op_iter != block_iter->rend(); | ||||
op_iter++) { | op_iter++) { | ||||
auto op = llvm::dyn_cast_or_null<AssignOp>(&(*op_iter)); | |||||
auto op = llvm::dyn_cast_or_null<dialect::AssignOp>(&(*op_iter)); | |||||
if (op && op.getNumOperands() > 0) { | if (op && op.getNumOperands() > 0) { | ||||
return mlir_type_to_layout(*(op.operand_type_begin())); | return mlir_type_to_layout(*(op.operand_type_begin())); | ||||
} | } | ||||
@@ -81,64 +82,27 @@ std::vector<mlir::Value> get_multidim_tid(ConversionPatternRewriter& rewriter, | |||||
idxs.resize(dst.ndim); | idxs.resize(dst.ndim); | ||||
mlir::Value dim_index = index; | mlir::Value dim_index = index; | ||||
for (int i = dst.ndim - 1; i >= 0; i--) { | for (int i = dst.ndim - 1; i >= 0; i--) { | ||||
auto cur_index = helper.modI(dim_index, helper.constI(dst[i])); | |||||
auto cur_index = helper.modI(dim_index, helper.const_i32(dst[i])); | |||||
idxs[i] = cur_index; | idxs[i] = cur_index; | ||||
dim_index = helper.divI(dim_index, helper.constI(dst[i])); | |||||
dim_index = helper.divI(dim_index, helper.const_i32(dst[i])); | |||||
} | } | ||||
megdnn::TensorLayout src_layout = mlir_type_to_layout(type); | megdnn::TensorLayout src_layout = mlir_type_to_layout(type); | ||||
src_layout.init_contiguous_stride(); | src_layout.init_contiguous_stride(); | ||||
for (int i = 0; i < type.getRank(); ++i) { | for (int i = 0; i < type.getRank(); ++i) { | ||||
if (src_layout[i] == 1) { | if (src_layout[i] == 1) { | ||||
idxs[i] = helper.constI(0); | |||||
idxs[i] = helper.const_i32(0); | |||||
} | } | ||||
} | } | ||||
return idxs; | return idxs; | ||||
} else { | } else { | ||||
return {index}; | return {index}; | ||||
} | } | ||||
} | } | ||||
template <typename Op, typename LoweredOp> | |||||
struct UnaryOpLowering : public ConversionPattern { | |||||
UnaryOpLowering(MLIRContext* ctx, gpu::LaunchOp& launch_op) | |||||
: ConversionPattern(Op::getOperationName(), 1, ctx), | |||||
m_launch_op{launch_op} {} | |||||
LogicalResult matchAndRewrite( | |||||
Operation* op, ArrayRef<Value> operands, | |||||
ConversionPatternRewriter& rewriter) const final { | |||||
auto loc = op->getLoc(); | |||||
typename Op::Adaptor binary_adaptor(operands); | |||||
rewriter.setInsertionPointToEnd(&(m_launch_op.body().front())); | |||||
auto dst_layout = output_layout(m_launch_op); | |||||
auto index = get_multidim_tid(rewriter, loc, binary_adaptor.lhs(), | |||||
dst_layout); | |||||
auto loaded_lhs = | |||||
get_operand<LoadOp>(rewriter, loc, binary_adaptor.lhs(), index); | |||||
LoweredOp lower_op; | |||||
rewriter.replaceOp(op, lower_op(rewriter, loc, {loaded_lhs})); | |||||
return success(); | |||||
} | |||||
private: | |||||
gpu::LaunchOp& m_launch_op; | |||||
}; | |||||
#define cb(_op, _) \ | |||||
using _op##Lowering = UnaryOpLowering<jit::_op, jit::StandardOp<jit::_op>>; | |||||
MLIR_MGB_FOREACH_ELEMWISE_MODE_UNARY(cb) | |||||
#undef cb | |||||
template <typename Op, typename LoweredOp> | |||||
struct BinaryOpLowering : public ConversionPattern { | |||||
BinaryOpLowering(MLIRContext* ctx, gpu::LaunchOp& launch_op) | |||||
: ConversionPattern(Op::getOperationName(), 1, ctx), | |||||
struct ElemwiseLowering : public ConversionPattern { | |||||
ElemwiseLowering(MLIRContext* ctx, gpu::LaunchOp& launch_op) | |||||
: ConversionPattern(dialect::Elemwise::getOperationName(), 1, ctx), | |||||
m_launch_op{launch_op} {} | m_launch_op{launch_op} {} | ||||
LogicalResult matchAndRewrite( | LogicalResult matchAndRewrite( | ||||
@@ -146,23 +110,18 @@ struct BinaryOpLowering : public ConversionPattern { | |||||
ConversionPatternRewriter& rewriter) const final { | ConversionPatternRewriter& rewriter) const final { | ||||
auto loc = op->getLoc(); | auto loc = op->getLoc(); | ||||
typename Op::Adaptor binary_adaptor(operands); | |||||
rewriter.setInsertionPointToEnd(&(m_launch_op.body().front())); | rewriter.setInsertionPointToEnd(&(m_launch_op.body().front())); | ||||
auto dst_layout = output_layout(m_launch_op); | auto dst_layout = output_layout(m_launch_op); | ||||
auto lhs_index = get_multidim_tid(rewriter, loc, binary_adaptor.lhs(), | |||||
dst_layout); | |||||
auto rhs_index = get_multidim_tid(rewriter, loc, binary_adaptor.rhs(), | |||||
dst_layout); | |||||
auto loaded_lhs = get_operand<LoadOp>(rewriter, loc, | |||||
binary_adaptor.lhs(), lhs_index); | |||||
auto loaded_rhs = get_operand<LoadOp>(rewriter, loc, | |||||
binary_adaptor.rhs(), rhs_index); | |||||
LoweredOp lower_op; | |||||
auto inputs = llvm::to_vector<4>( | |||||
llvm::map_range(operands, [&](mlir::Value val) { | |||||
auto index = | |||||
get_multidim_tid(rewriter, loc, val, dst_layout); | |||||
return get_operand<LoadOp>(rewriter, loc, val, index); | |||||
})); | |||||
rewriter.replaceOp(op, | rewriter.replaceOp(op, | ||||
lower_op(rewriter, loc, {loaded_lhs, loaded_rhs})); | |||||
lower_elemwise_to_std(op, rewriter, loc, inputs)); | |||||
return success(); | return success(); | ||||
} | } | ||||
@@ -170,43 +129,22 @@ private: | |||||
gpu::LaunchOp& m_launch_op; | gpu::LaunchOp& m_launch_op; | ||||
}; | }; | ||||
#define cb(_op, _) \ | |||||
using _op##Lowering = BinaryOpLowering<jit::_op, jit::StandardOp<jit::_op>>; | |||||
MLIR_MGB_FOREACH_ELEMWISE_MODE_BINARY(cb) | |||||
#undef cb | |||||
template <typename Op, typename LoweredOp> | |||||
struct TernaryOpLowering : public ConversionPattern { | |||||
TernaryOpLowering(MLIRContext* ctx, gpu::LaunchOp& launch_op) | |||||
: ConversionPattern(Op::getOperationName(), 1, ctx), | |||||
struct TypeCvtLowering : public ConversionPattern { | |||||
TypeCvtLowering(MLIRContext* ctx, gpu::LaunchOp& launch_op) | |||||
: ConversionPattern(dialect::TypeCvt::getOperationName(), 1, ctx), | |||||
m_launch_op{launch_op} {} | m_launch_op{launch_op} {} | ||||
LogicalResult matchAndRewrite( | LogicalResult matchAndRewrite( | ||||
Operation* op, ArrayRef<Value> operands, | Operation* op, ArrayRef<Value> operands, | ||||
ConversionPatternRewriter& rewriter) const final { | ConversionPatternRewriter& rewriter) const final { | ||||
auto loc = op->getLoc(); | auto loc = op->getLoc(); | ||||
typename Op::Adaptor ternary_adaptor(operands); | |||||
rewriter.setInsertionPointToEnd(&(m_launch_op.body().front())); | rewriter.setInsertionPointToEnd(&(m_launch_op.body().front())); | ||||
auto dst_layout = output_layout(m_launch_op); | auto dst_layout = output_layout(m_launch_op); | ||||
auto index_x = get_multidim_tid(rewriter, loc, ternary_adaptor.x(), | |||||
dst_layout); | |||||
auto index_y = get_multidim_tid(rewriter, loc, ternary_adaptor.y(), | |||||
dst_layout); | |||||
auto index_z = get_multidim_tid(rewriter, loc, ternary_adaptor.z(), | |||||
dst_layout); | |||||
auto loaded_x = get_operand<LoadOp>(rewriter, loc, ternary_adaptor.x(), | |||||
index_x); | |||||
auto loaded_y = get_operand<LoadOp>(rewriter, loc, ternary_adaptor.y(), | |||||
index_y); | |||||
auto loaded_z = get_operand<LoadOp>(rewriter, loc, ternary_adaptor.z(), | |||||
index_z); | |||||
LoweredOp lower_op; | |||||
rewriter.replaceOp( | |||||
op, lower_op(rewriter, loc, {loaded_x, loaded_y, loaded_z})); | |||||
auto index = get_multidim_tid(rewriter, loc, operands[0], dst_layout); | |||||
auto input = get_operand<LoadOp>(rewriter, loc, operands[0], index); | |||||
rewriter.replaceOp(op, lower_typecvt_to_std(op, rewriter, loc, input)); | |||||
return success(); | return success(); | ||||
} | } | ||||
@@ -214,15 +152,9 @@ private: | |||||
gpu::LaunchOp& m_launch_op; | gpu::LaunchOp& m_launch_op; | ||||
}; | }; | ||||
#define cb(_op, _) \ | |||||
using _op##Lowering = \ | |||||
TernaryOpLowering<jit::_op, jit::StandardOp<jit::_op>>; | |||||
MLIR_MGB_FOREACH_ELEMWISE_MODE_TERNARY(cb) | |||||
#undef cb | |||||
struct ReturnOpLowering : public ConversionPattern { | struct ReturnOpLowering : public ConversionPattern { | ||||
ReturnOpLowering(MLIRContext* ctx, gpu::LaunchOp& launch_op) | ReturnOpLowering(MLIRContext* ctx, gpu::LaunchOp& launch_op) | ||||
: ConversionPattern(jit::ReturnOp::getOperationName(), 1, ctx), | |||||
: ConversionPattern(dialect::ReturnOp::getOperationName(), 1, ctx), | |||||
m_launch_op{launch_op} {} | m_launch_op{launch_op} {} | ||||
LogicalResult matchAndRewrite( | LogicalResult matchAndRewrite( | ||||
@@ -270,14 +202,14 @@ private: | |||||
}; | }; | ||||
struct ConstantScalarOpLowering | struct ConstantScalarOpLowering | ||||
: public OpRewritePattern<jit::ConstantScalarOp> { | |||||
: public OpRewritePattern<dialect::ConstantScalarOp> { | |||||
ConstantScalarOpLowering(MLIRContext* ctx, gpu::LaunchOp& launch_op) | ConstantScalarOpLowering(MLIRContext* ctx, gpu::LaunchOp& launch_op) | ||||
: OpRewritePattern<jit::ConstantScalarOp>(ctx), | |||||
: OpRewritePattern<dialect::ConstantScalarOp>(ctx), | |||||
m_launch_op{launch_op} {} | m_launch_op{launch_op} {} | ||||
LogicalResult matchAndRewrite(jit::ConstantScalarOp op, | |||||
LogicalResult matchAndRewrite(dialect::ConstantScalarOp op, | |||||
PatternRewriter& rewriter) const final { | PatternRewriter& rewriter) const final { | ||||
ConstantScalarOpAdaptor constant_scalar_adaptor(op); | |||||
dialect::ConstantScalarOpAdaptor constant_scalar_adaptor(op); | |||||
rewriter.setInsertionPointToEnd(&(m_launch_op.body().front())); | rewriter.setInsertionPointToEnd(&(m_launch_op.body().front())); | ||||
rewriter.replaceOpWithNewOp<mlir::ConstantOp>( | rewriter.replaceOpWithNewOp<mlir::ConstantOp>( | ||||
@@ -291,7 +223,7 @@ private: | |||||
struct AssignOpLowering : public ConversionPattern { | struct AssignOpLowering : public ConversionPattern { | ||||
AssignOpLowering(MLIRContext* ctx, gpu::LaunchOp& launch_op) | AssignOpLowering(MLIRContext* ctx, gpu::LaunchOp& launch_op) | ||||
: ConversionPattern(jit::AssignOp::getOperationName(), 2, ctx), | |||||
: ConversionPattern(dialect::AssignOp::getOperationName(), 2, ctx), | |||||
m_launch_op{launch_op} {} | m_launch_op{launch_op} {} | ||||
LogicalResult matchAndRewrite( | LogicalResult matchAndRewrite( | ||||
@@ -299,7 +231,7 @@ struct AssignOpLowering : public ConversionPattern { | |||||
ConversionPatternRewriter& rewriter) const final { | ConversionPatternRewriter& rewriter) const final { | ||||
auto loc = op->getLoc(); | auto loc = op->getLoc(); | ||||
AssignOpAdaptor assign_adaptor(operands); | |||||
dialect::AssignOpAdaptor assign_adaptor(operands); | |||||
rewriter.setInsertionPointToEnd(&(m_launch_op.body().front())); | rewriter.setInsertionPointToEnd(&(m_launch_op.body().front())); | ||||
auto dst_layout = output_layout(m_launch_op); | auto dst_layout = output_layout(m_launch_op); | ||||
@@ -343,14 +275,9 @@ public: | |||||
target.addLegalDialect<gpu::GPUDialect>(); | target.addLegalDialect<gpu::GPUDialect>(); | ||||
target.addIllegalDialect<MgbDialect>(); | target.addIllegalDialect<MgbDialect>(); | ||||
#define cb(_op, _) _op##Lowering, | |||||
patterns.insert<MLIR_MGB_FOREACH_ELEMWISE_MODE_UNARY( | |||||
cb) MLIR_MGB_FOREACH_ELEMWISE_MODE_BINARY(cb) | |||||
MLIR_MGB_FOREACH_ELEMWISE_MODE_TERNARY(cb) | |||||
ReturnOpLowering, | |||||
patterns.insert<ElemwiseLowering, TypeCvtLowering, ReturnOpLowering, | |||||
ConstantScalarOpLowering, AssignOpLowering>( | ConstantScalarOpLowering, AssignOpLowering>( | ||||
&getContext(), launch_op); | &getContext(), launch_op); | ||||
#undef cb | |||||
if (failed(applyPartialConversion(func_op, target, patterns))) { | if (failed(applyPartialConversion(func_op, target, patterns))) { | ||||
signalPassFailure(); | signalPassFailure(); | ||||
@@ -22,7 +22,7 @@ mlir::Value polynomial(ValueBuilderHelper& helper, mlir::Value x, | |||||
std::vector<mlir::Value>& coeff) { | std::vector<mlir::Value>& coeff) { | ||||
size_t n = coeff.size(); | size_t n = coeff.size(); | ||||
if (n == 0) { | if (n == 0) { | ||||
return helper.const_val(0); | |||||
return helper.const_f32(0); | |||||
} | } | ||||
mlir::Value r = coeff[0]; | mlir::Value r = coeff[0]; | ||||
@@ -40,23 +40,23 @@ mlir::Value atan2_approx(ValueBuilderHelper& helper, mlir::Value y, | |||||
mlir::Value x) { | mlir::Value x) { | ||||
auto atan_poly = [&](mlir::Value t) { | auto atan_poly = [&](mlir::Value t) { | ||||
std::vector<mlir::Value> coeff = { | std::vector<mlir::Value> coeff = { | ||||
helper.const_val(2.90188402868807315826416015625E-3), | |||||
helper.const_val(-1.62907354533672332763671875E-2), | |||||
helper.const_val(4.3082617223262786865234375E-2), | |||||
helper.const_val(-7.5408883392810821533203125E-2), | |||||
helper.const_val(0.1066047251224517822265625), | |||||
helper.const_val(-0.14209578931331634521484375), | |||||
helper.const_val(0.19993579387664794921875), | |||||
helper.const_val(-0.3333314359188079833984375)}; | |||||
helper.const_f32(2.90188402868807315826416015625E-3), | |||||
helper.const_f32(-1.62907354533672332763671875E-2), | |||||
helper.const_f32(4.3082617223262786865234375E-2), | |||||
helper.const_f32(-7.5408883392810821533203125E-2), | |||||
helper.const_f32(0.1066047251224517822265625), | |||||
helper.const_f32(-0.14209578931331634521484375), | |||||
helper.const_f32(0.19993579387664794921875), | |||||
helper.const_f32(-0.3333314359188079833984375)}; | |||||
auto t2 = helper.mul(t, t); | auto t2 = helper.mul(t, t); | ||||
auto p = polynomial(helper, t2, coeff); | auto p = polynomial(helper, t2, coeff); | ||||
return helper.add(helper.mul(helper.mul(p, t2), t), t); | return helper.add(helper.mul(helper.mul(p, t2), t), t); | ||||
}; | }; | ||||
// constants | // constants | ||||
auto zero = helper.const_val(0); | |||||
auto pi = helper.const_val(3.141592653589793); | |||||
auto pi_over_2 = helper.const_val(1.570796326794897); | |||||
auto zero = helper.const_f32(0); | |||||
auto pi = helper.const_f32(3.141592653589793); | |||||
auto pi_over_2 = helper.const_f32(1.570796326794897); | |||||
// transform the angle into interval [0, pi/4] | // transform the angle into interval [0, pi/4] | ||||
auto ax = helper.abs(x); | auto ax = helper.abs(x); | ||||
@@ -83,23 +83,23 @@ mlir::Value atan2_approx(ValueBuilderHelper& helper, mlir::Value y, | |||||
// original book: | // original book: | ||||
// Numerical Recipes in Fortran 77: The Art of Scientific Computing | // Numerical Recipes in Fortran 77: The Art of Scientific Computing | ||||
mlir::Value erf_approx(ValueBuilderHelper& helper, mlir::Value x) { | mlir::Value erf_approx(ValueBuilderHelper& helper, mlir::Value x) { | ||||
auto zero = helper.const_val(0); | |||||
auto one = helper.const_val(1); | |||||
auto half = helper.const_val(0.5); | |||||
auto zero = helper.const_f32(0); | |||||
auto one = helper.const_f32(1); | |||||
auto half = helper.const_f32(0.5); | |||||
auto t = helper.div(one, helper.add(one, helper.mul(half, helper.abs(x)))); | auto t = helper.div(one, helper.add(one, helper.mul(half, helper.abs(x)))); | ||||
std::vector<mlir::Value> coeff = { | std::vector<mlir::Value> coeff = { | ||||
helper.const_val(0.17087277), | |||||
helper.const_val(-0.82215223), | |||||
helper.const_val(1.48851587), | |||||
helper.const_val(-1.13520398), | |||||
helper.const_val(0.27886807), | |||||
helper.const_val(-0.18628806), | |||||
helper.const_val(0.09678418), | |||||
helper.const_val(0.37409196), | |||||
helper.const_val(1.00002368), | |||||
helper.const_val(-1.26551223)}; | |||||
helper.const_f32(0.17087277), | |||||
helper.const_f32(-0.82215223), | |||||
helper.const_f32(1.48851587), | |||||
helper.const_f32(-1.13520398), | |||||
helper.const_f32(0.27886807), | |||||
helper.const_f32(-0.18628806), | |||||
helper.const_f32(0.09678418), | |||||
helper.const_f32(0.37409196), | |||||
helper.const_f32(1.00002368), | |||||
helper.const_f32(-1.26551223)}; | |||||
auto p = polynomial(helper, t, coeff); | auto p = polynomial(helper, t, coeff); | ||||
auto r = helper.mul(t, helper.exp(helper.sub(p, helper.mul(x, x)))); | auto r = helper.mul(t, helper.exp(helper.sub(p, helper.mul(x, x)))); | ||||
@@ -130,25 +130,25 @@ mlir::Value ndtri_approx(ValueBuilderHelper& helper, mlir::Value x) { | |||||
// polynomial P | // polynomial P | ||||
auto P = [&](mlir::Value i, mlir::Value cond) { | auto P = [&](mlir::Value i, mlir::Value cond) { | ||||
std::vector<mlir::Value> coeff0 = { | std::vector<mlir::Value> coeff0 = { | ||||
helper.const_val(4.05544892305962419923E0), | |||||
helper.const_val(3.15251094599893866154E1), | |||||
helper.const_val(5.71628192246421288162E1), | |||||
helper.const_val(4.40805073893200834700E1), | |||||
helper.const_val(1.46849561928858024014E1), | |||||
helper.const_val(2.18663306850790267539E0), | |||||
helper.const_val(-1.40256079171354495875E-1), | |||||
helper.const_val(-3.50424626827848203418E-2), | |||||
helper.const_val(-8.57456785154685413611E-4)}; | |||||
helper.const_f32(4.05544892305962419923E0), | |||||
helper.const_f32(3.15251094599893866154E1), | |||||
helper.const_f32(5.71628192246421288162E1), | |||||
helper.const_f32(4.40805073893200834700E1), | |||||
helper.const_f32(1.46849561928858024014E1), | |||||
helper.const_f32(2.18663306850790267539E0), | |||||
helper.const_f32(-1.40256079171354495875E-1), | |||||
helper.const_f32(-3.50424626827848203418E-2), | |||||
helper.const_f32(-8.57456785154685413611E-4)}; | |||||
std::vector<mlir::Value> coeff1 = { | std::vector<mlir::Value> coeff1 = { | ||||
helper.const_val(3.23774891776946035970E0), | |||||
helper.const_val(6.91522889068984211695E0), | |||||
helper.const_val(3.93881025292474443415E0), | |||||
helper.const_val(1.33303460815807542389E0), | |||||
helper.const_val(2.01485389549179081538E-1), | |||||
helper.const_val(1.23716634817820021358E-2), | |||||
helper.const_val(3.01581553508235416007E-4), | |||||
helper.const_val(2.65806974686737550832E-6), | |||||
helper.const_val(6.23974539184983293730E-9)}; | |||||
helper.const_f32(3.23774891776946035970E0), | |||||
helper.const_f32(6.91522889068984211695E0), | |||||
helper.const_f32(3.93881025292474443415E0), | |||||
helper.const_f32(1.33303460815807542389E0), | |||||
helper.const_f32(2.01485389549179081538E-1), | |||||
helper.const_f32(1.23716634817820021358E-2), | |||||
helper.const_f32(3.01581553508235416007E-4), | |||||
helper.const_f32(2.65806974686737550832E-6), | |||||
helper.const_f32(6.23974539184983293730E-9)}; | |||||
return helper.select(cond, | return helper.select(cond, | ||||
polynomial(helper, i, coeff0), | polynomial(helper, i, coeff0), | ||||
polynomial(helper, i, coeff1)); | polynomial(helper, i, coeff1)); | ||||
@@ -157,25 +157,25 @@ mlir::Value ndtri_approx(ValueBuilderHelper& helper, mlir::Value x) { | |||||
// polynomial Q | // polynomial Q | ||||
auto Q = [&](mlir::Value i, mlir::Value cond) { | auto Q = [&](mlir::Value i, mlir::Value cond) { | ||||
std::vector<mlir::Value> coeff0 = { | std::vector<mlir::Value> coeff0 = { | ||||
helper.const_val(1.f), | |||||
helper.const_val(1.57799883256466749731E1), | |||||
helper.const_val(4.53907635128879210584E1), | |||||
helper.const_val(4.13172038254672030440E1), | |||||
helper.const_val(1.50425385692907503408E1), | |||||
helper.const_val(2.50464946208309415979E0), | |||||
helper.const_val(-1.42182922854787788574E-1), | |||||
helper.const_val(-3.80806407691578277194E-2), | |||||
helper.const_val(-9.33259480895457427372E-4)}; | |||||
helper.const_f32(1.f), | |||||
helper.const_f32(1.57799883256466749731E1), | |||||
helper.const_f32(4.53907635128879210584E1), | |||||
helper.const_f32(4.13172038254672030440E1), | |||||
helper.const_f32(1.50425385692907503408E1), | |||||
helper.const_f32(2.50464946208309415979E0), | |||||
helper.const_f32(-1.42182922854787788574E-1), | |||||
helper.const_f32(-3.80806407691578277194E-2), | |||||
helper.const_f32(-9.33259480895457427372E-4)}; | |||||
std::vector<mlir::Value> coeff1 = { | std::vector<mlir::Value> coeff1 = { | ||||
helper.const_val(1.f), | |||||
helper.const_val(6.02427039364742014255E0), | |||||
helper.const_val(3.67983563856160859403E0), | |||||
helper.const_val(1.37702099489081330271E0), | |||||
helper.const_val(2.16236993594496635890E-1), | |||||
helper.const_val(1.34204006088543189037E-2), | |||||
helper.const_val(3.28014464682127739104E-4), | |||||
helper.const_val(2.89247864745380683936E-6), | |||||
helper.const_val(6.79019408009981274425E-9)}; | |||||
helper.const_f32(1.f), | |||||
helper.const_f32(6.02427039364742014255E0), | |||||
helper.const_f32(3.67983563856160859403E0), | |||||
helper.const_f32(1.37702099489081330271E0), | |||||
helper.const_f32(2.16236993594496635890E-1), | |||||
helper.const_f32(1.34204006088543189037E-2), | |||||
helper.const_f32(3.28014464682127739104E-4), | |||||
helper.const_f32(2.89247864745380683936E-6), | |||||
helper.const_f32(6.79019408009981274425E-9)}; | |||||
return helper.select(cond, | return helper.select(cond, | ||||
polynomial(helper, i, coeff0), | polynomial(helper, i, coeff0), | ||||
polynomial(helper, i, coeff1)); | polynomial(helper, i, coeff1)); | ||||
@@ -184,37 +184,37 @@ mlir::Value ndtri_approx(ValueBuilderHelper& helper, mlir::Value x) { | |||||
// polynomial R | // polynomial R | ||||
auto R = [&](mlir::Value i) { | auto R = [&](mlir::Value i) { | ||||
std::vector<mlir::Value> coeff = { | std::vector<mlir::Value> coeff = { | ||||
helper.const_val(-5.99633501014107895267E1), | |||||
helper.const_val(9.80010754185999661536E1), | |||||
helper.const_val(-5.66762857469070293439E1), | |||||
helper.const_val(1.39312609387279679503E1), | |||||
helper.const_val(-1.23916583867381258016E0)}; | |||||
helper.const_f32(-5.99633501014107895267E1), | |||||
helper.const_f32(9.80010754185999661536E1), | |||||
helper.const_f32(-5.66762857469070293439E1), | |||||
helper.const_f32(1.39312609387279679503E1), | |||||
helper.const_f32(-1.23916583867381258016E0)}; | |||||
return polynomial(helper, i, coeff); | return polynomial(helper, i, coeff); | ||||
}; | }; | ||||
// polynomial S | // polynomial S | ||||
auto S = [&](mlir::Value i) { | auto S = [&](mlir::Value i) { | ||||
std::vector<mlir::Value> coeff = { | std::vector<mlir::Value> coeff = { | ||||
helper.const_val(1.f), | |||||
helper.const_val(1.95448858338141759834E0), | |||||
helper.const_val(4.67627912898881538453E0), | |||||
helper.const_val(8.63602421390890590575E1), | |||||
helper.const_val(-2.25462687854119370527E2), | |||||
helper.const_val(2.00260212380060660359E2), | |||||
helper.const_val(-8.20372256168333339912E1), | |||||
helper.const_val(1.59056225126211695515E1), | |||||
helper.const_val(-1.18331621121330003142E0)}; | |||||
helper.const_f32(1.f), | |||||
helper.const_f32(1.95448858338141759834E0), | |||||
helper.const_f32(4.67627912898881538453E0), | |||||
helper.const_f32(8.63602421390890590575E1), | |||||
helper.const_f32(-2.25462687854119370527E2), | |||||
helper.const_f32(2.00260212380060660359E2), | |||||
helper.const_f32(-8.20372256168333339912E1), | |||||
helper.const_f32(1.59056225126211695515E1), | |||||
helper.const_f32(-1.18331621121330003142E0)}; | |||||
return polynomial(helper, i, coeff); | return polynomial(helper, i, coeff); | ||||
}; | }; | ||||
// constants | // constants | ||||
auto zero = helper.const_val(0); | |||||
auto one = helper.const_val(1); | |||||
auto half = helper.const_val(0.5); | |||||
auto eight = helper.const_val(8); | |||||
auto minus_2 = helper.const_val(-2); | |||||
auto exp_minus_2 = helper.const_val(0.135335283236); // exp(-2) | |||||
auto sqrt_2pi = helper.const_val(2.506628274631); // sqrt(2pi) | |||||
auto zero = helper.const_f32(0); | |||||
auto one = helper.const_f32(1); | |||||
auto half = helper.const_f32(0.5); | |||||
auto eight = helper.const_f32(8); | |||||
auto minus_2 = helper.const_f32(-2); | |||||
auto exp_minus_2 = helper.const_f32(0.135335283236); // exp(-2) | |||||
auto sqrt_2pi = helper.const_f32(2.506628274631); // sqrt(2pi) | |||||
// conditions | // conditions | ||||
auto case1 = helper.lt(x, exp_minus_2); // x < exp(-2) | auto case1 = helper.lt(x, exp_minus_2); // x < exp(-2) | ||||
@@ -1,216 +0,0 @@ | |||||
/** | |||||
* \file src/jit/impl/mlir/ir/ops.td | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | |||||
#ifndef MGB_MLIR_OPS | |||||
#define MGB_MLIR_OPS | |||||
include "mlir/IR/OpBase.td" | |||||
include "mlir/Interfaces/SideEffectInterfaces.td" | |||||
include "./interfaces.td" | |||||
include "./predicates.td" | |||||
def Mgb_Dialect : Dialect { | |||||
let name = "mgb"; | |||||
let cppNamespace = "mgb::jit"; | |||||
} | |||||
class ElemwiseBuilderImpl { | |||||
code ElemwiseBuilderImpl_create = [{ | |||||
static Operation* create(OpBuilder* builder, Location loc, ValueRange operands) { | |||||
OperationState state(loc, getOperationName()); | |||||
state.addOperands(operands); | |||||
state.addTypes(getResultType(operands)); | |||||
return builder->createOperation(state); | |||||
} | |||||
}]; | |||||
} | |||||
class ElemwiseOp<string mnemonic, list<OpTrait> traits = [NoSideEffect]> : | |||||
Op<Mgb_Dialect, mnemonic, !listconcat(traits, [ElemwiseOpInterface, | |||||
GenericBuilderInterface])>, ElemwiseBuilderImpl; | |||||
class GenericOp<string mnemonic, list<OpTrait> traits = []> : | |||||
Op<Mgb_Dialect, mnemonic, traits>; | |||||
class ElemwiseUnaryOp<string mnemonic, list<OpTrait> traits = [NoSideEffect]> : | |||||
ElemwiseOp<mnemonic, traits> { | |||||
let arguments = (ins F32MemRef:$lhs); | |||||
let results = (outs F32MemRef); | |||||
let builders = [OpBuilder< | |||||
"Builder* builder, OperationState& result, ValueRange operands", [{ | |||||
result.addOperands(operands); | |||||
result.addTypes(getResultType(operands)); | |||||
}]>, OpBuilder < | |||||
"OpBuilder& builder, OperationState& result, Value lhs", [{ | |||||
result.addOperands(lhs); | |||||
result.addTypes(getResultType({lhs})); | |||||
}] | |||||
>]; | |||||
let extraClassDeclaration = [{ | |||||
static Type getResultType(ValueRange operands) { | |||||
return deduce_result_type(operands); | |||||
} | |||||
}] # ElemwiseBuilderImpl_create; | |||||
} | |||||
def ReluOp : ElemwiseUnaryOp<"relu", [NoSideEffect]>; | |||||
def AbsOp : ElemwiseUnaryOp<"abs", [NoSideEffect]>; | |||||
def NegOp : ElemwiseUnaryOp<"negate", [NoSideEffect]>; | |||||
def AcosOp : ElemwiseUnaryOp<"acos", [NoSideEffect]>; | |||||
def AsinOp : ElemwiseUnaryOp<"asin", [NoSideEffect]>; | |||||
def CeilOp : ElemwiseUnaryOp<"ceil", [NoSideEffect]>; | |||||
def CosOp : ElemwiseUnaryOp<"cos", [NoSideEffect]>; | |||||
def ExpOp : ElemwiseUnaryOp<"exp", [NoSideEffect]>; | |||||
def ExpM1Op : ElemwiseUnaryOp<"expm1", [NoSideEffect]>; | |||||
def FloorOp : ElemwiseUnaryOp<"floor", [NoSideEffect]>; | |||||
def LogOp : ElemwiseUnaryOp<"log", [NoSideEffect]>; | |||||
def Log1POp : ElemwiseUnaryOp<"log1p", [NoSideEffect]>; | |||||
def SigmoidOp: ElemwiseUnaryOp<"sigmoid", [NoSideEffect]>; | |||||
def SinOp : ElemwiseUnaryOp<"sin", [NoSideEffect]>; | |||||
def TanhOp : ElemwiseUnaryOp<"tanh", [NoSideEffect]>; | |||||
def FastTanhOp : ElemwiseUnaryOp<"fast_tanh", [NoSideEffect]>; | |||||
def HswishOp : ElemwiseUnaryOp<"hswish", [NoSideEffect]>; | |||||
def RoundOp : ElemwiseUnaryOp<"round", [NoSideEffect]>; | |||||
def ErfOp : ElemwiseUnaryOp<"erf", [NoSideEffect]>; | |||||
def ErfInvOp : ElemwiseUnaryOp<"erfinv", [NoSideEffect]>; | |||||
def ErfCOp : ElemwiseUnaryOp<"erfc", [NoSideEffect]>; | |||||
def ErfCInvOp : ElemwiseUnaryOp<"erfcinv", [NoSideEffect]>; | |||||
class ElemwiseBinaryOp<string mnemonic, list<OpTrait> traits = [NoSideEffect]> : | |||||
ElemwiseOp<mnemonic, traits> { | |||||
let arguments = (ins ElemwiseFloatAny:$lhs, ElemwiseFloatAny:$rhs); | |||||
let results = (outs F32MemRef); | |||||
let builders = [OpBuilder< | |||||
"Builder* builder, OperationState& result, ValueRange operands", [{ | |||||
result.addOperands(operands); | |||||
result.addTypes(getResultType(operands)); | |||||
}] | |||||
>, OpBuilder < | |||||
"OpBuilder& builder, OperationState& result, Value lhs, Value rhs", [{ | |||||
result.addOperands(lhs); | |||||
result.addOperands(rhs); | |||||
result.addTypes(getResultType({lhs, rhs})); | |||||
}] | |||||
>]; | |||||
let extraClassDeclaration = [{ | |||||
static Type getResultType(ValueRange operands) { | |||||
return deduce_result_type(operands); | |||||
} | |||||
}] # ElemwiseBuilderImpl_create; | |||||
} | |||||
def AbsGradOp : ElemwiseBinaryOp<"abs_grad", [NoSideEffect]>; | |||||
def AddOp : ElemwiseBinaryOp<"add", [Commutative, NoSideEffect]>; | |||||
def FloorDivOp : ElemwiseBinaryOp<"floor_div", [NoSideEffect]>; | |||||
def MaxOp : ElemwiseBinaryOp<"max", [Commutative, NoSideEffect]>; | |||||
def MinOp : ElemwiseBinaryOp<"min", [Commutative, NoSideEffect]>; | |||||
def ModOp : ElemwiseBinaryOp<"mod", [NoSideEffect]>; | |||||
def MulOp : ElemwiseBinaryOp<"mul", [Commutative, NoSideEffect]>; | |||||
def SubOp : ElemwiseBinaryOp<"sub", [NoSideEffect]>; | |||||
def SigmoidGradOp : ElemwiseBinaryOp<"sigmoid_grad", [NoSideEffect]>; | |||||
def SwishGt0Op : ElemwiseBinaryOp<"switch_gt0", [NoSideEffect]>; | |||||
def TanhGradOp : ElemwiseBinaryOp<"tanh_grad", [NoSideEffect]>; | |||||
def LtOp : ElemwiseBinaryOp<"lt", [NoSideEffect]>; | |||||
def LeqOp : ElemwiseBinaryOp<"leq", [NoSideEffect]>; | |||||
def EqOp : ElemwiseBinaryOp<"eq", [Commutative, NoSideEffect]>; | |||||
def FuseAddReluOp : ElemwiseBinaryOp<"fuse_add_relu", [NoSideEffect]>; | |||||
def TrueDivOp : ElemwiseBinaryOp<"true_div", [NoSideEffect]>; | |||||
def PowOp : ElemwiseBinaryOp<"pow", [NoSideEffect]>; | |||||
def LogSumExpOp : ElemwiseBinaryOp<"log_sum_exp", [Commutative, NoSideEffect]>; | |||||
def FuseAddTanhOp : ElemwiseBinaryOp<"fuse_add_tanh", [NoSideEffect]>; | |||||
def FastTanhGradOp : ElemwiseBinaryOp<"fast_tanh_grad", [NoSideEffect]>; | |||||
def FuseAddSigmoidOp : ElemwiseBinaryOp<"fuse_add_sigmoid", [NoSideEffect]>; | |||||
def HswishGradOp : ElemwiseBinaryOp<"hswish_grad", [NoSideEffect]>; | |||||
def FuseAddHswishOp : ElemwiseBinaryOp<"fuse_add_hswish", [NoSideEffect]>; | |||||
def Atan2Op : ElemwiseBinaryOp<"atan2", [NoSideEffect]>; | |||||
class ElemwiseTernaryOp<string mnemonic, list<OpTrait> traits = [NoSideEffect]> : | |||||
ElemwiseOp<mnemonic, traits> { | |||||
let arguments = (ins ElemwiseFloatAny:$x, ElemwiseFloatAny:$y, ElemwiseFloatAny:$z); | |||||
let results = (outs F32MemRef); | |||||
let builders = [OpBuilder< | |||||
"Builder* builder, OperationState& result, ValueRange operands", [{ | |||||
result.addOperands(operands); | |||||
result.addTypes(getResultType(operands)); | |||||
}] | |||||
>, OpBuilder < | |||||
"OpBuilder& builder, OperationState& result, Value x, Value y, Value z", [{ | |||||
result.addOperands(x); | |||||
result.addOperands(y); | |||||
result.addOperands(z); | |||||
result.addTypes(getResultType({x, y, z})); | |||||
}] | |||||
>]; | |||||
let extraClassDeclaration = [{ | |||||
static Type getResultType(ValueRange operands) { | |||||
return deduce_result_type(operands); | |||||
} | |||||
}] # ElemwiseBuilderImpl_create; | |||||
} | |||||
def CondLeqMovOp: ElemwiseTernaryOp<"cond_leq_mov", [NoSideEffect]>; | |||||
def FuseMulAdd3Op: ElemwiseTernaryOp<"fuse_mul_add3", [NoSideEffect]>; | |||||
def ReturnOp : GenericOp<"return", | |||||
[NoSideEffect, HasParent<"FuncOp">, Terminator]> { | |||||
let summary = "return operation"; | |||||
let description = [{ | |||||
The "return" operation represents a return operation within a function. | |||||
The operation takes an no tensor operand and produces no results. | |||||
}]; | |||||
// The return operation takes an optional input operand to return. This | |||||
// value must match the return type of the enclosing function. | |||||
let arguments = (ins); | |||||
// The return operation only emits the input in the format if it is present. | |||||
let assemblyFormat = "attr-dict"; | |||||
} | |||||
def ConstantScalarOp: GenericOp<"sconst", [NoSideEffect]> { | |||||
let summary = "scalar constant"; | |||||
let arguments = (ins AnyAttr:$value); | |||||
let results = (outs F32:$result); | |||||
let builders = [OpBuilder< | |||||
"Builder* builder, OperationState& result, float value", [{ | |||||
result.addAttribute("value", builder->getF32FloatAttr(value)); | |||||
result.addTypes(builder->getF32Type()); | |||||
}] | |||||
>]; | |||||
let extraClassDeclaration = [{ | |||||
Attribute getValue() { return getAttr("value"); } | |||||
FloatAttr getFloatAttr() { return getAttrOfType<FloatAttr>("value"); } | |||||
}]; | |||||
} | |||||
def AssignOp : GenericOp<"assign", []> { | |||||
let summary = "assign op"; | |||||
let description = [{ | |||||
assign rhs to lhs without results | |||||
}]; | |||||
let arguments = (ins F32MemRef:$lhs, F32MemRef:$rhs); | |||||
} | |||||
#endif |
@@ -1,24 +0,0 @@ | |||||
/** | |||||
* \file src/jit/impl/mlir/ir/predicates.td | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | |||||
#ifndef MGB_MLIR_PREDICATES | |||||
#define MGB_MLIR_PREDICATES | |||||
#ifndef OP_BASE | |||||
include "mlir/IR/OpBase.td" | |||||
#endif | |||||
def ElemwiseFloatAny : TypeConstraint< | |||||
CPred<"is_elemwise_float($_self)">, "elemwise-float">; | |||||
#endif | |||||
@@ -0,0 +1,115 @@ | |||||
/** | |||||
* \file src/jit/impl/mlir/ir/types.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | |||||
#include "megbrain_build_config.h" | |||||
#if MGB_JIT && MGB_JIT_MLIR | |||||
#include "./types.h" | |||||
#include "megbrain/common.h" | |||||
#include "megbrain/exception.h" | |||||
#include "megbrain/jit/mlir/ir/utils.h" | |||||
namespace mgb { | |||||
namespace jit { | |||||
mlir::Type megdnn_dtype_to_mlir_type(megdnn::DType type, | |||||
mlir::MLIRContext* ctx) { | |||||
switch (type.enumv()) { | |||||
case megdnn::DTypeEnum::Float32: | |||||
return mlir::FloatType::getF32(ctx); | |||||
case megdnn::DTypeEnum::Uint8: | |||||
return mlir::IntegerType::get(8, ctx); | |||||
case megdnn::DTypeEnum::Int8: | |||||
return mlir::IntegerType::get(8, ctx); | |||||
case megdnn::DTypeEnum::Int16: | |||||
return mlir::IntegerType::get(16, ctx); | |||||
case megdnn::DTypeEnum::Int32: | |||||
return mlir::IntegerType::get(32, ctx); | |||||
case megdnn::DTypeEnum::IntB1: | |||||
return mlir::IntegerType::get(1, ctx); | |||||
case megdnn::DTypeEnum::IntB2: | |||||
return mlir::IntegerType::get(2, ctx); | |||||
case megdnn::DTypeEnum::IntB4: | |||||
return mlir::IntegerType::get(4, ctx); | |||||
case megdnn::DTypeEnum::Byte: | |||||
return mlir::IntegerType::get(8, ctx); | |||||
case megdnn::DTypeEnum::Float16: | |||||
return mlir::FloatType::getF16(ctx); | |||||
case megdnn::DTypeEnum::UintB4: | |||||
return mlir::IntegerType::get(4, ctx); | |||||
case megdnn::DTypeEnum::BFloat16: | |||||
return mlir::FloatType::getBF16(ctx); | |||||
case megdnn::DTypeEnum::Bool: | |||||
return mlir::IntegerType::get(1, ctx); | |||||
default: | |||||
mgb_throw(InternalError, "Unsupported MegDNN dtype: %s", | |||||
type.name()); | |||||
} | |||||
} | |||||
megdnn::DType mlir_type_to_megdnn_dtype(mlir::Type type) { | |||||
mlir::Type element_type = type; | |||||
if (auto cast = type.dyn_cast_or_null<mlir::MemRefType>()) { | |||||
element_type = cast.getElementType(); | |||||
} | |||||
megdnn::DTypeEnum enumv; | |||||
if (element_type.isF32()) { | |||||
enumv = megdnn::DTypeEnum::Float32; | |||||
} else if (element_type.isSignlessInteger(1)) { | |||||
enumv = megdnn::DTypeEnum::IntB1; | |||||
} else if (element_type.isSignlessInteger(2)) { | |||||
enumv = megdnn::DTypeEnum::IntB2; | |||||
} else if (element_type.isSignlessInteger(4)) { | |||||
enumv = megdnn::DTypeEnum::IntB4; | |||||
} else if (element_type.isSignlessInteger(8)) { | |||||
enumv = megdnn::DTypeEnum::Int8; | |||||
} else if (element_type.isSignlessInteger(16)) { | |||||
enumv = megdnn::DTypeEnum::Int16; | |||||
} else if (element_type.isSignlessInteger(32)) { | |||||
enumv = megdnn::DTypeEnum::Int32; | |||||
} else if (element_type.isF16()) { | |||||
enumv = megdnn::DTypeEnum::Float16; | |||||
} else if (element_type.isBF16()) { | |||||
enumv = megdnn::DTypeEnum::BFloat16; | |||||
} else if (element_type.isSignlessInteger(1)) { | |||||
enumv = megdnn::DTypeEnum::Bool; | |||||
} else { | |||||
mgb_throw(InternalError, "Unsupported MLIR Type: %s", | |||||
mlir_type_to_string(element_type).c_str()); | |||||
} | |||||
return megdnn::DType::from_enum(enumv); | |||||
} | |||||
bool is_signed_int_dtype(megdnn::DType type) { | |||||
auto enumv = type.enumv(); | |||||
return enumv == megdnn::DTypeEnum::Int8 or | |||||
enumv == megdnn::DTypeEnum::Int16 or | |||||
enumv == megdnn::DTypeEnum::Int32 or | |||||
enumv == megdnn::DTypeEnum::IntB1 or | |||||
enumv == megdnn::DTypeEnum::IntB2 or | |||||
enumv == megdnn::DTypeEnum::IntB4; | |||||
} | |||||
bool is_unsigned_int_dtype(megdnn::DType type) { | |||||
auto enumv = type.enumv(); | |||||
return enumv == megdnn::DTypeEnum::Uint8 or | |||||
enumv == megdnn::DTypeEnum::UintB4; | |||||
} | |||||
} // namespace jit | |||||
} // namespace mgb | |||||
#endif // MGB_JIT && MGB_JIT_MLIR | |||||
// vim: syntax=cpp.doxygen |
@@ -14,22 +14,33 @@ | |||||
#include "megbrain_build_config.h" | #include "megbrain_build_config.h" | ||||
#if MGB_JIT && MGB_JIT_MLIR | #if MGB_JIT && MGB_JIT_MLIR | ||||
#include "megdnn/dtype.h" | |||||
#include <mlir/IR/StandardTypes.h> | #include <mlir/IR/StandardTypes.h> | ||||
namespace mgb { | namespace mgb { | ||||
namespace jit { | namespace jit { | ||||
inline bool is_elemwise_float(const mlir::Type& dt) { | |||||
if (auto cast = dt.dyn_cast_or_null<mlir::MemRefType>()) { | |||||
if (cast.getElementType().isF32()) { | |||||
return true; | |||||
} | |||||
} | |||||
if (dt.isa<mlir::FloatType>()) { | |||||
return true; | |||||
} | |||||
return false; | |||||
} | |||||
#define FOR_EACH_DNN_DTYPE(cb) \ | |||||
cb(Float32, dt_float32); \ | |||||
cb(Uint8, dt_uint8); \ | |||||
cb(Int8, dt_int8); \ | |||||
cb(Int16, dt_int16); \ | |||||
cb(Int32, dt_int32); \ | |||||
cb(Byte, dt_byte); \ | |||||
MEGDNN_INC_FLOAT16(cb(Float16, dt_float16)); \ | |||||
MEGDNN_INC_FLOAT16(cb(BFloat16, dt_bfloat16)); \ | |||||
cb(Bool, dt_bool); | |||||
mlir::Type megdnn_dtype_to_mlir_type(megdnn::DType type, | |||||
mlir::MLIRContext* ctx); | |||||
megdnn::DType mlir_type_to_megdnn_dtype(mlir::Type type); | |||||
bool is_signed_int_dtype(megdnn::DType type); | |||||
bool is_unsigned_int_dtype(megdnn::DType type); | |||||
} // namespace jit | } // namespace jit | ||||
} // namespace mgb | } // namespace mgb | ||||
@@ -13,11 +13,14 @@ | |||||
#include "megbrain_build_config.h" | #include "megbrain_build_config.h" | ||||
#if MGB_JIT && MGB_JIT_MLIR | #if MGB_JIT && MGB_JIT_MLIR | ||||
#include "megbrain/jit/mlir/ir/utils.h" | |||||
#include "./types.h" | |||||
#include "megbrain/common.h" | #include "megbrain/common.h" | ||||
#include "megbrain/exception.h" | #include "megbrain/exception.h" | ||||
#include "megbrain/jit/mlir/ir/utils.h" | |||||
#include "megdnn/oprs/general.h" | |||||
#include "megdnn/basic_types.h" | #include "megdnn/basic_types.h" | ||||
#include "megdnn/oprs/general.h" | |||||
#include <mlir/Dialect/Affine/IR/AffineOps.h> | #include <mlir/Dialect/Affine/IR/AffineOps.h> | ||||
#include <mlir/IR/Builders.h> | #include <mlir/IR/Builders.h> | ||||
@@ -44,7 +47,7 @@ mlir::Value jit::insert_alloc_and_dealloc(mlir::MemRefType type, | |||||
return alloc; | return alloc; | ||||
} | } | ||||
mlir::Type jit::deduce_result_type(mlir::ValueRange operands) { | |||||
mlir::Type jit::deduce_elemwise_res_type(mlir::ValueRange operands) { | |||||
megdnn::TensorShapeArray srcs; | megdnn::TensorShapeArray srcs; | ||||
megdnn::TensorShape dst; | megdnn::TensorShape dst; | ||||
megdnn::DType dst_type; | megdnn::DType dst_type; | ||||
@@ -59,8 +62,8 @@ mlir::Type jit::deduce_result_type(mlir::ValueRange operands) { | |||||
} | } | ||||
megdnn::Elemwise::deduce_shape(srcs, dst); | megdnn::Elemwise::deduce_shape(srcs, dst); | ||||
mlir::Builder builder(operands[0].getContext()); | mlir::Builder builder(operands[0].getContext()); | ||||
return layout_to_mlir_type({dst, mlir_type_to_dtype(operands[0].getType())}, | |||||
builder); | |||||
return layout_to_mlir_type( | |||||
{dst, mlir_type_to_megdnn_dtype(operands[0].getType())}, builder); | |||||
} | } | ||||
megdnn::TensorLayout jit::mlir_type_to_layout(mlir::Type type) { | megdnn::TensorLayout jit::mlir_type_to_layout(mlir::Type type) { | ||||
@@ -72,41 +75,21 @@ megdnn::TensorLayout jit::mlir_type_to_layout(mlir::Type type) { | |||||
for (size_t i = 0; i < ret.ndim; i++) { | for (size_t i = 0; i < ret.ndim; i++) { | ||||
ret.shape[i] = real_type.getDimSize(i); | ret.shape[i] = real_type.getDimSize(i); | ||||
} | } | ||||
ret.dtype = mlir_type_to_dtype(real_type.getElementType()); | |||||
ret.dtype = mlir_type_to_megdnn_dtype(real_type.getElementType()); | |||||
} | } | ||||
return ret; | return ret; | ||||
} | } | ||||
megdnn::DType jit::mlir_type_to_dtype(mlir::Type type) { | |||||
mlir::Type element_type = type; | |||||
if (auto cast = type.dyn_cast_or_null<mlir::MemRefType>()) { | |||||
element_type = cast.getElementType(); | |||||
} | |||||
if (element_type.isF32()) { | |||||
return megdnn::dtype::Float32{}; | |||||
} else { | |||||
mgb_throw(InternalError, | |||||
"Unsupport mlir type for MemRefType, got: %s\n", | |||||
mlir_type_to_string(type).c_str()); | |||||
} | |||||
return {}; | |||||
} | |||||
mlir::MemRefType jit::layout_to_mlir_type(const megdnn::TensorLayout& layout, | mlir::MemRefType jit::layout_to_mlir_type(const megdnn::TensorLayout& layout, | ||||
mlir::Builder& builder) { | mlir::Builder& builder) { | ||||
std::vector<int64_t> shape; | std::vector<int64_t> shape; | ||||
for (size_t i = 0; i < layout.ndim; i++) { | for (size_t i = 0; i < layout.ndim; i++) { | ||||
shape.push_back(layout[i]); | shape.push_back(layout[i]); | ||||
} | } | ||||
switch (layout.dtype.enumv()) { | |||||
case megdnn::DTypeEnum::Float32: | |||||
return mlir::MemRefType::get(shape, builder.getF32Type()); | |||||
default: | |||||
mgb_throw(InternalError, "No supported dtype: %s", | |||||
layout.dtype.name()); | |||||
} | |||||
mlir::Type type = megdnn_dtype_to_mlir_type(layout.dtype, builder.getContext()); | |||||
return mlir::MemRefType::get(shape, type); | |||||
} | } | ||||
#endif // MGB_JIT_MLIR | |||||
#endif // MGB_JIT && MGB_JIT_MLIR | |||||
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} |
@@ -15,6 +15,7 @@ | |||||
#include "./mlir_gen.h" | #include "./mlir_gen.h" | ||||
#include "./ir/each_mode.h" | #include "./ir/each_mode.h" | ||||
#include "./ir/types.h" | |||||
#include "megbrain/jit/mlir/ir/dialect.h" | #include "megbrain/jit/mlir/ir/dialect.h" | ||||
#include "megbrain/jit/mlir/ir/utils.h" | #include "megbrain/jit/mlir/ir/utils.h" | ||||
@@ -116,9 +117,9 @@ private: | |||||
return nullptr; | return nullptr; | ||||
} | } | ||||
jit::ReturnOp return_op; | |||||
dialect::ReturnOp return_op; | |||||
if (!return_op) { | if (!return_op) { | ||||
m_builder.create<jit::ReturnOp>(m_builder.getUnknownLoc()); | |||||
m_builder.create<dialect::ReturnOp>(m_builder.getUnknownLoc()); | |||||
} | } | ||||
std::string op_content = mlir_type_to_string(func_op); | std::string op_content = mlir_type_to_string(func_op); | ||||
func_op.setName( | func_op.setName( | ||||
@@ -135,9 +136,7 @@ private: | |||||
cg::DepOprIter{[&](cg::OperatorNodeBase* opr) { | cg::DepOprIter{[&](cg::OperatorNodeBase* opr) { | ||||
if (opr->same_type<JITPlaceholder>()) { | if (opr->same_type<JITPlaceholder>()) { | ||||
return; | return; | ||||
} | |||||
if (opr->same_type<opr::ImmutableTensor>()) { | |||||
} else if (opr->same_type<opr::ImmutableTensor>()) { | |||||
auto imm = SymbolVar{opr->output(0)}.as_immutable_scalar(); | auto imm = SymbolVar{opr->output(0)}.as_immutable_scalar(); | ||||
if (imm.valid()) { | if (imm.valid()) { | ||||
auto dtype = imm->dtype(); | auto dtype = imm->dtype(); | ||||
@@ -150,59 +149,53 @@ private: | |||||
"dtype, but got %s", | "dtype, but got %s", | ||||
dtype.name()); | dtype.name()); | ||||
} | } | ||||
auto&& out = m_builder.create<jit::ConstantScalarOp>( | |||||
auto&& out = m_builder.create<dialect::ConstantScalarOp>( | |||||
m_builder.getUnknownLoc(), m_builder.getF32Type(), | m_builder.getUnknownLoc(), m_builder.getF32Type(), | ||||
m_builder.getF32FloatAttr(scalar_value)); | m_builder.getF32FloatAttr(scalar_value)); | ||||
mgb_assert(mlir::succeeded( | mgb_assert(mlir::succeeded( | ||||
declare(opr->output(0)->name(), out))); | declare(opr->output(0)->name(), out))); | ||||
} | } | ||||
} | |||||
if (opr->same_type<opr::Elemwise>()) { | |||||
auto&& out = gen_op(opr->cast_final<opr::Elemwise>()); | |||||
} else if (opr->same_type<opr::Elemwise>()) { | |||||
auto&& out = gen_elemwise(opr->cast_final<opr::Elemwise>()); | |||||
mgb_assert( | |||||
mlir::succeeded(declare(opr->output(0)->name(), out))); | |||||
return; | |||||
} else if (opr->same_type<opr::TypeCvt>()) { | |||||
auto&& out = gen_typecvt(opr->cast_final<opr::TypeCvt>()); | |||||
mgb_assert( | mgb_assert( | ||||
mlir::succeeded(declare(opr->output(0)->name(), out))); | mlir::succeeded(declare(opr->output(0)->name(), out))); | ||||
} | } | ||||
}} | }} | ||||
.add(internal_graph.output()); | .add(internal_graph.output()); | ||||
m_builder.create<AssignOp>(m_builder.getUnknownLoc(), | |||||
get(internal_graph.output()), | |||||
get(args.outputs[0].from)); | |||||
m_builder.create<dialect::AssignOp>(m_builder.getUnknownLoc(), | |||||
get(internal_graph.output()), | |||||
get(args.outputs[0].from)); | |||||
return mlir::success(); | return mlir::success(); | ||||
} | } | ||||
mlir::Value gen_op(const opr::Elemwise& opr) { | |||||
switch (opr.param().mode) { | |||||
#define cb(mlir_op, mgb_mode) \ | |||||
case opr::Elemwise::Mode::mgb_mode: \ | |||||
return m_builder.create<jit::mlir_op>(m_builder.getUnknownLoc(), \ | |||||
get(opr.input(0)), \ | |||||
get(opr.input(1))); \ | |||||
break; | |||||
MLIR_MGB_FOREACH_ELEMWISE_MODE_BINARY(cb) | |||||
#undef cb | |||||
#define cb(mlir_op, mgb_mode) \ | |||||
case opr::Elemwise::Mode::mgb_mode: \ | |||||
return m_builder.create<jit::mlir_op>(m_builder.getUnknownLoc(), \ | |||||
get(opr.input(0))); \ | |||||
break; | |||||
MLIR_MGB_FOREACH_ELEMWISE_MODE_UNARY(cb) | |||||
#undef cb | |||||
#define cb(mlir_op, mgb_mode) \ | |||||
case opr::Elemwise::Mode::mgb_mode: \ | |||||
return m_builder.create<jit::mlir_op>( \ | |||||
m_builder.getUnknownLoc(), get(opr.input(0)), \ | |||||
get(opr.input(1)), get(opr.input(2))); \ | |||||
break; | |||||
MLIR_MGB_FOREACH_ELEMWISE_MODE_TERNARY(cb) | |||||
#undef cb | |||||
default: | |||||
return nullptr; | |||||
mlir::Value gen_elemwise(const opr::Elemwise& opr) { | |||||
llvm::SmallVector<mlir::Value, 4> operands; | |||||
for (size_t i = 0; i < opr.input().size(); i++) { | |||||
operands.push_back(get(opr.input(i))); | |||||
} | } | ||||
return nullptr; | |||||
mlir::Type res_type = deduce_elemwise_res_type(operands); | |||||
return m_builder.create<dialect::Elemwise>( | |||||
m_builder.getUnknownLoc(), res_type, mlir::ValueRange(operands), | |||||
opr.param().mode); | |||||
} | |||||
mlir::Value gen_typecvt(const opr::TypeCvt& opr) { | |||||
auto shape = get(opr.input(0)) | |||||
.getType() | |||||
.dyn_cast_or_null<mlir::MemRefType>() | |||||
.getShape(); | |||||
auto res_type = mlir::MemRefType::get( | |||||
shape, | |||||
megdnn_dtype_to_mlir_type(opr.param(), m_builder.getContext())); | |||||
return m_builder.create<dialect::TypeCvt>( | |||||
m_builder.getUnknownLoc(), res_type, get(opr.input(0)), | |||||
opr.input(0)->dtype(), opr.param()); | |||||
} | } | ||||
mlir::Type get_type(const TensorLayout& layout) { | mlir::Type get_type(const TensorLayout& layout) { | ||||
@@ -0,0 +1,6 @@ | |||||
# mgb_dialect | |||||
set(LLVM_TARGET_DEFINITIONS mgb_dialect.td) | |||||
tablegen(MLIR mgb_dialect.h.inc ${MGE_IR_INCLUDE_DIRS} "--gen-op-decls") | |||||
tablegen(MLIR mgb_dialect.cpp.inc ${MGE_IR_INCLUDE_DIRS} "--gen-op-defs") | |||||
add_custom_target(mgb_dialect DEPENDS mgb_dialect.h.inc mgb_dialect.cpp.inc) | |||||
add_dependencies(mgb_dialect param_defs_tblgen) |
@@ -1,5 +1,5 @@ | |||||
/** | /** | ||||
* \file src/jit/impl/mlir/ir/dialect.h | |||||
* \file src/jit/include/megbrain/jit/mlir/ir/dialect.h | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | ||||
* | * | ||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | ||||
@@ -15,8 +15,7 @@ | |||||
#include "megbrain_build_config.h" | #include "megbrain_build_config.h" | ||||
#if MGB_JIT && MGB_JIT_MLIR | #if MGB_JIT && MGB_JIT_MLIR | ||||
#include "megbrain/jit/mlir/ir/interfaces.h" | |||||
#include "megbrain/jit/mlir/ir/utils.h" | |||||
#include "megdnn/opr_param_defs.h" | |||||
#include <mlir/IR/Dialect.h> | #include <mlir/IR/Dialect.h> | ||||
#include <mlir/IR/Function.h> | #include <mlir/IR/Function.h> | ||||
@@ -39,7 +38,7 @@ public: | |||||
#define GET_OP_CLASSES | #define GET_OP_CLASSES | ||||
using namespace mlir; | using namespace mlir; | ||||
#include "megbrain/jit/mlir/ir/ops.h.inc" | |||||
#include "megbrain/jit/mlir/ir/mgb_dialect.h.inc" | |||||
#endif // MGB_JIT && MGB_JIT_MLIR | #endif // MGB_JIT && MGB_JIT_MLIR | ||||
@@ -1,28 +0,0 @@ | |||||
/** | |||||
* \file src/jit/include/mlir/ir/interfaces.h | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | |||||
#pragma once | |||||
#include "megbrain_build_config.h" | |||||
#if MGB_JIT_MLIR | |||||
#include <mlir/IR/OpDefinition.h> | |||||
#include <mlir/IR/Types.h> | |||||
namespace mlir { | |||||
/// Include the auto-generated declarations. | |||||
#include "megbrain/jit/mlir/ir/interfaces.h.inc" | |||||
} | |||||
#endif // MGB_JIT_MLIR | |||||
// vim: syntax=cpp.doxygen |
@@ -1,5 +1,5 @@ | |||||
/** | /** | ||||
* \file src/jit/impl/mlir/ir/passes.h | |||||
* \file src/jit/include/megbrain/jit/mlir/ir/passes.h | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | ||||
* | * | ||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | ||||
@@ -11,8 +11,8 @@ | |||||
*/ | */ | ||||
#pragma once | #pragma once | ||||
#include "megbrain_build_config.h" | |||||
#include "megbrain_build_config.h" | |||||
#if MGB_JIT && MGB_JIT_MLIR | #if MGB_JIT && MGB_JIT_MLIR | ||||
#include <memory> | #include <memory> | ||||
@@ -1,5 +1,5 @@ | |||||
/** | /** | ||||
* \file src/jit/include/megbrain/mlir/ir/utils.h | |||||
* \file src/jit/include/megbrain/jit/mlir/ir/utils.h | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | ||||
* | * | ||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | ||||
@@ -35,15 +35,19 @@ std::string mlir_type_to_string(T&& t) { | |||||
mlir::Value insert_alloc_and_dealloc(mlir::MemRefType type, mlir::Location loc, | mlir::Value insert_alloc_and_dealloc(mlir::MemRefType type, mlir::Location loc, | ||||
mlir::PatternRewriter& rewriter); | mlir::PatternRewriter& rewriter); | ||||
mlir::Type deduce_result_type(mlir::ValueRange operands); | |||||
mlir::Type deduce_elemwise_res_type(mlir::ValueRange operands); | |||||
/** | /** | ||||
* \brief convert mlir type to TensorShape | |||||
* \brief convert MLIR Type to TensorLayout | |||||
*/ | */ | ||||
megdnn::TensorLayout mlir_type_to_layout(mlir::Type type); | megdnn::TensorLayout mlir_type_to_layout(mlir::Type type); | ||||
megdnn::DType mlir_type_to_dtype(mlir::Type type); | |||||
/** | |||||
* \brief convert TensorLayout to MLIR Type | |||||
*/ | |||||
mlir::MemRefType layout_to_mlir_type(const megdnn::TensorLayout& layout, | mlir::MemRefType layout_to_mlir_type(const megdnn::TensorLayout& layout, | ||||
mlir::Builder& builder); | mlir::Builder& builder); | ||||
} // namespace jit | } // namespace jit | ||||
} // namespace mgb | } // namespace mgb | ||||
@@ -267,6 +267,8 @@ void run_mlir_mode(CompNode cn) { | |||||
} // anonymous namespace | } // anonymous namespace | ||||
/* ===================== TestJITHalideCodeGenCude ===================== */ | |||||
#if MGB_JIT_HALIDE | #if MGB_JIT_HALIDE | ||||
template <typename tag> | template <typename tag> | ||||
class TestJITHalideCodeGenCuda : public ::testing::Test {}; | class TestJITHalideCodeGenCuda : public ::testing::Test {}; | ||||
@@ -277,6 +279,8 @@ TYPED_TEST(TestJITHalideCodeGenCuda, run) { | |||||
} | } | ||||
#endif | #endif | ||||
/* ===================== TestJITNvrtcCodeGen ===================== */ | |||||
template <typename tag> | template <typename tag> | ||||
class TestJITNvrtcCodeGen : public ::testing::Test {}; | class TestJITNvrtcCodeGen : public ::testing::Test {}; | ||||
TYPED_TEST_CASE(TestJITNvrtcCodeGen, test_types); | TYPED_TEST_CASE(TestJITNvrtcCodeGen, test_types); | ||||
@@ -285,6 +289,8 @@ TYPED_TEST(TestJITNvrtcCodeGen, run) { | |||||
run<TypeParam>(Backend::NVRTC, CompNode::load("gpu0")); | run<TypeParam>(Backend::NVRTC, CompNode::load("gpu0")); | ||||
} | } | ||||
/* ===================== TestJITMlirCodeGen ===================== */ | |||||
#if MGB_JIT_MLIR | #if MGB_JIT_MLIR | ||||
TEST(TestJITMlirCodeGen, Basic) { | TEST(TestJITMlirCodeGen, Basic) { | ||||
auto cn = CompNode::load("cpu0"); | auto cn = CompNode::load("cpu0"); | ||||
@@ -299,7 +305,8 @@ TEST(TestJITMlirCodeGen, BasicGPU) { | |||||
run_mlir_broadcast(cn); | run_mlir_broadcast(cn); | ||||
} | } | ||||
///////////////////////// unary /////////////////////////////// | |||||
/* ===================== TestJITMlirUnaryElemwise ===================== */ | |||||
// clang-format off | // clang-format off | ||||
#define FOREACH_UNARY_MODE(cb) \ | #define FOREACH_UNARY_MODE(cb) \ | ||||
cb(RELU) \ | cb(RELU) \ | ||||
@@ -365,7 +372,8 @@ TYPED_TEST(TestJITMlirUnaryElemwise, runGpu) { | |||||
run_mlir_mode<TypeParam, 1>(cn); | run_mlir_mode<TypeParam, 1>(cn); | ||||
} | } | ||||
///////////////////////// binary /////////////////////////////// | |||||
/* ===================== TestJITMlirBinaryElemwise ===================== */ | |||||
// clang-format off | // clang-format off | ||||
#define FOREACH_BINARY_MODE(cb) \ | #define FOREACH_BINARY_MODE(cb) \ | ||||
cb(ADD) \ | cb(ADD) \ | ||||
@@ -422,7 +430,8 @@ TYPED_TEST(TestJITMlirBinaryElemwise, runGpu) { | |||||
run_mlir_mode<TypeParam, 2>(cn); | run_mlir_mode<TypeParam, 2>(cn); | ||||
} | } | ||||
///////////////////////// ternary /////////////////////////////// | |||||
/* ===================== TestJITMlirTenaryElemwise ===================== */ | |||||
// clang-format off | // clang-format off | ||||
#define FOREACH_TERNARY_MODE(cb) \ | #define FOREACH_TERNARY_MODE(cb) \ | ||||
cb(COND_LEQ_MOV) \ | cb(COND_LEQ_MOV) \ | ||||
@@ -456,6 +465,81 @@ TYPED_TEST(TestJITMlirTernaryElemwise, runGpu) { | |||||
#undef SKIP_MODE | #undef SKIP_MODE | ||||
/* ===================== TestJITMlirTypeCvt ===================== */ | |||||
template <typename itype, typename otype> | |||||
void run_typecvt(CompNode cn) { | |||||
set_backend(Backend::MLIR); | |||||
auto graph = ComputingGraph::make(); | |||||
HostTensorGenerator<itype, RandomDistribution::UNIFORM> gen(-10, 10); | |||||
auto host_x = gen({23, 42}, cn); | |||||
auto x = opr::Host2DeviceCopy::make(*graph, host_x); | |||||
auto y = opr::TypeCvt::make(x, otype()); | |||||
auto ig_gen = std::make_unique<InternalGraphGenerator>(y.node()->owner_opr()); | |||||
for (auto i : get_rev_topo_order(y)) { | |||||
if (!i->template same_type<opr::Host2DeviceCopy>()) { | |||||
ig_gen->add_opr(i); | |||||
} | |||||
} | |||||
auto igraph = ig_gen->generate(); | |||||
auto y_jit = JITExecutor::make(igraph, ig_gen->orig_inps()); | |||||
HostTensorND host_y, host_y_jit; | |||||
auto func = graph->compile({make_callback_copy(y, host_y), | |||||
make_callback_copy(y_jit, host_y_jit)}); | |||||
func->execute(); | |||||
MGB_ASSERT_TENSOR_EQ(host_y, host_y_jit); | |||||
}; | |||||
#define add_typecvt_gtest(itype, otype) \ | |||||
TEST(TestJITMlirTypeCvt, itype##_to_##otype) { \ | |||||
run_typecvt<dtype::itype, dtype::otype>(CompNode::load("cpu0")); \ | |||||
} \ | |||||
TEST(TestJITMlirTypeCvt, itype##_to_##otype##_GPU) { \ | |||||
REQUIRE_GPU(1); \ | |||||
run_typecvt<dtype::itype, dtype::otype>(CompNode::load("gpu0")); \ | |||||
} | |||||
#if !MEGDNN_DISABLE_FLOAT16 | |||||
// TODO: the support for f16 and bf16 is currently not complete in mlir | |||||
// FPExtOp | |||||
// add_typecvt_gtest(Float16, Float32); | |||||
// add_typecvt_gtest(BFloat16, Float32); | |||||
// add_typecvt_gtest(Float16, BFloat16); | |||||
// FPTruncOp | |||||
// add_typecvt_gtest(Float32, Float16); | |||||
// add_typecvt_gtest(Float32, BFloat16); | |||||
// add_typecvt_gtest(Float16, BFloat16); | |||||
#endif | |||||
// FPToSIOp | |||||
add_typecvt_gtest(Float32, Int8); | |||||
add_typecvt_gtest(Float32, Int16); | |||||
add_typecvt_gtest(Float32, Int32); | |||||
// FPToUIOp | |||||
add_typecvt_gtest(Float32, Uint8); | |||||
// SIToFPOp | |||||
add_typecvt_gtest(Int8, Float32); | |||||
add_typecvt_gtest(Int16, Float32); | |||||
add_typecvt_gtest(Int32, Float32); | |||||
// UIToFPOp | |||||
add_typecvt_gtest(Uint8, Float32); | |||||
#undef add_typecvt_gtest | |||||
#endif // MGB_JIT_MLIR | #endif // MGB_JIT_MLIR | ||||
#endif // MGB_JIT | #endif // MGB_JIT | ||||
@@ -2,7 +2,7 @@ | |||||
// RUN: mgb-opt --mgb-convert-to-affine --mgb-codegen-convert-affine-to-llvm --split-input-file -canonicalize -cse %s | // RUN: mgb-opt --mgb-convert-to-affine --mgb-codegen-convert-affine-to-llvm --split-input-file -canonicalize -cse %s | ||||
func @add_dim1(%lhs: memref<2xf32>, %rhs: memref<2xf32>, %res: memref<2xf32>) -> () { | func @add_dim1(%lhs: memref<2xf32>, %rhs: memref<2xf32>, %res: memref<2xf32>) -> () { | ||||
%0 = "mgb.add"(%lhs, %rhs) {name = "add.f"} : | |||||
%0 = "mgb.Elemwise"(%lhs, %rhs) {name = "add.f", mode = 16 : i32} : | |||||
(memref<2xf32>, memref<2xf32>) -> memref<2xf32> | (memref<2xf32>, memref<2xf32>) -> memref<2xf32> | ||||
"mgb.assign"(%0, %res) : (memref<2xf32>, memref<2xf32>) -> () | "mgb.assign"(%0, %res) : (memref<2xf32>, memref<2xf32>) -> () | ||||
mgb.return | mgb.return | ||||
@@ -24,7 +24,7 @@ func @add_dim1(%lhs: memref<2xf32>, %rhs: memref<2xf32>, %res: memref<2xf32>) -> | |||||
// CHECK: } | // CHECK: } | ||||
func @add_dim4(%lhs: memref<4x3x64x64xf32>, %rhs: memref<4x3x64x64xf32>, %res: memref<4x3x64x64xf32>) -> () { | func @add_dim4(%lhs: memref<4x3x64x64xf32>, %rhs: memref<4x3x64x64xf32>, %res: memref<4x3x64x64xf32>) -> () { | ||||
%0 = "mgb.add"(%lhs, %rhs) {name = "add.f"} : | |||||
%0 = "mgb.Elemwise"(%lhs, %rhs) {name = "add.f", mode = 16 : i32} : | |||||
(memref<4x3x64x64xf32>, memref<4x3x64x64xf32>) -> memref<4x3x64x64xf32> | (memref<4x3x64x64xf32>, memref<4x3x64x64xf32>) -> memref<4x3x64x64xf32> | ||||
"mgb.assign"(%0, %res) : (memref<4x3x64x64xf32>, memref<4x3x64x64xf32>) -> () | "mgb.assign"(%0, %res) : (memref<4x3x64x64xf32>, memref<4x3x64x64xf32>) -> () | ||||
mgb.return | mgb.return | ||||
@@ -55,4 +55,4 @@ func @add_dim4(%lhs: memref<4x3x64x64xf32>, %rhs: memref<4x3x64x64xf32>, %res: m | |||||
// CHECK: } | // CHECK: } | ||||
// CHECK: dealloc %0 : memref<4x3x64x64xf32> | // CHECK: dealloc %0 : memref<4x3x64x64xf32> | ||||
// CHECK: return | // CHECK: return | ||||
// CHECK: } | |||||
// CHECK: } |