Browse Source

fix(mge): fix bool.sum()

GitOrigin-RevId: 62c482db40
release-1.1
Megvii Engine Team 4 years ago
parent
commit
1965276845
2 changed files with 49 additions and 6 deletions
  1. +39
    -1
      imperative/python/megengine/core/tensor/tensor_wrapper.py
  2. +10
    -5
      imperative/python/test/unit/core/test_tensor_wrapper.py

+ 39
- 1
imperative/python/megengine/core/tensor/tensor_wrapper.py View File

@@ -158,6 +158,10 @@ def _reduce(mode):
def f(self, axis=None, keepdims: bool = False):
data = self
(data,) = utils.convert_inputs(data)
if mode == "MEAN":
data = data.astype("float32")
elif self.dtype == np.bool_:
data = data.astype("int32")
if axis is None:
data = data.reshape(-1)
assert not keepdims, "can not set axis=None and keepdims=True"
@@ -180,6 +184,9 @@ def _reduce(mode):

if not keepdims:
result = _remove_axis(result, axis)
if self.dtype == np.bool_:
if mode in ["MIN", "MAX"]:
result = result.astype("bool")
return result

return f
@@ -377,7 +384,38 @@ class ArrayMethodMixin(abc.ABC):
def flatten(self):
return self.reshape(-1)

sum = _reduce("SUM")
def sum(self, axis=None, keepdims: bool = False):
r"""Returns the sum of each row of the input tensor in the given dimension ``axis``.
If ``axis`` is a list of axises, reduce over all of them.
If ``keepdims`` is ``True``, the shape of output tensor is the same as the input tensor, except in the dimension(s) ``axis`` where it is of size 1. Otherwise, ``axis`` is squeezed(see :meth:`~.functional.tensor.remove_axis`).

Same for prod/mean/max/min.

:param axis: the dimension or dimensions to reduce.
:param keepdim: whether the output tensor has ndim retained or not.
:return: output tensor.

Examples:

.. testcode::

from megengine import tensor
a = tensor([False, True, True, False])
b = tensor([1.0, 2.0, 3.0, 4.0])
print(a.sum().numpy())
print(b.sum().numpy())

Outputs:

.. testoutput::

[2]
[10.]

"""
return _reduce("SUM")(self, axis, keepdims)

prod = _reduce("PRODUCT")
min = _reduce("MIN")
max = _reduce("MAX")


+ 10
- 5
imperative/python/test/unit/core/test_tensor_wrapper.py View File

@@ -35,11 +35,16 @@ def test_matmul():


def test_reduce():
for m in ["sum", "prod", "min", "max", "mean"]:
x_np = np.random.rand(10).astype("float32")
x = TensorWrapper(x_np)
y = getattr(x, m)(axis=-1, keepdims=True)
np.testing.assert_almost_equal(y.numpy(), getattr(x_np, m)(-1), decimal=6)
def test_x(x_np):
for m in ["sum", "prod", "min", "max", "mean"]:
x = TensorWrapper(x_np)
y = getattr(x, m)(axis=-1, keepdims=True)
np.testing.assert_almost_equal(y.numpy(), getattr(x_np, m)(-1), decimal=6)

test_x((10 * np.random.rand(10) + 1).astype("int32"))
test_x(np.random.rand(10).astype("float32"))
test_x(np.array([True, True, True]))
test_x(np.array([True, False, True]))


def test_set_subtensor():


Loading…
Cancel
Save