GitOrigin-RevId: 08bfc4c34a
tags/v1.0.0-rc1
@@ -291,8 +291,27 @@ void JITFusionPass::Impl::process_opr(OperatorNodeBase* opr) { | |||||
cond_cn = opr->output(0)->comp_node() == | cond_cn = opr->output(0)->comp_node() == | ||||
ig_gen->output()->comp_node(), | ig_gen->output()->comp_node(), | ||||
cond_shp = check_shape(opr, ig_gen), | cond_shp = check_shape(opr, ig_gen), | ||||
cond_nr_inp = ig_gen->get_cnt_input_if_add(opr) <= max_nr_input; | |||||
if (cond_readers && cond_cn && cond_shp && cond_nr_inp) { | |||||
cond_nr_inp = ig_gen->get_cnt_input_if_add(opr) <= max_nr_input, | |||||
cond_mlir_specific = true; | |||||
#if MGB_JIT_MLIR | |||||
//! FIXME mlir does't support broadcast currently. | |||||
auto backend = MGB_GETENV("MGB_JIT_BACKEND"); | |||||
if (!strcmp(backend, "MLIR")) { | |||||
for (VarNode* var : opr->input()) { | |||||
if (!SymbolVar{var}.as_immutable_scalar().valid()) { | |||||
if (opr->node_prop().dep_map().at(var) & | |||||
DepType::DEV_VALUE) { | |||||
if (!var->shape().eq_shape(opr->output(0)->shape())) { | |||||
cond_mlir_specific = false; | |||||
} | |||||
} | |||||
} | |||||
} | |||||
} | |||||
#endif | |||||
if (cond_readers && cond_cn && cond_shp && cond_nr_inp && | |||||
cond_mlir_specific) { | |||||
ig_gen->add_opr(opr); | ig_gen->add_opr(opr); | ||||
} else { | } else { | ||||
if (opr->same_type<opr::Dimshuffle>()) { | if (opr->same_type<opr::Dimshuffle>()) { | ||||
@@ -344,7 +363,10 @@ bool JITFusionPass::Impl::can_be_fused(cg::OperatorNodeBase* opr) const { | |||||
} | } | ||||
//! As MLIR backend has some contraints | //! As MLIR backend has some contraints | ||||
auto backend = MGB_GETENV("MGB_JIT_BACKEND"); | |||||
const char* backend = MGB_GETENV("MGB_JIT_BACKEND"); | |||||
if (!backend) { | |||||
backend = "DEFAULT"; | |||||
} | |||||
// float elemwise | // float elemwise | ||||
if (auto elem = gopt::try_cast_as_op<opr::Elemwise>(opr)) { | if (auto elem = gopt::try_cast_as_op<opr::Elemwise>(opr)) { | ||||
bool ret = true; | bool ret = true; | ||||
@@ -222,9 +222,6 @@ void MLIRCompiler::run_lowering_pass(mlir::OwningModuleRef& module, | |||||
std::unique_ptr<Executable> MLIRCompiler::do_compile( | std::unique_ptr<Executable> MLIRCompiler::do_compile( | ||||
const InternalGraph& graph, const JITExecutor::Args& args) { | const InternalGraph& graph, const JITExecutor::Args& args) { | ||||
MGB_MARK_USED_VAR(graph); | |||||
MGB_MARK_USED_VAR(args); | |||||
mlir::MLIRContext ctx; | mlir::MLIRContext ctx; | ||||
ctx.printStackTraceOnDiagnostic(true); | ctx.printStackTraceOnDiagnostic(true); | ||||
ctx.printOpOnDiagnostic(true); | ctx.printOpOnDiagnostic(true); | ||||
@@ -19,7 +19,7 @@ | |||||
* implied. | * implied. | ||||
* | * | ||||
* This file has been modified by Megvii ("Megvii Modifications"). | * This file has been modified by Megvii ("Megvii Modifications"). | ||||
* All Megvii Modifications are Copyright (C) 2014-2019 Megvii Inc. All rights | |||||
* All Megvii Modifications are Copyright (C) 2014-2020 Megvii Inc. All rights | |||||
* reserved. | * reserved. | ||||
* | * | ||||
*/ | */ | ||||
@@ -19,7 +19,7 @@ | |||||
namespace mgb { | namespace mgb { | ||||
namespace jit { | namespace jit { | ||||
inline bool is_elemwise_float(const mlir::Type& dt) { | |||||
inline const bool is_elemwise_float(const mlir::Type& dt) { | |||||
if (auto cast = dt.dyn_cast_or_null<mlir::MemRefType>()) { | if (auto cast = dt.dyn_cast_or_null<mlir::MemRefType>()) { | ||||
if (cast.getElementType().getKind() == mlir::StandardTypes::F32) { | if (cast.getElementType().getKind() == mlir::StandardTypes::F32) { | ||||
return true; | return true; | ||||
@@ -1553,6 +1553,48 @@ TEST(TestJITExecutor, GradBehavior) { | |||||
} | } | ||||
} | } | ||||
#if MGB_JIT_MLIR | |||||
void run_mlir(CompNode cn) { | |||||
set_backend(Backend::MLIR); | |||||
HostTensorGenerator<> gen; | |||||
auto host_x0 = gen({23, 42}, cn), host_x1 = gen({23, 1}, cn), | |||||
host_x2 = gen({1, 42}, cn), host_x3 = gen({23, 42}, cn), | |||||
host_x4 = gen({1, 42}, cn), host_x5 = gen({23, 1}, cn); | |||||
auto make_dst = [&](ComputingGraph& graph) { | |||||
auto a = opr::Host2DeviceCopy::make(graph, host_x0), | |||||
b = opr::Host2DeviceCopy::make(graph, host_x1), | |||||
c = opr::Host2DeviceCopy::make(graph, host_x2), | |||||
d = opr::Host2DeviceCopy::make(graph, host_x3), | |||||
e = opr::Host2DeviceCopy::make(graph, host_x4); | |||||
return a + opr::max(b, c) + opr::max(d, e); | |||||
}; | |||||
HostTensorND host_y1, host_y2; | |||||
auto funcs = make_func_pair(host_y1, host_y2, make_dst, 2); | |||||
funcs.first->execute(); | |||||
funcs.second->execute(); | |||||
MGB_ASSERT_TENSOR_EQ(host_y1, host_y2); | |||||
JITExecutor* jit; | |||||
unpack_vector(find_oprs<JITExecutor>(*funcs.second), jit); | |||||
ASSERT_EQ(2u, find_oprs<opr::Elemwise>(*funcs.second).size()); | |||||
ASSERT_EQ(3u, jit->input().size()); | |||||
} | |||||
TEST(TestJITExecutor, TestJITMlirFusion) { | |||||
run_mlir(CompNode::load("cpu0")); | |||||
} | |||||
TEST(TestJITExecutor, TestJITMlirFusionGpu) { | |||||
REQUIRE_GPU(1); | |||||
run_mlir(CompNode::load("gpu0")); | |||||
} | |||||
#endif // MGB_JIT_MLIR | |||||
#endif // MGB_JIT | #endif // MGB_JIT | ||||
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} |