|
|
@@ -176,30 +176,19 @@ def cross_entropy( |
|
|
|
"target_ndim={}".format(n0, n1) |
|
|
|
) |
|
|
|
|
|
|
|
num_classes = pred.shape[axis] |
|
|
|
no_label_smooth = ( |
|
|
|
label_smooth is None or type(label_smooth) in (int, float) and label_smooth == 0 |
|
|
|
) |
|
|
|
ls = label_smooth |
|
|
|
|
|
|
|
if with_logits: |
|
|
|
logZ = logsumexp(pred, axis).mean() |
|
|
|
primary_term = indexing_one_hot(pred, label, axis).mean() |
|
|
|
else: |
|
|
|
logZ = 0 |
|
|
|
primary_term = log(indexing_one_hot(pred, label, axis)).mean() |
|
|
|
if ls is None or type(ls) in (int, float) and ls == 0: |
|
|
|
return logZ - primary_term |
|
|
|
if not with_logits: |
|
|
|
if no_label_smooth: |
|
|
|
return -log(indexing_one_hot(pred, label, axis)).mean() |
|
|
|
pred = log(pred) |
|
|
|
return ( |
|
|
|
label_smooth * pred.mean() |
|
|
|
- (1 - label_smooth) * indexing_one_hot(pred, label, axis).mean() |
|
|
|
) |
|
|
|
|
|
|
|
# Denominator of the softmax |
|
|
|
down = logsumexp(pred, axis=axis, keepdims=True) |
|
|
|
|
|
|
|
up = indexing_one_hot(pred, label, axis) |
|
|
|
|
|
|
|
if not no_label_smooth: |
|
|
|
factor = label_smooth / num_classes |
|
|
|
up = up * (1 - label_smooth) + pred.sum(axis=axis, keepdims=True) * factor |
|
|
|
|
|
|
|
return (down - up).mean() |
|
|
|
return logZ - ls * pred.mean() - (1 - ls) * primary_term |
|
|
|
|
|
|
|
|
|
|
|
def binary_cross_entropy( |
|
|
|