Browse Source

feat(mge/functional): add hinge loss

GitOrigin-RevId: 64c89c1f8c
tags/v0.5.0
Megvii Engine Team 5 years ago
parent
commit
3c49d1d324
3 changed files with 71 additions and 1 deletions
  1. +1
    -0
      python_module/megengine/functional/__init__.py
  2. +44
    -1
      python_module/megengine/functional/loss.py
  3. +26
    -0
      python_module/test/unit/functional/test_functional.py

+ 1
- 0
python_module/megengine/functional/__init__.py View File

@@ -43,6 +43,7 @@ from .loss import (
binary_cross_entropy, binary_cross_entropy,
cross_entropy, cross_entropy,
cross_entropy_with_softmax, cross_entropy_with_softmax,
hinge_loss,
l1_loss, l1_loss,
nll_loss, nll_loss,
square_loss, square_loss,


+ 44
- 1
python_module/megengine/functional/loss.py View File

@@ -9,8 +9,9 @@
import megengine._internal as mgb import megengine._internal as mgb


from ..core.tensor import Tensor from ..core.tensor import Tensor
from .elemwise import abs, equal, log, maximum, power
from .elemwise import abs, equal, log, maximum, power, relu
from .nn import assert_equal, indexing_one_hot from .nn import assert_equal, indexing_one_hot
from .tensor import where
from .utils import zero_grad from .utils import zero_grad




@@ -297,3 +298,45 @@ def nll_loss(
loss = indexing_one_hot(pred, label, axis) * mask loss = indexing_one_hot(pred, label, axis) * mask


return -1.0 * loss.sum() / maximum(mask.sum(), 1.0) return -1.0 * loss.sum() / maximum(mask.sum(), 1.0)


def hinge_loss(pred: Tensor, label: Tensor, norm: str = "L1") -> Tensor:
r"""
Caculate the hinge loss which is often used in SVMs.

The hinge loss can be described as:

.. math:: loss(x, y) = \frac{1}{N}\sum_i\sum_j(max(0, 1 - x_i_j*y_i_j))

:param pred: The input tensor representing the predicted probability, shape is (N, C).
:param label: The input tensor representing the binary classification label, shape is (N, C).
:param norm: Specify the norm to caculate the loss, should be "L1" or "L2".

Examples:

.. testcode::

from megengine import tensor
import megengine.functional as F

pred = tensor([[0.5, -0.5, 0.1], [-0.6, 0.7, 0.8]])
label = tensor([[1, -1, -1], [-1, 1, 1]])

loss = F.hinge_loss(pred, label)

print(loss.numpy())

Outputs:

.. testoutput::

[1.5]

"""
assert norm in ["L1", "L2"], "norm must be L1 or L2"
# Converts binary labels to -1/1 labels.
loss = relu(1.0 - pred * label)
if norm == "L1":
return loss.sum(axis=1).mean()
else:
return (loss ** 2).sum(axis=1).mean()

+ 26
- 0
python_module/test/unit/functional/test_functional.py View File

@@ -336,6 +336,32 @@ def test_binary_cross_entropy():
opr_test(cases, F.binary_cross_entropy, compare_fn=compare_fn) opr_test(cases, F.binary_cross_entropy, compare_fn=compare_fn)




def test_hinge_loss():
np.random.seed(123)
# case with L1 norm
cases = []
for shape in [(2, 2), (2, 3)]:
data = np.random.uniform(size=shape).astype(np.float32)
label = 2 * np.random.randint(0, 1, size=shape).astype(np.int32) - 1
expect = np.clip(0, np.inf, 1 - data * label).sum(axis=1).mean()
cases.append({"input": [data, label], "output": tensor(expect)})

opr_test(cases, F.hinge_loss)

# cases with L2 norm
cases = []
for shape in [(2, 2), (2, 3)]:
data = np.random.uniform(size=shape).astype(np.float32)
label = 2 * np.random.randint(0, 1, size=shape).astype(np.int32) - 1
expect = ((np.clip(0, np.inf, 1 - data * label) ** 2).sum(axis=1)).mean()
cases.append({"input": [data, label], "output": tensor(expect)})

def hinge_loss_with_l2_norm(pred, label):
return F.hinge_loss(pred, label, "L2")

opr_test(cases, hinge_loss_with_l2_norm)


@pytest.mark.skip @pytest.mark.skip
def test_conv_bias(): def test_conv_bias():
inp_scale = 0.01 inp_scale = 0.01


Loading…
Cancel
Save