Browse Source

refactor(mge): use logsumexp in cross_entropy

GitOrigin-RevId: 4a14aabb94
release-1.1
Megvii Engine Team 4 years ago
parent
commit
d31a4fff73
2 changed files with 3 additions and 5 deletions
  1. +2
    -4
      imperative/python/megengine/functional/loss.py
  2. +1
    -1
      imperative/python/megengine/functional/nn.py

+ 2
- 4
imperative/python/megengine/functional/loss.py View File

@@ -11,7 +11,7 @@ import numpy as np
from ..core.tensor.utils import make_shape_tuple from ..core.tensor.utils import make_shape_tuple
from ..tensor import Tensor from ..tensor import Tensor
from .elemwise import abs, equal, exp, log, maximum, pow, relu from .elemwise import abs, equal, exp, log, maximum, pow, relu
from .nn import indexing_one_hot, logsigmoid, logsoftmax
from .nn import indexing_one_hot, logsigmoid, logsumexp
from .tensor import where from .tensor import where


__all__ = [ __all__ = [
@@ -191,9 +191,7 @@ def cross_entropy(
) )


# Denominator of the softmax # Denominator of the softmax
offset = pred.detach().max(axis=axis, keepdims=True)
pred = pred - offset
down = log(exp(pred).sum(axis=axis, keepdims=True))
down = logsumexp(pred, axis=axis, keepdims=True)


up = indexing_one_hot(pred, label, axis) up = indexing_one_hot(pred, label, axis)




+ 1
- 1
imperative/python/megengine/functional/nn.py View File

@@ -546,7 +546,7 @@ def logsumexp(
[-0.5481 4.4519] [-0.5481 4.4519]


""" """
max_value = max(inp, axis, keepdims=True)
max_value = max(inp.detach(), axis, keepdims=True)
if keepdims: if keepdims:
return max_value + log(sum(exp(inp - max_value), axis, keepdims)) return max_value + log(sum(exp(inp - max_value), axis, keepdims))
else: else:


Loading…
Cancel
Save