|
@@ -408,7 +408,7 @@ std::optional<ValueRefList> reduce_grad_rule( |
|
|
[shapes = std::move(input_shapes), axis, keepdim](Span<ValueRef> grads) { |
|
|
[shapes = std::move(input_shapes), axis, keepdim](Span<ValueRef> grads) { |
|
|
mgb_assert(grads.size() == 1); |
|
|
mgb_assert(grads.size() == 1); |
|
|
ValueRef grad = grads[0]; |
|
|
ValueRef grad = grads[0]; |
|
|
if (!keepdim) { |
|
|
|
|
|
|
|
|
if (!keepdim && grad) { |
|
|
auto&& grad_op = AddAxis::make(std::vector<int32_t>({axis})); |
|
|
auto&& grad_op = AddAxis::make(std::vector<int32_t>({axis})); |
|
|
grad = imperative::apply(*grad_op, grad)[0]; |
|
|
grad = imperative::apply(*grad_op, grad)[0]; |
|
|
} |
|
|
} |
|
|