Browse Source

fix(mgb): register invalid grad for AddUpdate

GitOrigin-RevId: f9bbf570dc
tags/v0.5.0
Megvii Engine Team 5 years ago
parent
commit
2f3d185de6
2 changed files with 27 additions and 1 deletions
  1. +5
    -0
      src/opr/impl/basic_arith.cpp
  2. +22
    -1
      src/opr/test/basic_arith/others.cpp

+ 5
- 0
src/opr/impl/basic_arith.cpp View File

@@ -947,6 +947,11 @@ void AddUpdate::record_execute_deps(ExecDependencyArray& deps) {
record_megdnn_opr(deps);
}

MGB_IMPL_OPR_GRAD(AddUpdate) {
// actually valid, just not implemented
return InvalidGrad::make(opr, wrt_idx);
}

/* =========================== Reduce =========================== */

class Reduce::KernScheduler {


+ 22
- 1
src/opr/test/basic_arith/others.cpp View File

@@ -372,7 +372,7 @@ TEST(TestOprBasicArith, AddUpdateVolatile) {
for (size_t i = 0; i < SIZE * 2; i ++) {
MGB_ASSERT_FLOAT_EQ(expect(i), z[i]);
}
mgb_assert(host_sub.shape().total_nr_elems() == 4 &&
mgb_assert(host_sub.shape().total_nr_elems() == 4 &&
host_sub.layout().is_contiguous());
for (size_t i = 0; i < 4; ++ i) {
size_t idx = i * (SIZE >> 1);
@@ -390,6 +390,27 @@ TEST(TestOprBasicArith, AddUpdateVolatile) {
}
}

// AddUpdate in gradient path but no gradient flows through it
TEST(TestOprBasicArith, AddUpdateInGradPath) {
auto graph = ComputingGraph::make();
HostTensorGenerator<> gen;
auto dest = opr::SharedDeviceTensor::make(*graph, *gen({42}));
auto host_x = gen({42});
auto x = opr::Host2DeviceCopy::make(*graph, host_x);
// delta depends on x, but not differentiable wrt x
// a invalid grad is registered for AddUpdate to fix this case
auto delta = opr::VirtualDep::make({opr::SetGrad::make(x, nullptr), x});
auto updated = opr::AddUpdate::make(dest, delta);
auto y = opr::reduce_ax_sum(updated + x, 0);
auto dx = cg::grad(y, x);
HostTensorND host_dx;
auto func = graph->compile({make_callback_copy(dx, host_dx)});
func->execute();
for (size_t i = 0; i < host_dx.shape(0); ++i) {
MGB_ASSERT_FLOAT_EQ(host_dx.ptr<float>()[i], 1.f);
}
}

TEST(TestOprBasicArith, MemFwd) {
constexpr size_t SIZE = 12321;
HostTensorGenerator<> gen;


Loading…
Cancel
Save