|
|
@@ -1443,7 +1443,111 @@ TEST(TestJITNvrtc, DimshuffleGrad) { |
|
|
|
}, |
|
|
|
CompNode::load("gpu0")}; |
|
|
|
checker.set_jit_level(1) |
|
|
|
.run({TensorShape{1, 2, 3, 4}, {2, 3, 4, 1}}); |
|
|
|
.run({TensorShape{1, 2, 3, 4}, {2, 3, 4, 1}}) |
|
|
|
.run({TensorShape{3, 4, 1, 2}, {4, 1, 2, 3}}) |
|
|
|
.run({TensorShape{4, 6, 3, 5}, {6, 3, 5, 4}}); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
TEST(TestJITExecutor, GradBehavior) { |
|
|
|
REQUIRE_GPU(1); |
|
|
|
auto cn = CompNode::load("gpu0"); |
|
|
|
HostTensorGenerator<> gen; |
|
|
|
{ |
|
|
|
set_backend(Backend::NVRTC); |
|
|
|
auto graph = ComputingGraph::make(); |
|
|
|
auto host_a = gen({2, 3, 4}, cn); |
|
|
|
auto a = opr::Host2DeviceCopy::make(*graph, host_a), |
|
|
|
x = opr::exp(a + 1); |
|
|
|
|
|
|
|
gopt::GraphOptimizer gopt; |
|
|
|
gopt.add_pass<gopt::JITFusionPass>(); |
|
|
|
VarNodeArray dest_vars{x.node()}; |
|
|
|
gopt.apply_inplace(dest_vars); |
|
|
|
x = opr::reduce_sum(dest_vars[0], a.make_scalar_dt(1)); |
|
|
|
SmallVector<jit::JITExecutor*> jits; |
|
|
|
auto on_opr = [&jits](cg::OperatorNodeBase* op) { |
|
|
|
if (auto jit = op->try_cast_final<jit::JITExecutor>()) { |
|
|
|
jits.push_back(jit); |
|
|
|
} |
|
|
|
}; |
|
|
|
auto grad_a = cg::grad(x, a); |
|
|
|
cg::DepOprIter{on_opr}.add(grad_a); |
|
|
|
ASSERT_EQ(jits.size(), 2); |
|
|
|
// input of forward jit executor: host_a |
|
|
|
ASSERT_EQ(jits[0]->input().size(), 1); |
|
|
|
// input of grad jit executor: |
|
|
|
// output of forward jit executor, output grad |
|
|
|
ASSERT_EQ(jits[1]->input().size(), 2); |
|
|
|
// internal graph is (input: og, out | output: og * out) |
|
|
|
size_t nr_ph = 0, nr_mul = 0; |
|
|
|
cg::DepOprIter{ |
|
|
|
[&nr_ph, &nr_mul](cg::OperatorNodeBase* op) { |
|
|
|
if (op->same_type<jit::JITPlaceholder>()) { |
|
|
|
++ nr_ph; |
|
|
|
return; |
|
|
|
} |
|
|
|
if(auto mul = op->try_cast_final<opr::Elemwise>()) { |
|
|
|
using Mode = opr::Elemwise::Mode; |
|
|
|
if (mul->param().mode == Mode::MUL) { |
|
|
|
++ nr_mul; |
|
|
|
return; |
|
|
|
} |
|
|
|
} |
|
|
|
mgb_throw(MegBrainError, "unexpected op %s", op->cname()); |
|
|
|
}} |
|
|
|
.add(jits[1]->internal_graph_ptr()->output()); |
|
|
|
ASSERT_EQ(nr_ph, 2); |
|
|
|
ASSERT_EQ(nr_mul, 1); |
|
|
|
} |
|
|
|
{ |
|
|
|
set_backend(Backend::HALIDE); |
|
|
|
auto graph = ComputingGraph::make(); |
|
|
|
auto host_a = gen({2, 3, 4}, cn); |
|
|
|
auto a = opr::Host2DeviceCopy::make(*graph, host_a), |
|
|
|
x = opr::exp(a + 1); |
|
|
|
|
|
|
|
gopt::GraphOptimizer gopt; |
|
|
|
gopt.add_pass<gopt::JITFusionPass>(); |
|
|
|
VarNodeArray dest_vars{x.node()}; |
|
|
|
gopt.apply_inplace(dest_vars); |
|
|
|
x = opr::reduce_sum(dest_vars[0], a.make_scalar_dt(1)); |
|
|
|
size_t nr_ops = 0, nr_jits = 0; |
|
|
|
auto on_opr = [&nr_jits, &nr_ops](cg::OperatorNodeBase* op) { |
|
|
|
if (op->same_type<jit::JITExecutor>()) { |
|
|
|
++ nr_jits; |
|
|
|
} |
|
|
|
++ nr_ops; |
|
|
|
}; |
|
|
|
auto grad_a = cg::grad(x, a); |
|
|
|
cg::DepOprIter{on_opr}.add(grad_a); |
|
|
|
// in Halide backend, grad internal graph would be expanded into |
|
|
|
// original graph, so there was only one JITExecutor |
|
|
|
ASSERT_EQ(nr_jits, 1); |
|
|
|
// the grad of a is broadcast(JITExecutor.output(0), a.shape()), |
|
|
|
// so the oprs depended by grad_a are H2D(a), JITExecutor, |
|
|
|
// GetVarShape(a) and broadcast |
|
|
|
ASSERT_EQ(nr_ops, 4); |
|
|
|
} |
|
|
|
{ |
|
|
|
set_backend(Backend::NVRTC); |
|
|
|
auto graph = ComputingGraph::make(); |
|
|
|
auto host_a = gen({2, 3, 4}, cn); |
|
|
|
auto a = opr::SharedDeviceTensor::make(*graph, *host_a), |
|
|
|
x = a * 2 + 1; |
|
|
|
|
|
|
|
gopt::GraphOptimizer gopt; |
|
|
|
gopt.add_pass<gopt::JITFusionPass>(); |
|
|
|
VarNodeArray dest_vars{x.node()}; |
|
|
|
gopt.apply_inplace(dest_vars); |
|
|
|
x = opr::reduce_sum(dest_vars[0], a.make_scalar_dt(1)); |
|
|
|
auto grad_a = cg::grad(x, a); |
|
|
|
// all inputs of grad jit executor are const, its internal graph |
|
|
|
// would be expanded into original graph for more optimizations, |
|
|
|
// so no JITExecutor can be found |
|
|
|
cg::DepOprIter{[](cg::OperatorNodeBase* op) { |
|
|
|
ASSERT_FALSE(op->same_type<jit::JITExecutor>());} |
|
|
|
}.add(grad_a); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|