|
|
@@ -372,7 +372,7 @@ TEST(TestOprBasicArith, AddUpdateVolatile) { |
|
|
|
for (size_t i = 0; i < SIZE * 2; i ++) { |
|
|
|
MGB_ASSERT_FLOAT_EQ(expect(i), z[i]); |
|
|
|
} |
|
|
|
mgb_assert(host_sub.shape().total_nr_elems() == 4 && |
|
|
|
mgb_assert(host_sub.shape().total_nr_elems() == 4 && |
|
|
|
host_sub.layout().is_contiguous()); |
|
|
|
for (size_t i = 0; i < 4; ++ i) { |
|
|
|
size_t idx = i * (SIZE >> 1); |
|
|
@@ -390,6 +390,27 @@ TEST(TestOprBasicArith, AddUpdateVolatile) { |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
// AddUpdate in gradient path but no gradient flows through it |
|
|
|
TEST(TestOprBasicArith, AddUpdateInGradPath) { |
|
|
|
auto graph = ComputingGraph::make(); |
|
|
|
HostTensorGenerator<> gen; |
|
|
|
auto dest = opr::SharedDeviceTensor::make(*graph, *gen({42})); |
|
|
|
auto host_x = gen({42}); |
|
|
|
auto x = opr::Host2DeviceCopy::make(*graph, host_x); |
|
|
|
// delta depends on x, but not differentiable wrt x |
|
|
|
// a invalid grad is registered for AddUpdate to fix this case |
|
|
|
auto delta = opr::VirtualDep::make({opr::SetGrad::make(x, nullptr), x}); |
|
|
|
auto updated = opr::AddUpdate::make(dest, delta); |
|
|
|
auto y = opr::reduce_ax_sum(updated + x, 0); |
|
|
|
auto dx = cg::grad(y, x); |
|
|
|
HostTensorND host_dx; |
|
|
|
auto func = graph->compile({make_callback_copy(dx, host_dx)}); |
|
|
|
func->execute(); |
|
|
|
for (size_t i = 0; i < host_dx.shape(0); ++i) { |
|
|
|
MGB_ASSERT_FLOAT_EQ(host_dx.ptr<float>()[i], 1.f); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
TEST(TestOprBasicArith, MemFwd) { |
|
|
|
constexpr size_t SIZE = 12321; |
|
|
|
HostTensorGenerator<> gen; |
|
|
|