Browse Source

fix(jit): more testcases on the grad of JITExecutor

GitOrigin-RevId: c3bb405979
release-0.6
Megvii Engine Team 5 years ago
parent
commit
672d4ad0e0
2 changed files with 109 additions and 5 deletions
  1. +4
    -4
      src/jit/impl/executor_opr.cpp
  2. +105
    -1
      src/jit/test/fusion.cpp

+ 4
- 4
src/jit/impl/executor_opr.cpp View File

@@ -549,8 +549,8 @@ MGB_IMPL_OPR_GRAD(JITExecutor) {
rewriter.auto_replace_outputs(opr);
});

static auto expand_into_origin_graph = [](cg::OperatorNodeBase* opr,
InternalGraphRewriter& rewriter, const VarNodeArray& grad_inputs) {
auto expand_into_origin_graph = [&rewriter](
cg::OperatorNodeBase* opr, const VarNodeArray& grad_inputs) {
if (auto ph = gopt::try_cast_as_op<JITPlaceholder>(opr)) {
rewriter.replace_var(
opr->output(0), grad_inputs.at(ph->input_id()));
@@ -571,7 +571,7 @@ MGB_IMPL_OPR_GRAD(JITExecutor) {
// oprs
using namespace std::placeholders;
rewriter.iter(std::bind(expand_into_origin_graph, _1,
std::ref(rewriter), std::cref(grad_inputs)));
std::cref(grad_inputs)));
return rewriter.dest_var();
} else {
VarNodeArray new_grad_inputs;
@@ -602,7 +602,7 @@ MGB_IMPL_OPR_GRAD(JITExecutor) {
// infer and const folding mechanism
using namespace std::placeholders;
rewriter.iter(std::bind(expand_into_origin_graph, _1,
std::ref(rewriter), std::cref(new_grad_inputs)));
std::cref(new_grad_inputs)));
return rewriter.dest_var();
}
gx = rewriter.dest_var();


+ 105
- 1
src/jit/test/fusion.cpp View File

@@ -1443,7 +1443,111 @@ TEST(TestJITNvrtc, DimshuffleGrad) {
},
CompNode::load("gpu0")};
checker.set_jit_level(1)
.run({TensorShape{1, 2, 3, 4}, {2, 3, 4, 1}});
.run({TensorShape{1, 2, 3, 4}, {2, 3, 4, 1}})
.run({TensorShape{3, 4, 1, 2}, {4, 1, 2, 3}})
.run({TensorShape{4, 6, 3, 5}, {6, 3, 5, 4}});
}
}

TEST(TestJITExecutor, GradBehavior) {
REQUIRE_GPU(1);
auto cn = CompNode::load("gpu0");
HostTensorGenerator<> gen;
{
set_backend(Backend::NVRTC);
auto graph = ComputingGraph::make();
auto host_a = gen({2, 3, 4}, cn);
auto a = opr::Host2DeviceCopy::make(*graph, host_a),
x = opr::exp(a + 1);

gopt::GraphOptimizer gopt;
gopt.add_pass<gopt::JITFusionPass>();
VarNodeArray dest_vars{x.node()};
gopt.apply_inplace(dest_vars);
x = opr::reduce_sum(dest_vars[0], a.make_scalar_dt(1));
SmallVector<jit::JITExecutor*> jits;
auto on_opr = [&jits](cg::OperatorNodeBase* op) {
if (auto jit = op->try_cast_final<jit::JITExecutor>()) {
jits.push_back(jit);
}
};
auto grad_a = cg::grad(x, a);
cg::DepOprIter{on_opr}.add(grad_a);
ASSERT_EQ(jits.size(), 2);
// input of forward jit executor: host_a
ASSERT_EQ(jits[0]->input().size(), 1);
// input of grad jit executor:
// output of forward jit executor, output grad
ASSERT_EQ(jits[1]->input().size(), 2);
// internal graph is (input: og, out | output: og * out)
size_t nr_ph = 0, nr_mul = 0;
cg::DepOprIter{
[&nr_ph, &nr_mul](cg::OperatorNodeBase* op) {
if (op->same_type<jit::JITPlaceholder>()) {
++ nr_ph;
return;
}
if(auto mul = op->try_cast_final<opr::Elemwise>()) {
using Mode = opr::Elemwise::Mode;
if (mul->param().mode == Mode::MUL) {
++ nr_mul;
return;
}
}
mgb_throw(MegBrainError, "unexpected op %s", op->cname());
}}
.add(jits[1]->internal_graph_ptr()->output());
ASSERT_EQ(nr_ph, 2);
ASSERT_EQ(nr_mul, 1);
}
{
set_backend(Backend::HALIDE);
auto graph = ComputingGraph::make();
auto host_a = gen({2, 3, 4}, cn);
auto a = opr::Host2DeviceCopy::make(*graph, host_a),
x = opr::exp(a + 1);

gopt::GraphOptimizer gopt;
gopt.add_pass<gopt::JITFusionPass>();
VarNodeArray dest_vars{x.node()};
gopt.apply_inplace(dest_vars);
x = opr::reduce_sum(dest_vars[0], a.make_scalar_dt(1));
size_t nr_ops = 0, nr_jits = 0;
auto on_opr = [&nr_jits, &nr_ops](cg::OperatorNodeBase* op) {
if (op->same_type<jit::JITExecutor>()) {
++ nr_jits;
}
++ nr_ops;
};
auto grad_a = cg::grad(x, a);
cg::DepOprIter{on_opr}.add(grad_a);
// in Halide backend, grad internal graph would be expanded into
// original graph, so there was only one JITExecutor
ASSERT_EQ(nr_jits, 1);
// the grad of a is broadcast(JITExecutor.output(0), a.shape()),
// so the oprs depended by grad_a are H2D(a), JITExecutor,
// GetVarShape(a) and broadcast
ASSERT_EQ(nr_ops, 4);
}
{
set_backend(Backend::NVRTC);
auto graph = ComputingGraph::make();
auto host_a = gen({2, 3, 4}, cn);
auto a = opr::SharedDeviceTensor::make(*graph, *host_a),
x = a * 2 + 1;

gopt::GraphOptimizer gopt;
gopt.add_pass<gopt::JITFusionPass>();
VarNodeArray dest_vars{x.node()};
gopt.apply_inplace(dest_vars);
x = opr::reduce_sum(dest_vars[0], a.make_scalar_dt(1));
auto grad_a = cg::grad(x, a);
// all inputs of grad jit executor are const, its internal graph
// would be expanded into original graph for more optimizations,
// so no JITExecutor can be found
cg::DepOprIter{[](cg::OperatorNodeBase* op) {
ASSERT_FALSE(op->same_type<jit::JITExecutor>());}
}.add(grad_a);
}
}



Loading…
Cancel
Save