Browse Source

refactor(mge): use logsumexp in cross_entropy

GitOrigin-RevId: 4a14aabb94
release-1.1
Megvii Engine Team 4 years ago
parent
commit
d31a4fff73
2 changed files with 3 additions and 5 deletions
  1. +2
    -4
      imperative/python/megengine/functional/loss.py
  2. +1
    -1
      imperative/python/megengine/functional/nn.py

+ 2
- 4
imperative/python/megengine/functional/loss.py View File

@@ -11,7 +11,7 @@ import numpy as np
from ..core.tensor.utils import make_shape_tuple
from ..tensor import Tensor
from .elemwise import abs, equal, exp, log, maximum, pow, relu
from .nn import indexing_one_hot, logsigmoid, logsoftmax
from .nn import indexing_one_hot, logsigmoid, logsumexp
from .tensor import where

__all__ = [
@@ -191,9 +191,7 @@ def cross_entropy(
)

# Denominator of the softmax
offset = pred.detach().max(axis=axis, keepdims=True)
pred = pred - offset
down = log(exp(pred).sum(axis=axis, keepdims=True))
down = logsumexp(pred, axis=axis, keepdims=True)

up = indexing_one_hot(pred, label, axis)



+ 1
- 1
imperative/python/megengine/functional/nn.py View File

@@ -546,7 +546,7 @@ def logsumexp(
[-0.5481 4.4519]

"""
max_value = max(inp, axis, keepdims=True)
max_value = max(inp.detach(), axis, keepdims=True)
if keepdims:
return max_value + log(sum(exp(inp - max_value), axis, keepdims))
else:


Loading…
Cancel
Save