GitOrigin-RevId: 8ed95e0fb8
tags/v1.0.0-rc1
@@ -134,15 +134,54 @@ def _logical_binary_elwise(mode, rev=False): | |||||
return f | return f | ||||
def _remove_axis(inp: Tensor, axis) -> Tensor: | |||||
Param = builtin.AxisAddRemove.Param | |||||
def get_axes(): | |||||
if axis is None: | |||||
return [i for i, s in enumerate(inp.shape) if s == 1] | |||||
try: | |||||
return [int(axis)] | |||||
except (TypeError, ValueError): | |||||
pass | |||||
return list(map(int, axis)) | |||||
axis = get_axes() | |||||
axis = sorted(i + inp.ndim if i < 0 else i for i in axis) | |||||
axis = [a - i for i, a in enumerate(axis)] | |||||
param = Param(*map(builtin.AxisAddRemove.AxisDesc.make_remove, axis)) | |||||
op = builtin.AxisAddRemove(param=param) | |||||
(result,) = apply(op, inp) | |||||
return result | |||||
def _reduce(mode): | def _reduce(mode): | ||||
def f(self, axis=None): | |||||
inp = self | |||||
def f(self, axis=None, keepdims: bool = False): | |||||
data = self | |||||
(data,) = utils.convert_inputs(data) | |||||
if axis is None: | if axis is None: | ||||
inp = self.flatten() | |||||
axis = 0 | |||||
op = builtin.Reduce(mode=mode, axis=axis) | |||||
(result,) = utils.convert_inputs(inp) | |||||
(result,) = apply(op, result) | |||||
data = data.reshape(-1) | |||||
assert not keepdims, "can not set axis=None and keepdims=True" | |||||
op = builtin.Reduce(mode=mode, axis=0) | |||||
(result,) = apply(op, data) | |||||
elif isinstance(axis, collections.Iterable): | |||||
axis = list(axis) | |||||
axis.sort(reverse=True) | |||||
for ai in axis: | |||||
op = builtin.Reduce(mode=mode, axis=ai) | |||||
(data,) = apply(op, data) | |||||
if not keepdims: | |||||
data = _remove_axis(data, ai) | |||||
result = data | |||||
else: | |||||
op = builtin.Reduce(mode=mode, axis=axis) | |||||
(result,) = apply(op, data) | |||||
if not keepdims: | |||||
result = _remove_axis(result, axis) | |||||
return result | return result | ||||
return f | return f | ||||
@@ -176,15 +176,15 @@ def cross_entropy_with_softmax( | |||||
num_classes = pred.shape[axis] | num_classes = pred.shape[axis] | ||||
# Denominator of the softmax | # Denominator of the softmax | ||||
offset = pred.max(axis=axis).detach() | |||||
offset = pred.max(axis=axis, keepdims=True).detach() | |||||
pred = pred - offset | pred = pred - offset | ||||
down = exp(pred).sum(axis=axis) | |||||
down = exp(pred).sum(axis=axis, keepdims=True) | |||||
up = indexing_one_hot(pred, label, axis) | up = indexing_one_hot(pred, label, axis) | ||||
if label_smooth != 0: | if label_smooth != 0: | ||||
factor = label_smooth / num_classes | factor = label_smooth / num_classes | ||||
up = up * (1 - label_smooth) + pred.sum(axis=axis) * factor | |||||
up = up * (1 - label_smooth) + pred.sum(axis=axis, keepdims=True) * factor | |||||
return (log(down) - up).mean() | return (log(down) - up).mean() | ||||
@@ -117,40 +117,6 @@ def sign(inp: Tensor): | |||||
raise NotImplementedError | raise NotImplementedError | ||||
def _reduce( | |||||
data, | |||||
*, | |||||
mode, | |||||
axis: Optional[Union[int, Sequence[int]]] = None, | |||||
keepdims: bool = False | |||||
): | |||||
(data,) = utils.convert_inputs(data) | |||||
if axis is None: | |||||
data = data.reshape(-1) | |||||
assert not keepdims, "can not set axis=None and keepdims=True" | |||||
op = builtin.Reduce(mode=mode, axis=0) | |||||
(result,) = apply(op, data) | |||||
elif isinstance(axis, collections.Iterable): | |||||
axis = list(axis) | |||||
axis.sort(reverse=True) | |||||
for ai in axis: | |||||
op = builtin.Reduce(mode=mode, axis=ai) | |||||
(data,) = apply(op, data) | |||||
if not keepdims: | |||||
data = remove_axis(data, ai) | |||||
result = data | |||||
else: | |||||
op = builtin.Reduce(mode=mode, axis=axis) | |||||
(result,) = apply(op, data) | |||||
if not keepdims: | |||||
result = remove_axis(result, axis) | |||||
return result | |||||
def sum( | def sum( | ||||
inp: Tensor, | inp: Tensor, | ||||
axis: Optional[Union[int, Sequence[int]]] = None, | axis: Optional[Union[int, Sequence[int]]] = None, | ||||
@@ -182,7 +148,7 @@ def sum( | |||||
[21] | [21] | ||||
""" | """ | ||||
return _reduce(inp, mode="SUM", axis=axis, keepdims=keepdims) | |||||
return inp.sum(axis=axis, keepdims=keepdims) | |||||
def prod( | def prod( | ||||
@@ -215,7 +181,7 @@ def prod( | |||||
[720] | [720] | ||||
""" | """ | ||||
return _reduce(inp, mode="PRODUCT", axis=axis, keepdims=keepdims) | |||||
return inp.prod(axis=axis, keepdims=keepdims) | |||||
def mean( | def mean( | ||||
@@ -248,7 +214,7 @@ def mean( | |||||
[3.5] | [3.5] | ||||
""" | """ | ||||
return _reduce(inp, mode="MEAN", axis=axis, keepdims=keepdims) | |||||
return inp.astype("float32").mean(axis=axis, keepdims=keepdims) | |||||
def median( | def median( | ||||
@@ -362,7 +328,7 @@ def min( | |||||
[1] | [1] | ||||
""" | """ | ||||
return _reduce(inp, mode="MIN", axis=axis, keepdims=keepdims) | |||||
return inp.min(axis=axis, keepdims=keepdims) | |||||
def max( | def max( | ||||
@@ -394,7 +360,7 @@ def max( | |||||
[6] | [6] | ||||
""" | """ | ||||
return _reduce(inp, mode="MAX", axis=axis, keepdims=keepdims) | |||||
return inp.max(axis=axis, keepdims=keepdims) | |||||
def norm( | def norm( | ||||
@@ -580,7 +580,7 @@ def softmax(inp: Tensor, axis: Optional[int] = None) -> Tensor: | |||||
""" | """ | ||||
if axis is None: | if axis is None: | ||||
axis = _get_softmax_axis(len(inp.shape)) | axis = _get_softmax_axis(len(inp.shape)) | ||||
offset = inp.max(axis=axis).detach() | |||||
offset = inp.max(axis=axis, keepdims=True).detach() | |||||
cached = exp(inp - offset) | cached = exp(inp - offset) | ||||
down = sum(cached, axis=axis, keepdims=True) | down = sum(cached, axis=axis, keepdims=True) | ||||
return cached / down | return cached / down | ||||
@@ -38,7 +38,7 @@ def test_reduce(): | |||||
for m in ["sum", "prod", "min", "max", "mean"]: | for m in ["sum", "prod", "min", "max", "mean"]: | ||||
x_np = np.random.rand(10).astype("float32") | x_np = np.random.rand(10).astype("float32") | ||||
x = TensorWrapper(x_np) | x = TensorWrapper(x_np) | ||||
y = getattr(x, m)(-1) | |||||
y = getattr(x, m)(axis=-1, keepdims=True) | |||||
np.testing.assert_almost_equal(y.numpy(), getattr(x_np, m)(-1), decimal=6) | np.testing.assert_almost_equal(y.numpy(), getattr(x_np, m)(-1), decimal=6) | ||||