@@ -50,7 +50,19 @@ from .loss import ( | |||||
square_loss, | square_loss, | ||||
triplet_margin_loss, | triplet_margin_loss, | ||||
) | ) | ||||
from .math import argmax, argmin, max, mean, min, norm, normalize, prod, sqrt, sum | |||||
from .math import ( | |||||
argmax, | |||||
argmin, | |||||
logsumexp, | |||||
max, | |||||
mean, | |||||
min, | |||||
norm, | |||||
normalize, | |||||
prod, | |||||
sqrt, | |||||
sum, | |||||
) | |||||
from .nn import ( | from .nn import ( | ||||
assert_equal, | assert_equal, | ||||
avg_pool2d, | avg_pool2d, | ||||
@@ -6,12 +6,15 @@ | |||||
# Unless required by applicable law or agreed to in writing, | # Unless required by applicable law or agreed to in writing, | ||||
# software distributed under the License is distributed on an | # software distributed under the License is distributed on an | ||||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
from typing import Optional | |||||
import math | |||||
import numbers | |||||
from typing import Optional, Sequence, Union | |||||
import megengine._internal as mgb | import megengine._internal as mgb | ||||
from ..core import Tensor, wrap_io_tensor | from ..core import Tensor, wrap_io_tensor | ||||
from .elemwise import clamp | |||||
from .elemwise import clamp, exp, isinf, log | |||||
from .tensor import remove_axis, where, zeros_like | |||||
@wrap_io_tensor | @wrap_io_tensor | ||||
@@ -296,3 +299,35 @@ def normalize( | |||||
return inp / clamp(norm(inp, p), lower=eps) | return inp / clamp(norm(inp, p), lower=eps) | ||||
else: | else: | ||||
return inp / clamp(norm(inp, p, axis, keepdims=True), lower=eps) | return inp / clamp(norm(inp, p, axis, keepdims=True), lower=eps) | ||||
def logsumexp(inp: Tensor, axis: Union[int, Sequence[int]], keepdims: bool = False): | |||||
r""" | |||||
Compute the log of the sum of exponentials of inputs along the given :attr:`axis`. The computation is numerically stabilized. | |||||
.. math:: | |||||
\mathsf{logsumexp}(x_1, \dots, x_n) = \log(\exp(x_1) + \cdots + \exp(x_n)) | |||||
:param inp: The input tensor. | |||||
:param axis: Axis over which the sum is taken. It can be a single axis or a list of axes. | |||||
:param keepdims: whether to retain :attr:`axis` or not for the output tensor. | |||||
""" | |||||
if isinstance(axis, numbers.Integral): | |||||
axis = (axis,) | |||||
max_value = inp | |||||
for dim in axis: | |||||
max_value = max_value.max(axis=dim, keepdims=True) | |||||
max_value = where( | |||||
isinf(max_value).astype("int32"), zeros_like(max_value), max_value | |||||
) | |||||
x = exp(inp - max_value) | |||||
for dim in axis: | |||||
x = x.sum(axis=dim, keepdims=True) | |||||
x = max_value + log(x) | |||||
if not keepdims: | |||||
axis = sorted(axis, reverse=True) | |||||
for i in axis: | |||||
x = remove_axis(x, axis=i) | |||||
return x |
@@ -9,9 +9,12 @@ | |||||
import numpy as np | import numpy as np | ||||
def assertTensorClose(v0, v1, *, max_err=1e-6, name=None): | |||||
def assertTensorClose( | |||||
v0, v1, *, max_err: float = 1e-6, allow_special_values: bool = False, name=None | |||||
): | |||||
""" | """ | ||||
max_err: relative error | |||||
:param allow_special_values: whether to allow :attr:`v0` and :attr:`v1` to contain inf and nan values. | |||||
:param max_err: relative error | |||||
""" | """ | ||||
__tracebackhide__ = True # pylint: disable=unused-variable | __tracebackhide__ = True # pylint: disable=unused-variable | ||||
@@ -20,9 +23,30 @@ def assertTensorClose(v0, v1, *, max_err=1e-6, name=None): | |||||
), "Two Tensor must have same dtype, but the inputs are {} and {}".format( | ), "Two Tensor must have same dtype, but the inputs are {} and {}".format( | ||||
v0.dtype, v1.dtype | v0.dtype, v1.dtype | ||||
) | ) | ||||
v0 = np.ascontiguousarray(v0, dtype=np.float32) | |||||
v1 = np.ascontiguousarray(v1, dtype=np.float32) | |||||
assert np.isfinite(v0.sum()) and np.isfinite(v1.sum()), (v0, v1) | |||||
v0 = np.ascontiguousarray(v0, dtype=np.float32).copy() | |||||
v1 = np.ascontiguousarray(v1, dtype=np.float32).copy() | |||||
if allow_special_values: | |||||
# check nan and rm it | |||||
v0_nan_mask = np.isnan(v0) | |||||
if np.any(v0_nan_mask): | |||||
assert np.array_equiv(v0_nan_mask, np.isnan(v1)), (v0, v1) | |||||
v0[v0_nan_mask] = 0 | |||||
v1[v0_nan_mask] = 0 | |||||
# check inf and rm it | |||||
v0_inf_mask = v0 == float("inf") | |||||
if np.any(v0_inf_mask): | |||||
assert np.array_equiv(v0_inf_mask, v1 == float("inf")), (v0, v1) | |||||
v0[v0_inf_mask] = 0 | |||||
v1[v0_inf_mask] = 0 | |||||
# check -inf and rm it | |||||
v0_inf_mask = v0 == float("-inf") | |||||
if np.any(v0_inf_mask): | |||||
assert np.array_equiv(v0_inf_mask, v1 == float("-inf")), (v0, v1) | |||||
v0[v0_inf_mask] = 0 | |||||
v1[v0_inf_mask] = 0 | |||||
else: | |||||
assert np.isfinite(v0.sum()) and np.isfinite(v1.sum()), (v0, v1) | |||||
assert v0.shape == v1.shape, "Two tensor must have same shape({} v.s. {})".format( | assert v0.shape == v1.shape, "Two tensor must have same shape({} v.s. {})".format( | ||||
v0.shape, v1.shape | v0.shape, v1.shape | ||||
) | ) | ||||
@@ -6,10 +6,14 @@ | |||||
# Unless required by applicable law or agreed to in writing, | # Unless required by applicable law or agreed to in writing, | ||||
# software distributed under the License is distributed on an | # software distributed under the License is distributed on an | ||||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
from functools import partial | |||||
import numpy as np | import numpy as np | ||||
from helpers import opr_test | from helpers import opr_test | ||||
import megengine.functional as F | import megengine.functional as F | ||||
from megengine.test import assertTensorClose | |||||
def common_test_reduce(opr, ref_opr): | def common_test_reduce(opr, ref_opr): | ||||
@@ -86,7 +90,6 @@ def test_sqrt(): | |||||
def test_normalize(): | def test_normalize(): | ||||
from functools import partial | |||||
cases = [ | cases = [ | ||||
{"input": np.random.random((2, 3, 12, 12)).astype(np.float32)} for i in range(2) | {"input": np.random.random((2, 3, 12, 12)).astype(np.float32)} for i in range(2) | ||||
@@ -112,3 +115,54 @@ def test_normalize(): | |||||
cases[0]["input"][0, 0, 0, :] = 0 | cases[0]["input"][0, 0, 0, :] = 0 | ||||
cases[1]["input"][0, 0, 0, :] = 0 | cases[1]["input"][0, 0, 0, :] = 0 | ||||
opr_test(cases, partial(F.normalize, axis=3), ref_fn=partial(np_normalize, axis=3)) | opr_test(cases, partial(F.normalize, axis=3), ref_fn=partial(np_normalize, axis=3)) | ||||
def test_logsumexp(): | |||||
x = np.arange(10).astype(np.float32) | |||||
expected = np.log(np.sum(np.exp(x))) | |||||
cases = [{"input": x, "output": expected}] | |||||
compare_fn = partial(assertTensorClose, allow_special_values=True) | |||||
# large value check | |||||
n = 100 | |||||
x = np.full(n, 10000, dtype=np.float32) | |||||
expected = 10000 + np.log(n) | |||||
cases.append({"input": x, "output": expected.astype(np.float32)}) | |||||
opr_test(cases, F.logsumexp, axis=0, compare_fn=compare_fn) | |||||
# special value check | |||||
x = np.array([np.inf], dtype=np.float32) | |||||
expected = x | |||||
cases = [{"input": x, "output": expected}] | |||||
x = np.array([-np.inf, 0.0], dtype=np.float32) | |||||
expected = np.zeros(1).astype(np.float32) | |||||
cases.append({"input": x, "output": expected}) | |||||
opr_test(cases, F.logsumexp, axis=0, compare_fn=compare_fn) | |||||
x = np.array([np.nan], dtype=np.float32) | |||||
expected = x | |||||
cases = [{"input": x, "output": expected}] | |||||
x = np.array([-np.inf, 1], dtype=np.float32) | |||||
expected = np.array([1.0], dtype=np.float32) | |||||
cases.append({"input": x, "output": expected}) | |||||
opr_test(cases, F.logsumexp, axis=0, compare_fn=compare_fn) | |||||
# keepdims check | |||||
x = np.array([[1e10, 1e-10], [-1e10, -np.inf]], dtype=np.float32) | |||||
expected = np.array([[1e10], [-1e10]], dtype=np.float32) | |||||
cases = [{"input": x, "output": expected}] | |||||
x = np.array([[1e10, -1e-10, 1e-10], [1e10, 1e-10, np.inf]], dtype=np.float32) | |||||
expected = np.array([[1e10], [np.inf]], dtype=np.float32) | |||||
cases.append({"input": x, "output": expected}) | |||||
opr_test(cases, F.logsumexp, axis=1, keepdims=True, compare_fn=compare_fn) | |||||
# multiple axes check | |||||
x = np.array([[1e10, 1e-10], [-1e10, -np.inf]], dtype=np.float32) | |||||
expected = np.array([1e10], dtype=np.float32) | |||||
cases = [{"input": x, "output": expected}] | |||||
x = np.array([[1e10, -1e-10, 1e-10], [1e10, 1e-10, np.inf]], dtype=np.float32) | |||||
expected = np.array([np.inf], dtype=np.float32) | |||||
cases.append({"input": x, "output": expected}) | |||||
opr_test(cases, F.logsumexp, axis=(0, 1), keepdims=False, compare_fn=compare_fn) |