diff --git a/imperative/python/src/grad_override.cpp b/imperative/python/src/grad_override.cpp index be6c95f9..c576aab2 100644 --- a/imperative/python/src/grad_override.cpp +++ b/imperative/python/src/grad_override.cpp @@ -408,7 +408,7 @@ std::optional reduce_grad_rule( [shapes = std::move(input_shapes), axis, keepdim](Span grads) { mgb_assert(grads.size() == 1); ValueRef grad = grads[0]; - if (!keepdim) { + if (!keepdim && grad) { auto&& grad_op = AddAxis::make(std::vector({axis})); grad = imperative::apply(*grad_op, grad)[0]; }