@@ -50,7 +50,19 @@ from .loss import ( | |||
square_loss, | |||
triplet_margin_loss, | |||
) | |||
from .math import argmax, argmin, max, mean, min, norm, normalize, prod, sqrt, sum | |||
from .math import ( | |||
argmax, | |||
argmin, | |||
logsumexp, | |||
max, | |||
mean, | |||
min, | |||
norm, | |||
normalize, | |||
prod, | |||
sqrt, | |||
sum, | |||
) | |||
from .nn import ( | |||
assert_equal, | |||
avg_pool2d, | |||
@@ -6,12 +6,15 @@ | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
from typing import Optional | |||
import math | |||
import numbers | |||
from typing import Optional, Sequence, Union | |||
import megengine._internal as mgb | |||
from ..core import Tensor, wrap_io_tensor | |||
from .elemwise import clamp | |||
from .elemwise import clamp, exp, isinf, log | |||
from .tensor import remove_axis, where, zeros_like | |||
@wrap_io_tensor | |||
@@ -296,3 +299,35 @@ def normalize( | |||
return inp / clamp(norm(inp, p), lower=eps) | |||
else: | |||
return inp / clamp(norm(inp, p, axis, keepdims=True), lower=eps) | |||
def logsumexp(inp: Tensor, axis: Union[int, Sequence[int]], keepdims: bool = False): | |||
r""" | |||
Compute the log of the sum of exponentials of inputs along the given :attr:`axis`. The computation is numerically stabilized. | |||
.. math:: | |||
\mathsf{logsumexp}(x_1, \dots, x_n) = \log(\exp(x_1) + \cdots + \exp(x_n)) | |||
:param inp: The input tensor. | |||
:param axis: Axis over which the sum is taken. It can be a single axis or a list of axes. | |||
:param keepdims: whether to retain :attr:`axis` or not for the output tensor. | |||
""" | |||
if isinstance(axis, numbers.Integral): | |||
axis = (axis,) | |||
max_value = inp | |||
for dim in axis: | |||
max_value = max_value.max(axis=dim, keepdims=True) | |||
max_value = where( | |||
isinf(max_value).astype("int32"), zeros_like(max_value), max_value | |||
) | |||
x = exp(inp - max_value) | |||
for dim in axis: | |||
x = x.sum(axis=dim, keepdims=True) | |||
x = max_value + log(x) | |||
if not keepdims: | |||
axis = sorted(axis, reverse=True) | |||
for i in axis: | |||
x = remove_axis(x, axis=i) | |||
return x |
@@ -9,9 +9,12 @@ | |||
import numpy as np | |||
def assertTensorClose(v0, v1, *, max_err=1e-6, name=None): | |||
def assertTensorClose( | |||
v0, v1, *, max_err: float = 1e-6, allow_special_values: bool = False, name=None | |||
): | |||
""" | |||
max_err: relative error | |||
:param allow_special_values: whether to allow :attr:`v0` and :attr:`v1` to contain inf and nan values. | |||
:param max_err: relative error | |||
""" | |||
__tracebackhide__ = True # pylint: disable=unused-variable | |||
@@ -20,9 +23,30 @@ def assertTensorClose(v0, v1, *, max_err=1e-6, name=None): | |||
), "Two Tensor must have same dtype, but the inputs are {} and {}".format( | |||
v0.dtype, v1.dtype | |||
) | |||
v0 = np.ascontiguousarray(v0, dtype=np.float32) | |||
v1 = np.ascontiguousarray(v1, dtype=np.float32) | |||
assert np.isfinite(v0.sum()) and np.isfinite(v1.sum()), (v0, v1) | |||
v0 = np.ascontiguousarray(v0, dtype=np.float32).copy() | |||
v1 = np.ascontiguousarray(v1, dtype=np.float32).copy() | |||
if allow_special_values: | |||
# check nan and rm it | |||
v0_nan_mask = np.isnan(v0) | |||
if np.any(v0_nan_mask): | |||
assert np.array_equiv(v0_nan_mask, np.isnan(v1)), (v0, v1) | |||
v0[v0_nan_mask] = 0 | |||
v1[v0_nan_mask] = 0 | |||
# check inf and rm it | |||
v0_inf_mask = v0 == float("inf") | |||
if np.any(v0_inf_mask): | |||
assert np.array_equiv(v0_inf_mask, v1 == float("inf")), (v0, v1) | |||
v0[v0_inf_mask] = 0 | |||
v1[v0_inf_mask] = 0 | |||
# check -inf and rm it | |||
v0_inf_mask = v0 == float("-inf") | |||
if np.any(v0_inf_mask): | |||
assert np.array_equiv(v0_inf_mask, v1 == float("-inf")), (v0, v1) | |||
v0[v0_inf_mask] = 0 | |||
v1[v0_inf_mask] = 0 | |||
else: | |||
assert np.isfinite(v0.sum()) and np.isfinite(v1.sum()), (v0, v1) | |||
assert v0.shape == v1.shape, "Two tensor must have same shape({} v.s. {})".format( | |||
v0.shape, v1.shape | |||
) | |||
@@ -6,10 +6,14 @@ | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
from functools import partial | |||
import numpy as np | |||
from helpers import opr_test | |||
import megengine.functional as F | |||
from megengine.test import assertTensorClose | |||
def common_test_reduce(opr, ref_opr): | |||
@@ -86,7 +90,6 @@ def test_sqrt(): | |||
def test_normalize(): | |||
from functools import partial | |||
cases = [ | |||
{"input": np.random.random((2, 3, 12, 12)).astype(np.float32)} for i in range(2) | |||
@@ -112,3 +115,54 @@ def test_normalize(): | |||
cases[0]["input"][0, 0, 0, :] = 0 | |||
cases[1]["input"][0, 0, 0, :] = 0 | |||
opr_test(cases, partial(F.normalize, axis=3), ref_fn=partial(np_normalize, axis=3)) | |||
def test_logsumexp(): | |||
x = np.arange(10).astype(np.float32) | |||
expected = np.log(np.sum(np.exp(x))) | |||
cases = [{"input": x, "output": expected}] | |||
compare_fn = partial(assertTensorClose, allow_special_values=True) | |||
# large value check | |||
n = 100 | |||
x = np.full(n, 10000, dtype=np.float32) | |||
expected = 10000 + np.log(n) | |||
cases.append({"input": x, "output": expected.astype(np.float32)}) | |||
opr_test(cases, F.logsumexp, axis=0, compare_fn=compare_fn) | |||
# special value check | |||
x = np.array([np.inf], dtype=np.float32) | |||
expected = x | |||
cases = [{"input": x, "output": expected}] | |||
x = np.array([-np.inf, 0.0], dtype=np.float32) | |||
expected = np.zeros(1).astype(np.float32) | |||
cases.append({"input": x, "output": expected}) | |||
opr_test(cases, F.logsumexp, axis=0, compare_fn=compare_fn) | |||
x = np.array([np.nan], dtype=np.float32) | |||
expected = x | |||
cases = [{"input": x, "output": expected}] | |||
x = np.array([-np.inf, 1], dtype=np.float32) | |||
expected = np.array([1.0], dtype=np.float32) | |||
cases.append({"input": x, "output": expected}) | |||
opr_test(cases, F.logsumexp, axis=0, compare_fn=compare_fn) | |||
# keepdims check | |||
x = np.array([[1e10, 1e-10], [-1e10, -np.inf]], dtype=np.float32) | |||
expected = np.array([[1e10], [-1e10]], dtype=np.float32) | |||
cases = [{"input": x, "output": expected}] | |||
x = np.array([[1e10, -1e-10, 1e-10], [1e10, 1e-10, np.inf]], dtype=np.float32) | |||
expected = np.array([[1e10], [np.inf]], dtype=np.float32) | |||
cases.append({"input": x, "output": expected}) | |||
opr_test(cases, F.logsumexp, axis=1, keepdims=True, compare_fn=compare_fn) | |||
# multiple axes check | |||
x = np.array([[1e10, 1e-10], [-1e10, -np.inf]], dtype=np.float32) | |||
expected = np.array([1e10], dtype=np.float32) | |||
cases = [{"input": x, "output": expected}] | |||
x = np.array([[1e10, -1e-10, 1e-10], [1e10, 1e-10, np.inf]], dtype=np.float32) | |||
expected = np.array([np.inf], dtype=np.float32) | |||
cases.append({"input": x, "output": expected}) | |||
opr_test(cases, F.logsumexp, axis=(0, 1), keepdims=False, compare_fn=compare_fn) |