Browse Source

fix(mgb/jit): add bind_shape feature to MLIRCompiler

GitOrigin-RevId: bec6796fbf
release-1.2
Megvii Engine Team 4 years ago
parent
commit
d2910f7ef5
2 changed files with 38 additions and 1 deletions
  1. +1
    -1
      src/jit/impl/mlir/compiler.h
  2. +37
    -0
      src/jit/test/codegen.cpp

+ 1
- 1
src/jit/impl/mlir/compiler.h View File

@@ -33,7 +33,7 @@ public:
MLIRCompiler(CompNode::DeviceType device_type = CompNode::DeviceType::CPU);
Property property() const override {
using F = Property::Flag;
return Property{F::NEED_INPUT_COLLAPSE | F::BIND_NDIM,
return Property{F::BIND_NDIM | F::BIND_SHAPE,
JITFeatureBits::DIMSHUFFLE, 64};
}



+ 37
- 0
src/jit/test/codegen.cpp View File

@@ -197,6 +197,41 @@ void run_mlir_broadcast(CompNode cn) {
MGB_ASSERT_TENSOR_EQ(host_y, host_y_jit);
}

void run_mlir_different_shape(CompNode cn) {
set_backend(Backend::MLIR);
auto graph = ComputingGraph::make();
HostTensorGenerator<dtype::Float32> gen;

auto run = [&](TensorShape tshp) {
auto host_x = gen(tshp, cn);
auto x = opr::Host2DeviceCopy::make(*graph, host_x);
auto y = x * 2;
auto ig_gen =
std::make_unique<InternalGraphGenerator>(y.node()->owner_opr());

for (auto i : get_rev_topo_order(y)) {
if (!i->same_type<opr::Host2DeviceCopy>()) {
ig_gen->add_opr(i);
}
}

auto igraph = ig_gen->generate();
auto y_jit = JITExecutor::make(igraph, ig_gen->orig_inps());

HostTensorND host_y, host_y_jit;
auto func = graph->compile({make_callback_copy(y, host_y),
make_callback_copy(y_jit, host_y_jit)});
func->execute();

MGB_ASSERT_TENSOR_EQ(host_y, host_y_jit);
};

run({23, 42});
run({16, 31});
run({32, 56});
run({10});
}

struct MlirTestOpt {
float low;
float high;
@@ -297,6 +332,7 @@ TEST(TestJITMlirCodeGen, Basic) {
auto cn = CompNode::load("cpu0");
run_mlir(cn);
run_mlir_broadcast(cn);
run_mlir_different_shape(cn);
}

TEST(TestJITMlirCodeGen, BasicGPU) {
@@ -304,6 +340,7 @@ TEST(TestJITMlirCodeGen, BasicGPU) {
auto cn = CompNode::load("gpu0");
run_mlir(cn);
run_mlir_broadcast(cn);
run_mlir_different_shape(cn);
}

/* ===================== TestJITMlirUnaryElemwise ===================== */


Loading…
Cancel
Save