|
|
@@ -12,9 +12,9 @@ import numpy as np |
|
|
|
|
|
|
|
from ..core.tensor.array_method import _reduce |
|
|
|
from ..tensor import Tensor |
|
|
|
from .elemwise import abs, log |
|
|
|
from .elemwise import abs, equal, log, logaddexp, maximum |
|
|
|
from .nn import indexing_one_hot, logsigmoid, logsumexp, relu |
|
|
|
from .tensor import where |
|
|
|
from .tensor import broadcast_to, cumsum, linspace, ones, where, zeros |
|
|
|
|
|
|
|
__all__ = [ |
|
|
|
"l1_loss", |
|
|
@@ -22,6 +22,7 @@ __all__ = [ |
|
|
|
"cross_entropy", |
|
|
|
"binary_cross_entropy", |
|
|
|
"hinge_loss", |
|
|
|
"ctc_loss", |
|
|
|
] |
|
|
|
|
|
|
|
|
|
|
@@ -316,3 +317,164 @@ def hinge_loss( |
|
|
|
return loss.sum(axis=1) |
|
|
|
else: |
|
|
|
return (loss ** 2).sum(axis=1) |
|
|
|
|
|
|
|
|
|
|
|
def _gen_repeat_idx(inp: Tensor): |
|
|
|
idx = cumsum(inp, axis=0) |
|
|
|
ret = zeros(inp.sum(), dtype="int32") |
|
|
|
ret[idx[:-1]] = 1 |
|
|
|
return cumsum(ret, axis=0) |
|
|
|
|
|
|
|
|
|
|
|
def _gen_tile_idx(inp: Tensor): |
|
|
|
idx = cumsum(inp, axis=0) |
|
|
|
ret = ones(inp.sum(), dtype="int32") |
|
|
|
ret[idx[:-1]] = -(inp - 1)[:-1] |
|
|
|
return cumsum(ret, axis=0) - 1 |
|
|
|
|
|
|
|
|
|
|
|
def _expand_label(label: Tensor, label_lengths: Tensor, blank: int) -> Tensor: |
|
|
|
N = label_lengths.shape[0] |
|
|
|
if len(label.shape) == 1: |
|
|
|
L = label_lengths.max() |
|
|
|
unpack_label = zeros((N, L), dtype="int32") + blank |
|
|
|
idx_0 = _gen_repeat_idx(label_lengths) |
|
|
|
idx_1 = _gen_tile_idx(label_lengths) |
|
|
|
unpack_label[idx_0, idx_1] = label |
|
|
|
label = unpack_label |
|
|
|
|
|
|
|
L = label.shape[1] |
|
|
|
ex_label = zeros((N, L * 2 + 1), dtype="int32") + blank |
|
|
|
ex_label[:, 1::2] = label |
|
|
|
return ex_label |
|
|
|
|
|
|
|
|
|
|
|
def _safelog(x: Tensor) -> Tensor: |
|
|
|
eps = np.finfo(x.dtype).tiny |
|
|
|
return log(maximum(x, eps)) |
|
|
|
|
|
|
|
|
|
|
|
def ctc_loss( |
|
|
|
pred: Tensor, |
|
|
|
pred_lengths: Tensor, |
|
|
|
label: Tensor, |
|
|
|
label_lengths: Tensor, |
|
|
|
blank: int = 0, |
|
|
|
reduction: str = "mean", |
|
|
|
) -> Tensor: |
|
|
|
r"""The Connectionist Temporal Classification loss. |
|
|
|
|
|
|
|
|
|
|
|
Args: |
|
|
|
pred: The probabilities of the output, shape is (T, N, C) , |
|
|
|
where T=input length, N=batch size, and C=number of classes (including blank). |
|
|
|
pred_lengths: number of time steps for each sequence in ``pred``, shape is (N, ) |
|
|
|
label: groundtruth labels, containing the indices of groundtruth |
|
|
|
symbols for each sequence at each output time step, and the blank |
|
|
|
symbol should not be included. shape is (N, S) or (sum(label_lengths)). |
|
|
|
label_lengths: number of time steps for each sequence in the groundtruth, shape is (N, ) |
|
|
|
blank: the blank symbol number, default 0 |
|
|
|
reduction: the reduction to apply to the output: 'none' | 'mean' | 'sum'. Default: 'mean' |
|
|
|
|
|
|
|
Returns: |
|
|
|
loss value. |
|
|
|
|
|
|
|
Examples: |
|
|
|
|
|
|
|
.. testcode:: |
|
|
|
|
|
|
|
from megengine import tensor |
|
|
|
import megengine.functional as F |
|
|
|
|
|
|
|
pred = tensor([[[0.0614, 0.9386],[0.8812, 0.1188]],[[0.699, 0.301 ],[0.2572, 0.7428]]]) |
|
|
|
pred_length = tensor([2,2]) |
|
|
|
label = tensor([1,1]) |
|
|
|
label_lengths = tensor([1,1]) |
|
|
|
loss = F.nn.ctc_loss(pred, pred_length, label, label_lengths) |
|
|
|
print(loss.numpy()) |
|
|
|
|
|
|
|
Outputs: |
|
|
|
|
|
|
|
.. testoutput:: |
|
|
|
|
|
|
|
0.1504417 |
|
|
|
|
|
|
|
""" |
|
|
|
T, N, C = pred.shape |
|
|
|
|
|
|
|
assert ( |
|
|
|
pred_lengths.size == N |
|
|
|
), "pred_lengths must be equal to batch_size {}, but got {}".format( |
|
|
|
N, pred_lengths.size |
|
|
|
) |
|
|
|
assert ( |
|
|
|
label_lengths.size == N |
|
|
|
), "label_lengths must be euqal to batch_size {}, but got {}".format( |
|
|
|
N, label_lengths.size |
|
|
|
) |
|
|
|
assert ( |
|
|
|
blank >= 0 and blank < C |
|
|
|
), "blank must be in label range [0, {}), but got {}".format(C, blank) |
|
|
|
assert ( |
|
|
|
pred_lengths.min() > 0 and pred_lengths.max() <= T |
|
|
|
), "pred_lengths must be in range ({}, {}], bug got min {}, max {}".format( |
|
|
|
0, T, pred_lengths.min(), pred_lengths.max() |
|
|
|
) |
|
|
|
|
|
|
|
if label.ndim == 1: # concatenated label |
|
|
|
assert label_lengths.min() > 0, "label lengths muse be positive" |
|
|
|
assert ( |
|
|
|
label.size == label_lengths.sum() |
|
|
|
), "label size must be equal to sum(label_lengths)" |
|
|
|
else: |
|
|
|
N, S = label.shape |
|
|
|
assert ( |
|
|
|
label_lengths.min() > 0 and label_lengths.max() <= S |
|
|
|
), "label_lengths must be in range ({}, {}], bug got min {}, max {}".format( |
|
|
|
0, S, label_lengths.min(), label_lengths.max() |
|
|
|
) |
|
|
|
|
|
|
|
label = _expand_label(label, label_lengths, blank) |
|
|
|
label_mask = label[:, 2:] != label[:, :-2] |
|
|
|
L = label.shape[1] |
|
|
|
|
|
|
|
pred = pred.transpose(1, 0, 2) # (T, N, C) -> (N, T, C) |
|
|
|
|
|
|
|
batch_idx = linspace(0, N - 1, N).astype("int32").reshape(-1) |
|
|
|
batch_idx_NL = broadcast_to(batch_idx.reshape(N, 1), (N, L)).reshape(-1) |
|
|
|
|
|
|
|
match_pred = pred[batch_idx_NL, :, label.reshape(-1)].reshape( |
|
|
|
N, L, -1 |
|
|
|
) # (N, T, C) -> (N, L, T) |
|
|
|
|
|
|
|
log_alpha = zeros((N, L), dtype="float32") |
|
|
|
log_alpha[:, :2] = match_pred[:, :2, 0] |
|
|
|
log_alpha = _safelog(log_alpha) |
|
|
|
|
|
|
|
ret = -logaddexp( |
|
|
|
log_alpha[batch_idx, label_lengths * 2], |
|
|
|
log_alpha[batch_idx, label_lengths * 2 - 1], |
|
|
|
) * equal(pred_lengths - 1, 0) |
|
|
|
for t in range(1, T): |
|
|
|
la2 = log_alpha[:, :-2] |
|
|
|
log_alpha[:, 1:] = logaddexp(log_alpha[:, 1:], log_alpha[:, :-1]) |
|
|
|
log_alpha[:, 2:] = ( |
|
|
|
log_alpha[:, 2:] * (1 - label_mask) |
|
|
|
+ logaddexp(log_alpha[:, 2:], la2) * label_mask |
|
|
|
) |
|
|
|
log_alpha += _safelog(match_pred[:, :, t]) |
|
|
|
|
|
|
|
ret_t = -logaddexp( |
|
|
|
log_alpha[batch_idx, label_lengths * 2], |
|
|
|
log_alpha[batch_idx, label_lengths * 2 - 1], |
|
|
|
) |
|
|
|
ret += ret_t * equal(pred_lengths - 1, t) |
|
|
|
|
|
|
|
if reduction == "mean": |
|
|
|
return (ret / label_lengths).mean() |
|
|
|
elif reduction == "sum": |
|
|
|
return ret.sum() |
|
|
|
elif reduction == "none": |
|
|
|
return ret |
|
|
|
else: |
|
|
|
raise ValueError("{} is not a valid value for reduction".format(reduction)) |