|
@@ -11,6 +11,7 @@ from typing import Optional |
|
|
import megengine._internal as mgb |
|
|
import megengine._internal as mgb |
|
|
|
|
|
|
|
|
from ..core import Tensor, wrap_io_tensor |
|
|
from ..core import Tensor, wrap_io_tensor |
|
|
|
|
|
from .elemwise import clamp |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@wrap_io_tensor |
|
|
@wrap_io_tensor |
|
@@ -199,8 +200,7 @@ def sqrt(inp: Tensor) -> Tensor: |
|
|
return mgb.opr.sqrt(inp) |
|
|
return mgb.opr.sqrt(inp) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@wrap_io_tensor |
|
|
|
|
|
def norm(inp: Tensor, p=2, axis: Optional[int] = None, keepdims=False): |
|
|
|
|
|
|
|
|
def norm(inp: Tensor, p: int = 2, axis: Optional[int] = None, keepdims=False): |
|
|
"""Calculate ``p``-norm of input tensor along certain axis. |
|
|
"""Calculate ``p``-norm of input tensor along certain axis. |
|
|
|
|
|
|
|
|
:param inp: The input tensor |
|
|
:param inp: The input tensor |
|
@@ -271,3 +271,28 @@ def argmax(inp: Tensor, axis: Optional[int] = None, keepdims: bool = False) -> T |
|
|
|
|
|
|
|
|
""" |
|
|
""" |
|
|
return mgb.opr.argmax(inp, axis, keepdims) |
|
|
return mgb.opr.argmax(inp, axis, keepdims) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def normalize( |
|
|
|
|
|
inp: Tensor, p: int = 2, axis: Optional[int] = None, eps: float = 1e-12 |
|
|
|
|
|
) -> Tensor: |
|
|
|
|
|
r"""Perform :math:`L_p` normalization of input tensor along certain axis. |
|
|
|
|
|
|
|
|
|
|
|
For a tensor :attr:`inp` of shape :math:`(n_0, ..., n_{dim}, ..., n_k)`, each |
|
|
|
|
|
:math:`n_{dim}` -element vector :math:`v` along dimension :attr:`axis` is transformed as: |
|
|
|
|
|
|
|
|
|
|
|
.. math:: |
|
|
|
|
|
v = \frac{v}{\max(\lVert v \rVert_p, \epsilon)}. |
|
|
|
|
|
|
|
|
|
|
|
:param inp: the input tensor |
|
|
|
|
|
:param p: power of value ``p`` applied to ``inp``. Default: 2 |
|
|
|
|
|
:param axis: The dimension to reduce. If None, all the dimensions will be reduced |
|
|
|
|
|
to calculate the norm. Default: None |
|
|
|
|
|
:param eps: a small value to avoid division by zero. Default: 1e-12 |
|
|
|
|
|
:return: the normalized output tensor |
|
|
|
|
|
|
|
|
|
|
|
""" |
|
|
|
|
|
if axis is None: |
|
|
|
|
|
return inp / clamp(norm(inp, p), lower=eps) |
|
|
|
|
|
else: |
|
|
|
|
|
return inp / clamp(norm(inp, p, axis, keepdims=True), lower=eps) |