Browse Source

refactor(mge/tensor): tensor reduce supports keepdims

GitOrigin-RevId: 8ed95e0fb8
tags/v1.0.0-rc1
Megvii Engine Team 4 years ago
parent
commit
7e8f720953
5 changed files with 56 additions and 51 deletions
  1. +46
    -7
      imperative/python/megengine/core/tensor/tensor_wrapper.py
  2. +3
    -3
      imperative/python/megengine/functional/loss.py
  3. +5
    -39
      imperative/python/megengine/functional/math.py
  4. +1
    -1
      imperative/python/megengine/functional/nn.py
  5. +1
    -1
      imperative/python/test/unit/test_tensor_wrapper.py

+ 46
- 7
imperative/python/megengine/core/tensor/tensor_wrapper.py View File

@@ -134,15 +134,54 @@ def _logical_binary_elwise(mode, rev=False):
return f return f




def _remove_axis(inp: Tensor, axis) -> Tensor:
Param = builtin.AxisAddRemove.Param

def get_axes():
if axis is None:
return [i for i, s in enumerate(inp.shape) if s == 1]
try:
return [int(axis)]
except (TypeError, ValueError):
pass
return list(map(int, axis))

axis = get_axes()
axis = sorted(i + inp.ndim if i < 0 else i for i in axis)
axis = [a - i for i, a in enumerate(axis)]

param = Param(*map(builtin.AxisAddRemove.AxisDesc.make_remove, axis))
op = builtin.AxisAddRemove(param=param)
(result,) = apply(op, inp)
return result


def _reduce(mode): def _reduce(mode):
def f(self, axis=None):
inp = self
def f(self, axis=None, keepdims: bool = False):
data = self
(data,) = utils.convert_inputs(data)
if axis is None: if axis is None:
inp = self.flatten()
axis = 0
op = builtin.Reduce(mode=mode, axis=axis)
(result,) = utils.convert_inputs(inp)
(result,) = apply(op, result)
data = data.reshape(-1)
assert not keepdims, "can not set axis=None and keepdims=True"

op = builtin.Reduce(mode=mode, axis=0)
(result,) = apply(op, data)
elif isinstance(axis, collections.Iterable):
axis = list(axis)
axis.sort(reverse=True)

for ai in axis:
op = builtin.Reduce(mode=mode, axis=ai)
(data,) = apply(op, data)
if not keepdims:
data = _remove_axis(data, ai)
result = data
else:
op = builtin.Reduce(mode=mode, axis=axis)
(result,) = apply(op, data)

if not keepdims:
result = _remove_axis(result, axis)
return result return result


return f return f


+ 3
- 3
imperative/python/megengine/functional/loss.py View File

@@ -176,15 +176,15 @@ def cross_entropy_with_softmax(
num_classes = pred.shape[axis] num_classes = pred.shape[axis]


# Denominator of the softmax # Denominator of the softmax
offset = pred.max(axis=axis).detach()
offset = pred.max(axis=axis, keepdims=True).detach()
pred = pred - offset pred = pred - offset
down = exp(pred).sum(axis=axis)
down = exp(pred).sum(axis=axis, keepdims=True)


up = indexing_one_hot(pred, label, axis) up = indexing_one_hot(pred, label, axis)


if label_smooth != 0: if label_smooth != 0:
factor = label_smooth / num_classes factor = label_smooth / num_classes
up = up * (1 - label_smooth) + pred.sum(axis=axis) * factor
up = up * (1 - label_smooth) + pred.sum(axis=axis, keepdims=True) * factor


return (log(down) - up).mean() return (log(down) - up).mean()




+ 5
- 39
imperative/python/megengine/functional/math.py View File

@@ -117,40 +117,6 @@ def sign(inp: Tensor):
raise NotImplementedError raise NotImplementedError




def _reduce(
data,
*,
mode,
axis: Optional[Union[int, Sequence[int]]] = None,
keepdims: bool = False
):
(data,) = utils.convert_inputs(data)
if axis is None:
data = data.reshape(-1)
assert not keepdims, "can not set axis=None and keepdims=True"

op = builtin.Reduce(mode=mode, axis=0)
(result,) = apply(op, data)
elif isinstance(axis, collections.Iterable):
axis = list(axis)
axis.sort(reverse=True)

for ai in axis:
op = builtin.Reduce(mode=mode, axis=ai)
(data,) = apply(op, data)
if not keepdims:
data = remove_axis(data, ai)
result = data
else:
op = builtin.Reduce(mode=mode, axis=axis)
(result,) = apply(op, data)

if not keepdims:
result = remove_axis(result, axis)

return result


def sum( def sum(
inp: Tensor, inp: Tensor,
axis: Optional[Union[int, Sequence[int]]] = None, axis: Optional[Union[int, Sequence[int]]] = None,
@@ -182,7 +148,7 @@ def sum(
[21] [21]


""" """
return _reduce(inp, mode="SUM", axis=axis, keepdims=keepdims)
return inp.sum(axis=axis, keepdims=keepdims)




def prod( def prod(
@@ -215,7 +181,7 @@ def prod(
[720] [720]


""" """
return _reduce(inp, mode="PRODUCT", axis=axis, keepdims=keepdims)
return inp.prod(axis=axis, keepdims=keepdims)




def mean( def mean(
@@ -248,7 +214,7 @@ def mean(
[3.5] [3.5]


""" """
return _reduce(inp, mode="MEAN", axis=axis, keepdims=keepdims)
return inp.astype("float32").mean(axis=axis, keepdims=keepdims)




def median( def median(
@@ -362,7 +328,7 @@ def min(
[1] [1]


""" """
return _reduce(inp, mode="MIN", axis=axis, keepdims=keepdims)
return inp.min(axis=axis, keepdims=keepdims)




def max( def max(
@@ -394,7 +360,7 @@ def max(
[6] [6]


""" """
return _reduce(inp, mode="MAX", axis=axis, keepdims=keepdims)
return inp.max(axis=axis, keepdims=keepdims)




def norm( def norm(


+ 1
- 1
imperative/python/megengine/functional/nn.py View File

@@ -580,7 +580,7 @@ def softmax(inp: Tensor, axis: Optional[int] = None) -> Tensor:
""" """
if axis is None: if axis is None:
axis = _get_softmax_axis(len(inp.shape)) axis = _get_softmax_axis(len(inp.shape))
offset = inp.max(axis=axis).detach()
offset = inp.max(axis=axis, keepdims=True).detach()
cached = exp(inp - offset) cached = exp(inp - offset)
down = sum(cached, axis=axis, keepdims=True) down = sum(cached, axis=axis, keepdims=True)
return cached / down return cached / down


+ 1
- 1
imperative/python/test/unit/test_tensor_wrapper.py View File

@@ -38,7 +38,7 @@ def test_reduce():
for m in ["sum", "prod", "min", "max", "mean"]: for m in ["sum", "prod", "min", "max", "mean"]:
x_np = np.random.rand(10).astype("float32") x_np = np.random.rand(10).astype("float32")
x = TensorWrapper(x_np) x = TensorWrapper(x_np)
y = getattr(x, m)(-1)
y = getattr(x, m)(axis=-1, keepdims=True)
np.testing.assert_almost_equal(y.numpy(), getattr(x_np, m)(-1), decimal=6) np.testing.assert_almost_equal(y.numpy(), getattr(x_np, m)(-1), decimal=6)






Loading…
Cancel
Save