Browse Source

fix(mgb/jit): mlir doesn't support broadcast

GitOrigin-RevId: 08bfc4c34a
tags/v1.0.0-rc1
Megvii Engine Team 4 years ago
parent
commit
23437864f9
5 changed files with 69 additions and 8 deletions
  1. +25
    -3
      src/jit/impl/fusion_pass.cpp
  2. +0
    -3
      src/jit/impl/mlir/compiler.cpp
  3. +1
    -1
      src/jit/impl/mlir/ir/create_gpu_kernel_outlining_pass.cpp
  4. +1
    -1
      src/jit/impl/mlir/ir/types.h
  5. +42
    -0
      src/jit/test/fusion.cpp

+ 25
- 3
src/jit/impl/fusion_pass.cpp View File

@@ -291,8 +291,27 @@ void JITFusionPass::Impl::process_opr(OperatorNodeBase* opr) {
cond_cn = opr->output(0)->comp_node() ==
ig_gen->output()->comp_node(),
cond_shp = check_shape(opr, ig_gen),
cond_nr_inp = ig_gen->get_cnt_input_if_add(opr) <= max_nr_input;
if (cond_readers && cond_cn && cond_shp && cond_nr_inp) {
cond_nr_inp = ig_gen->get_cnt_input_if_add(opr) <= max_nr_input,
cond_mlir_specific = true;

#if MGB_JIT_MLIR
//! FIXME mlir does't support broadcast currently.
auto backend = MGB_GETENV("MGB_JIT_BACKEND");
if (!strcmp(backend, "MLIR")) {
for (VarNode* var : opr->input()) {
if (!SymbolVar{var}.as_immutable_scalar().valid()) {
if (opr->node_prop().dep_map().at(var) &
DepType::DEV_VALUE) {
if (!var->shape().eq_shape(opr->output(0)->shape())) {
cond_mlir_specific = false;
}
}
}
}
}
#endif
if (cond_readers && cond_cn && cond_shp && cond_nr_inp &&
cond_mlir_specific) {
ig_gen->add_opr(opr);
} else {
if (opr->same_type<opr::Dimshuffle>()) {
@@ -344,7 +363,10 @@ bool JITFusionPass::Impl::can_be_fused(cg::OperatorNodeBase* opr) const {
}

//! As MLIR backend has some contraints
auto backend = MGB_GETENV("MGB_JIT_BACKEND");
const char* backend = MGB_GETENV("MGB_JIT_BACKEND");
if (!backend) {
backend = "DEFAULT";
}
// float elemwise
if (auto elem = gopt::try_cast_as_op<opr::Elemwise>(opr)) {
bool ret = true;


+ 0
- 3
src/jit/impl/mlir/compiler.cpp View File

@@ -222,9 +222,6 @@ void MLIRCompiler::run_lowering_pass(mlir::OwningModuleRef& module,

std::unique_ptr<Executable> MLIRCompiler::do_compile(
const InternalGraph& graph, const JITExecutor::Args& args) {
MGB_MARK_USED_VAR(graph);
MGB_MARK_USED_VAR(args);

mlir::MLIRContext ctx;
ctx.printStackTraceOnDiagnostic(true);
ctx.printOpOnDiagnostic(true);


+ 1
- 1
src/jit/impl/mlir/ir/create_gpu_kernel_outlining_pass.cpp View File

@@ -19,7 +19,7 @@
* implied.
*
* This file has been modified by Megvii ("Megvii Modifications").
* All Megvii Modifications are Copyright (C) 2014-2019 Megvii Inc. All rights
* All Megvii Modifications are Copyright (C) 2014-2020 Megvii Inc. All rights
* reserved.
*
*/


+ 1
- 1
src/jit/impl/mlir/ir/types.h View File

@@ -19,7 +19,7 @@
namespace mgb {
namespace jit {

inline bool is_elemwise_float(const mlir::Type& dt) {
inline const bool is_elemwise_float(const mlir::Type& dt) {
if (auto cast = dt.dyn_cast_or_null<mlir::MemRefType>()) {
if (cast.getElementType().getKind() == mlir::StandardTypes::F32) {
return true;


+ 42
- 0
src/jit/test/fusion.cpp View File

@@ -1553,6 +1553,48 @@ TEST(TestJITExecutor, GradBehavior) {
}
}

#if MGB_JIT_MLIR

void run_mlir(CompNode cn) {
set_backend(Backend::MLIR);

HostTensorGenerator<> gen;
auto host_x0 = gen({23, 42}, cn), host_x1 = gen({23, 1}, cn),
host_x2 = gen({1, 42}, cn), host_x3 = gen({23, 42}, cn),
host_x4 = gen({1, 42}, cn), host_x5 = gen({23, 1}, cn);

auto make_dst = [&](ComputingGraph& graph) {
auto a = opr::Host2DeviceCopy::make(graph, host_x0),
b = opr::Host2DeviceCopy::make(graph, host_x1),
c = opr::Host2DeviceCopy::make(graph, host_x2),
d = opr::Host2DeviceCopy::make(graph, host_x3),
e = opr::Host2DeviceCopy::make(graph, host_x4);
return a + opr::max(b, c) + opr::max(d, e);
};
HostTensorND host_y1, host_y2;
auto funcs = make_func_pair(host_y1, host_y2, make_dst, 2);

funcs.first->execute();
funcs.second->execute();
MGB_ASSERT_TENSOR_EQ(host_y1, host_y2);

JITExecutor* jit;
unpack_vector(find_oprs<JITExecutor>(*funcs.second), jit);
ASSERT_EQ(2u, find_oprs<opr::Elemwise>(*funcs.second).size());
ASSERT_EQ(3u, jit->input().size());
}

TEST(TestJITExecutor, TestJITMlirFusion) {
run_mlir(CompNode::load("cpu0"));
}

TEST(TestJITExecutor, TestJITMlirFusionGpu) {
REQUIRE_GPU(1);
run_mlir(CompNode::load("gpu0"));
}

#endif // MGB_JIT_MLIR

#endif // MGB_JIT

// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}

Loading…
Cancel
Save