Browse Source

fix(imperative): fix the segmentfault when reduce backward

GitOrigin-RevId: 8a3e63d4f5
release-1.10
Megvii Engine Team 3 years ago
parent
commit
7b4b94fd93
1 changed files with 1 additions and 1 deletions
  1. +1
    -1
      imperative/python/src/grad_override.cpp

+ 1
- 1
imperative/python/src/grad_override.cpp View File

@@ -408,7 +408,7 @@ std::optional<ValueRefList> reduce_grad_rule(
[shapes = std::move(input_shapes), axis, keepdim](Span<ValueRef> grads) { [shapes = std::move(input_shapes), axis, keepdim](Span<ValueRef> grads) {
mgb_assert(grads.size() == 1); mgb_assert(grads.size() == 1);
ValueRef grad = grads[0]; ValueRef grad = grads[0];
if (!keepdim) {
if (!keepdim && grad) {
auto&& grad_op = AddAxis::make(std::vector<int32_t>({axis})); auto&& grad_op = AddAxis::make(std::vector<int32_t>({axis}));
grad = imperative::apply(*grad_op, grad)[0]; grad = imperative::apply(*grad_op, grad)[0];
} }


Loading…
Cancel
Save