diff --git a/imperative/python/test/unit/core/test_autodiff.py b/imperative/python/test/unit/core/test_autodiff.py index cd07f654..30997267 100644 --- a/imperative/python/test/unit/core/test_autodiff.py +++ b/imperative/python/test/unit/core/test_autodiff.py @@ -136,6 +136,46 @@ def test_grad_with_tensor_wrapper(): np.testing.assert_almost_equal(x.grad.numpy(), 4 * x_np ** 3, decimal=6) +def test_wrt_intermediate_var(): + x_np = np.random.rand(10).astype("float32") + x = mge.Tensor(x_np) + + result = {} + + with Grad() as grad: + grad.wrt(x, callback=lambda dx: result.update(dx=dx)) + y = mul(x, x) + grad.wrt(y, callback=lambda dy: result.update(dy=dy)) + z = mul(y, y) + grad(z, mge.Tensor(np.ones_like(x_np))) + + np.testing.assert_almost_equal(result["dx"].numpy(), 4 * x_np ** 3, decimal=6) + np.testing.assert_almost_equal(result["dy"].numpy(), 2 * (x_np ** 2), decimal=6) + + +@pytest.mark.parametrize("in_path", [False, True]) +def test_wrt_visibility(in_path): + x_np = np.random.rand(10).astype("float32") + x = mge.Tensor(x_np) + + def copy(x): + xx = mge.Tensor(x) + xx._reset(x) + return xx + + result = {} + + with Grad() as grad: + if in_path: + grad.wrt(x, callback=lambda _: None) + y = mul(x, x) + grad.wrt(copy(y), callback=lambda dy: result.update(dy=dy)) + z = mul(y, y) + grad(z, mge.Tensor(np.ones_like(x_np))) + + assert not result + + def test_release(): def check(f): n = 0 diff --git a/imperative/src/impl/transformations/grad.cpp b/imperative/src/impl/transformations/grad.cpp index 54e93100..7a9879b8 100644 --- a/imperative/src/impl/transformations/grad.cpp +++ b/imperative/src/impl/transformations/grad.cpp @@ -265,20 +265,21 @@ void GradKey::backward() { GradValue::ref_t GradKey::attach( ValueRef tensor, std::function callback) { - auto grad_value = tensor.as_ref(m_value_type); - if (grad_value) { - mgb_assert(!tensor.cast(m_value_type).slot()->callback, "callback exists"); - } else { - GradSlotPtr grad_slot; - auto& grad_fn = grad_slot.m_fn; - grad_fn = LocalPtr::make(); - grad_fn->m_key = shared_from_this(); - grad_fn->m_slots.resize(1); - grad_slot.m_index = 0; - grad_value = m_value_type.make(tensor, shared_from_this(), grad_slot); + // always create a new grad value + GradSlotPtr grad_slot; + auto& grad_fn = grad_slot.m_fn; + grad_fn = LocalPtr::make(); + grad_fn->m_key = shared_from_this(); + grad_fn->m_slots.resize(1); + grad_fn->m_slots[0].callback = callback; + grad_slot.m_index = 0; + if (auto&& grad_value = tensor.as_ref(m_value_type)) { + grad_fn->m_backward.emplace(); + grad_fn->m_dests.push_back(grad_value->m_slot); + tensor = grad_value->m_value; + m_tape.emplace_back(grad_fn, nullptr); } - grad_value->slot().m_fn->m_slots[0].callback = callback; - return grad_value; + return m_value_type.make(tensor, shared_from_this(), grad_slot); } void GradKey::freeze() { @@ -424,22 +425,17 @@ ValueRefList GradTransformation::apply_transformation( return outputs; } else if (op.is()) { return imperative::apply(op, inputs); - } - if (auto* attach_grad = op.as()) { - auto& tensor = inputs[0]; - if (auto&& grad_value = tensor.as_ref(m_value_type)) { - mgb_assert(!has_key(attach_grad->key())); - auto output = fallback()[0]; - return record_grad(m_value_type.make(output, m_key, grad_value->slot())); - } else if (!has_key(attach_grad->key())) { + } else if (auto* attach_grad = op.as()) { + if (!has_key(attach_grad->key())) { return fallback(); } else { GenericFunction callback = (GenericFunction&)inputs[1].cast(); - auto output = attach_grad->key()->attach(tensor, [callback](ValueRef grad) { - auto ret = callback({&grad, 1}); - assert(ret.empty()); - }); + auto output = + attach_grad->key()->attach(inputs[0], [callback](ValueRef grad) { + auto ret = callback({&grad, 1}); + mgb_assert(ret.empty()); + }); return {record_grad(output)}; } } else if (auto* grad_backward = op.as()) { diff --git a/imperative/src/include/megbrain/imperative/transformations/grad.h b/imperative/src/include/megbrain/imperative/transformations/grad.h index 7c5cec9d..6f08e2b7 100644 --- a/imperative/src/include/megbrain/imperative/transformations/grad.h +++ b/imperative/src/include/megbrain/imperative/transformations/grad.h @@ -83,6 +83,20 @@ public: static BackwardRule lookup_grad_rule(Typeinfo* typeinfo); }; +struct IdentityBackward { + bool input_has_grad(size_t i) { mgb_assert(0); } + bool output_requires_grad(size_t i) { mgb_assert(0); } + + template + void operator()(Span grads, F&& receiver) { + for (size_t i = 0; i < grads.size(); ++i) { + if (grads[i]) { + receiver(i, grads[i]); + } + } + } +}; + class GradSlot; class GradSlotPtr; class GradSlotProducerPtr; @@ -165,7 +179,9 @@ private: std::weak_ptr m_key; SmallVector m_slots; SmallVector m_dests; - std::variant m_backward; + std::variant< + std::monostate, BackwardGraphWithClosure, CustomBackward, IdentityBackward> + m_backward; public: void clear() {