From 5507a29bd339ccf156a6fe114967aeb9073c0def Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Wed, 22 Apr 2020 10:14:49 +0800 Subject: [PATCH] feat(mge/functional): add normalize opr GitOrigin-RevId: 572a32f5633bfd775e5cc2209e4c9005f49e3761 --- python_module/megengine/functional/__init__.py | 2 +- python_module/megengine/functional/math.py | 29 +++++++++++++++++++++++-- python_module/test/unit/functional/test_math.py | 29 +++++++++++++++++++++++++ 3 files changed, 57 insertions(+), 3 deletions(-) diff --git a/python_module/megengine/functional/__init__.py b/python_module/megengine/functional/__init__.py index 6b262bfd..f84181e6 100644 --- a/python_module/megengine/functional/__init__.py +++ b/python_module/megengine/functional/__init__.py @@ -48,7 +48,7 @@ from .loss import ( square_loss, triplet_margin_loss, ) -from .math import argmax, argmin, max, mean, min, norm, prod, sqrt, sum +from .math import argmax, argmin, max, mean, min, norm, normalize, prod, sqrt, sum from .nn import ( assert_equal, avg_pool2d, diff --git a/python_module/megengine/functional/math.py b/python_module/megengine/functional/math.py index 2ebac2f6..8bcc2cc1 100644 --- a/python_module/megengine/functional/math.py +++ b/python_module/megengine/functional/math.py @@ -11,6 +11,7 @@ from typing import Optional import megengine._internal as mgb from ..core import Tensor, wrap_io_tensor +from .elemwise import clamp @wrap_io_tensor @@ -199,8 +200,7 @@ def sqrt(inp: Tensor) -> Tensor: return mgb.opr.sqrt(inp) -@wrap_io_tensor -def norm(inp: Tensor, p=2, axis: Optional[int] = None, keepdims=False): +def norm(inp: Tensor, p: int = 2, axis: Optional[int] = None, keepdims=False): """Calculate ``p``-norm of input tensor along certain axis. :param inp: The input tensor @@ -271,3 +271,28 @@ def argmax(inp: Tensor, axis: Optional[int] = None, keepdims: bool = False) -> T """ return mgb.opr.argmax(inp, axis, keepdims) + + +def normalize( + inp: Tensor, p: int = 2, axis: Optional[int] = None, eps: float = 1e-12 +) -> Tensor: + r"""Perform :math:`L_p` normalization of input tensor along certain axis. + + For a tensor :attr:`inp` of shape :math:`(n_0, ..., n_{dim}, ..., n_k)`, each + :math:`n_{dim}` -element vector :math:`v` along dimension :attr:`axis` is transformed as: + + .. math:: + v = \frac{v}{\max(\lVert v \rVert_p, \epsilon)}. + + :param inp: the input tensor + :param p: power of value ``p`` applied to ``inp``. Default: 2 + :param axis: The dimension to reduce. If None, all the dimensions will be reduced + to calculate the norm. Default: None + :param eps: a small value to avoid division by zero. Default: 1e-12 + :return: the normalized output tensor + + """ + if axis is None: + return inp / clamp(norm(inp, p), lower=eps) + else: + return inp / clamp(norm(inp, p, axis, keepdims=True), lower=eps) diff --git a/python_module/test/unit/functional/test_math.py b/python_module/test/unit/functional/test_math.py index b5cb4a49..9354fee8 100644 --- a/python_module/test/unit/functional/test_math.py +++ b/python_module/test/unit/functional/test_math.py @@ -83,3 +83,32 @@ def test_sqrt(): cases = [{"input": d1}, {"input": d2}] opr_test(cases, F.sqrt, ref_fn=np.sqrt) + + +def test_normalize(): + from functools import partial + + cases = [ + {"input": np.random.random((2, 3, 12, 12)).astype(np.float32)} for i in range(2) + ] + + def np_normalize(x, p=2, axis=None, eps=1e-12): + if axis is None: + norm = np.sum(x ** p) ** (1.0 / p) + else: + norm = np.sum(x ** p, axis=axis, keepdims=True) ** (1.0 / p) + return x / np.clip(norm, a_min=eps, a_max=np.inf) + + # Test L-2 norm along all dimensions + opr_test(cases, F.normalize, ref_fn=np_normalize) + + # Test L-1 norm along all dimensions + opr_test(cases, partial(F.normalize, p=1), ref_fn=partial(np_normalize, p=1)) + + # Test L-2 norm along the second dimension + opr_test(cases, partial(F.normalize, axis=1), ref_fn=partial(np_normalize, axis=1)) + + # Test some norm == 0 + cases[0]["input"][0, 0, 0, :] = 0 + cases[1]["input"][0, 0, 0, :] = 0 + opr_test(cases, partial(F.normalize, axis=3), ref_fn=partial(np_normalize, axis=3))