assert not keepdims, "can not set axis=None and keepdims=True"
@@ -180,6 +184,9 @@ def _reduce(mode):
if not keepdims:
result = _remove_axis(result, axis)
if self.dtype == np.bool_:
if mode in ["MIN", "MAX"]:
result = result.astype("bool")
return result
return f
@@ -377,7 +384,38 @@ class ArrayMethodMixin(abc.ABC):
def flatten(self):
return self.reshape(-1)
sum = _reduce("SUM")
def sum(self, axis=None, keepdims: bool = False):
r"""Returns the sum of each row of the input tensor in the given dimension ``axis``.
If ``axis`` is a list of axises, reduce over all of them.
If ``keepdims`` is ``True``, the shape of output tensor is the same as the input tensor, except in the dimension(s) ``axis`` where it is of size 1. Otherwise, ``axis`` is squeezed(see :meth:`~.functional.tensor.remove_axis`).
Same for prod/mean/max/min.
:param axis: the dimension or dimensions to reduce.
:param keepdim: whether the output tensor has ndim retained or not.