|
@@ -11,7 +11,7 @@ import numpy as np |
|
|
from ..core.tensor.utils import make_shape_tuple |
|
|
from ..core.tensor.utils import make_shape_tuple |
|
|
from ..tensor import Tensor |
|
|
from ..tensor import Tensor |
|
|
from .elemwise import abs, equal, exp, log, maximum, pow, relu |
|
|
from .elemwise import abs, equal, exp, log, maximum, pow, relu |
|
|
from .nn import indexing_one_hot, logsigmoid, logsoftmax |
|
|
|
|
|
|
|
|
from .nn import indexing_one_hot, logsigmoid, logsumexp |
|
|
from .tensor import where |
|
|
from .tensor import where |
|
|
|
|
|
|
|
|
__all__ = [ |
|
|
__all__ = [ |
|
@@ -191,9 +191,7 @@ def cross_entropy( |
|
|
) |
|
|
) |
|
|
|
|
|
|
|
|
# Denominator of the softmax |
|
|
# Denominator of the softmax |
|
|
offset = pred.detach().max(axis=axis, keepdims=True) |
|
|
|
|
|
pred = pred - offset |
|
|
|
|
|
down = log(exp(pred).sum(axis=axis, keepdims=True)) |
|
|
|
|
|
|
|
|
down = logsumexp(pred, axis=axis, keepdims=True) |
|
|
|
|
|
|
|
|
up = indexing_one_hot(pred, label, axis) |
|
|
up = indexing_one_hot(pred, label, axis) |
|
|
|
|
|
|
|
|