GitOrigin-RevId: 8ed95e0fb8
tags/v1.0.0-rc1
@@ -134,15 +134,54 @@ def _logical_binary_elwise(mode, rev=False): | |||
return f | |||
def _remove_axis(inp: Tensor, axis) -> Tensor: | |||
Param = builtin.AxisAddRemove.Param | |||
def get_axes(): | |||
if axis is None: | |||
return [i for i, s in enumerate(inp.shape) if s == 1] | |||
try: | |||
return [int(axis)] | |||
except (TypeError, ValueError): | |||
pass | |||
return list(map(int, axis)) | |||
axis = get_axes() | |||
axis = sorted(i + inp.ndim if i < 0 else i for i in axis) | |||
axis = [a - i for i, a in enumerate(axis)] | |||
param = Param(*map(builtin.AxisAddRemove.AxisDesc.make_remove, axis)) | |||
op = builtin.AxisAddRemove(param=param) | |||
(result,) = apply(op, inp) | |||
return result | |||
def _reduce(mode): | |||
def f(self, axis=None): | |||
inp = self | |||
def f(self, axis=None, keepdims: bool = False): | |||
data = self | |||
(data,) = utils.convert_inputs(data) | |||
if axis is None: | |||
inp = self.flatten() | |||
axis = 0 | |||
op = builtin.Reduce(mode=mode, axis=axis) | |||
(result,) = utils.convert_inputs(inp) | |||
(result,) = apply(op, result) | |||
data = data.reshape(-1) | |||
assert not keepdims, "can not set axis=None and keepdims=True" | |||
op = builtin.Reduce(mode=mode, axis=0) | |||
(result,) = apply(op, data) | |||
elif isinstance(axis, collections.Iterable): | |||
axis = list(axis) | |||
axis.sort(reverse=True) | |||
for ai in axis: | |||
op = builtin.Reduce(mode=mode, axis=ai) | |||
(data,) = apply(op, data) | |||
if not keepdims: | |||
data = _remove_axis(data, ai) | |||
result = data | |||
else: | |||
op = builtin.Reduce(mode=mode, axis=axis) | |||
(result,) = apply(op, data) | |||
if not keepdims: | |||
result = _remove_axis(result, axis) | |||
return result | |||
return f | |||
@@ -176,15 +176,15 @@ def cross_entropy_with_softmax( | |||
num_classes = pred.shape[axis] | |||
# Denominator of the softmax | |||
offset = pred.max(axis=axis).detach() | |||
offset = pred.max(axis=axis, keepdims=True).detach() | |||
pred = pred - offset | |||
down = exp(pred).sum(axis=axis) | |||
down = exp(pred).sum(axis=axis, keepdims=True) | |||
up = indexing_one_hot(pred, label, axis) | |||
if label_smooth != 0: | |||
factor = label_smooth / num_classes | |||
up = up * (1 - label_smooth) + pred.sum(axis=axis) * factor | |||
up = up * (1 - label_smooth) + pred.sum(axis=axis, keepdims=True) * factor | |||
return (log(down) - up).mean() | |||
@@ -117,40 +117,6 @@ def sign(inp: Tensor): | |||
raise NotImplementedError | |||
def _reduce( | |||
data, | |||
*, | |||
mode, | |||
axis: Optional[Union[int, Sequence[int]]] = None, | |||
keepdims: bool = False | |||
): | |||
(data,) = utils.convert_inputs(data) | |||
if axis is None: | |||
data = data.reshape(-1) | |||
assert not keepdims, "can not set axis=None and keepdims=True" | |||
op = builtin.Reduce(mode=mode, axis=0) | |||
(result,) = apply(op, data) | |||
elif isinstance(axis, collections.Iterable): | |||
axis = list(axis) | |||
axis.sort(reverse=True) | |||
for ai in axis: | |||
op = builtin.Reduce(mode=mode, axis=ai) | |||
(data,) = apply(op, data) | |||
if not keepdims: | |||
data = remove_axis(data, ai) | |||
result = data | |||
else: | |||
op = builtin.Reduce(mode=mode, axis=axis) | |||
(result,) = apply(op, data) | |||
if not keepdims: | |||
result = remove_axis(result, axis) | |||
return result | |||
def sum( | |||
inp: Tensor, | |||
axis: Optional[Union[int, Sequence[int]]] = None, | |||
@@ -182,7 +148,7 @@ def sum( | |||
[21] | |||
""" | |||
return _reduce(inp, mode="SUM", axis=axis, keepdims=keepdims) | |||
return inp.sum(axis=axis, keepdims=keepdims) | |||
def prod( | |||
@@ -215,7 +181,7 @@ def prod( | |||
[720] | |||
""" | |||
return _reduce(inp, mode="PRODUCT", axis=axis, keepdims=keepdims) | |||
return inp.prod(axis=axis, keepdims=keepdims) | |||
def mean( | |||
@@ -248,7 +214,7 @@ def mean( | |||
[3.5] | |||
""" | |||
return _reduce(inp, mode="MEAN", axis=axis, keepdims=keepdims) | |||
return inp.astype("float32").mean(axis=axis, keepdims=keepdims) | |||
def median( | |||
@@ -362,7 +328,7 @@ def min( | |||
[1] | |||
""" | |||
return _reduce(inp, mode="MIN", axis=axis, keepdims=keepdims) | |||
return inp.min(axis=axis, keepdims=keepdims) | |||
def max( | |||
@@ -394,7 +360,7 @@ def max( | |||
[6] | |||
""" | |||
return _reduce(inp, mode="MAX", axis=axis, keepdims=keepdims) | |||
return inp.max(axis=axis, keepdims=keepdims) | |||
def norm( | |||
@@ -580,7 +580,7 @@ def softmax(inp: Tensor, axis: Optional[int] = None) -> Tensor: | |||
""" | |||
if axis is None: | |||
axis = _get_softmax_axis(len(inp.shape)) | |||
offset = inp.max(axis=axis).detach() | |||
offset = inp.max(axis=axis, keepdims=True).detach() | |||
cached = exp(inp - offset) | |||
down = sum(cached, axis=axis, keepdims=True) | |||
return cached / down | |||
@@ -38,7 +38,7 @@ def test_reduce(): | |||
for m in ["sum", "prod", "min", "max", "mean"]: | |||
x_np = np.random.rand(10).astype("float32") | |||
x = TensorWrapper(x_np) | |||
y = getattr(x, m)(-1) | |||
y = getattr(x, m)(axis=-1, keepdims=True) | |||
np.testing.assert_almost_equal(y.numpy(), getattr(x_np, m)(-1), decimal=6) | |||