diff --git a/imperative/python/megengine/functional/loss.py b/imperative/python/megengine/functional/loss.py index 8ed5958a..40da78d8 100644 --- a/imperative/python/megengine/functional/loss.py +++ b/imperative/python/megengine/functional/loss.py @@ -11,7 +11,7 @@ import numpy as np from ..core.tensor.utils import make_shape_tuple from ..tensor import Tensor from .elemwise import abs, equal, exp, log, maximum, pow, relu -from .nn import indexing_one_hot, logsigmoid, logsoftmax +from .nn import indexing_one_hot, logsigmoid, logsumexp from .tensor import where __all__ = [ @@ -191,9 +191,7 @@ def cross_entropy( ) # Denominator of the softmax - offset = pred.detach().max(axis=axis, keepdims=True) - pred = pred - offset - down = log(exp(pred).sum(axis=axis, keepdims=True)) + down = logsumexp(pred, axis=axis, keepdims=True) up = indexing_one_hot(pred, label, axis) diff --git a/imperative/python/megengine/functional/nn.py b/imperative/python/megengine/functional/nn.py index 3fc2e223..6ba8ee0e 100644 --- a/imperative/python/megengine/functional/nn.py +++ b/imperative/python/megengine/functional/nn.py @@ -546,7 +546,7 @@ def logsumexp( [-0.5481 4.4519] """ - max_value = max(inp, axis, keepdims=True) + max_value = max(inp.detach(), axis, keepdims=True) if keepdims: return max_value + log(sum(exp(inp - max_value), axis, keepdims)) else: