diff --git a/imperative/python/megengine/core/tensor/array_method.py b/imperative/python/megengine/core/tensor/array_method.py index 4fa9af21..2db8efb3 100644 --- a/imperative/python/megengine/core/tensor/array_method.py +++ b/imperative/python/megengine/core/tensor/array_method.py @@ -61,6 +61,7 @@ def _elwise(*args, mode): _ElwMod.H_SWISH, _ElwMod.SIGMOID, _ElwMod.SIN, + _ElwMod.LOG_SUM_EXP, ) and ( amp._enabled or np.all([np.issubdtype(arg.dtype, np.integer) for arg in args]) ): diff --git a/imperative/python/megengine/functional/elemwise.py b/imperative/python/megengine/functional/elemwise.py index e54fecb4..ba58556d 100644 --- a/imperative/python/megengine/functional/elemwise.py +++ b/imperative/python/megengine/functional/elemwise.py @@ -48,6 +48,7 @@ __all__ = [ "logical_not", "logical_or", "logical_xor", + "logaddexp", "maximum", "minimum", "mod", @@ -406,6 +407,12 @@ def logical_xor(x, y): return _elwise(x, y, mode=Elemwise.Mode.XOR) +def logaddexp(x: Tensor, y: Tensor) -> Tensor: + r"""Element-wise `numerically stable log(exp(x) + exp(y)` + """ + return _elwise(x, y, mode=Elemwise.Mode.LOG_SUM_EXP) + + # comparison functions diff --git a/imperative/python/megengine/functional/loss.py b/imperative/python/megengine/functional/loss.py index 655c5d0b..40f37410 100644 --- a/imperative/python/megengine/functional/loss.py +++ b/imperative/python/megengine/functional/loss.py @@ -12,9 +12,9 @@ import numpy as np from ..core.tensor.array_method import _reduce from ..tensor import Tensor -from .elemwise import abs, log +from .elemwise import abs, equal, log, logaddexp, maximum from .nn import indexing_one_hot, logsigmoid, logsumexp, relu -from .tensor import where +from .tensor import broadcast_to, cumsum, linspace, ones, where, zeros __all__ = [ "l1_loss", @@ -22,6 +22,7 @@ __all__ = [ "cross_entropy", "binary_cross_entropy", "hinge_loss", + "ctc_loss", ] @@ -316,3 +317,164 @@ def hinge_loss( return loss.sum(axis=1) else: return (loss ** 2).sum(axis=1) + + +def _gen_repeat_idx(inp: Tensor): + idx = cumsum(inp, axis=0) + ret = zeros(inp.sum(), dtype="int32") + ret[idx[:-1]] = 1 + return cumsum(ret, axis=0) + + +def _gen_tile_idx(inp: Tensor): + idx = cumsum(inp, axis=0) + ret = ones(inp.sum(), dtype="int32") + ret[idx[:-1]] = -(inp - 1)[:-1] + return cumsum(ret, axis=0) - 1 + + +def _expand_label(label: Tensor, label_lengths: Tensor, blank: int) -> Tensor: + N = label_lengths.shape[0] + if len(label.shape) == 1: + L = label_lengths.max() + unpack_label = zeros((N, L), dtype="int32") + blank + idx_0 = _gen_repeat_idx(label_lengths) + idx_1 = _gen_tile_idx(label_lengths) + unpack_label[idx_0, idx_1] = label + label = unpack_label + + L = label.shape[1] + ex_label = zeros((N, L * 2 + 1), dtype="int32") + blank + ex_label[:, 1::2] = label + return ex_label + + +def _safelog(x: Tensor) -> Tensor: + eps = np.finfo(x.dtype).tiny + return log(maximum(x, eps)) + + +def ctc_loss( + pred: Tensor, + pred_lengths: Tensor, + label: Tensor, + label_lengths: Tensor, + blank: int = 0, + reduction: str = "mean", +) -> Tensor: + r"""The Connectionist Temporal Classification loss. + + + Args: + pred: The probabilities of the output, shape is (T, N, C) , + where T=input length, N=batch size, and C=number of classes (including blank). + pred_lengths: number of time steps for each sequence in ``pred``, shape is (N, ) + label: groundtruth labels, containing the indices of groundtruth + symbols for each sequence at each output time step, and the blank + symbol should not be included. shape is (N, S) or (sum(label_lengths)). + label_lengths: number of time steps for each sequence in the groundtruth, shape is (N, ) + blank: the blank symbol number, default 0 + reduction: the reduction to apply to the output: 'none' | 'mean' | 'sum'. Default: 'mean' + + Returns: + loss value. + + Examples: + + .. testcode:: + + from megengine import tensor + import megengine.functional as F + + pred = tensor([[[0.0614, 0.9386],[0.8812, 0.1188]],[[0.699, 0.301 ],[0.2572, 0.7428]]]) + pred_length = tensor([2,2]) + label = tensor([1,1]) + label_lengths = tensor([1,1]) + loss = F.nn.ctc_loss(pred, pred_length, label, label_lengths) + print(loss.numpy()) + + Outputs: + + .. testoutput:: + + 0.1504417 + + """ + T, N, C = pred.shape + + assert ( + pred_lengths.size == N + ), "pred_lengths must be equal to batch_size {}, but got {}".format( + N, pred_lengths.size + ) + assert ( + label_lengths.size == N + ), "label_lengths must be euqal to batch_size {}, but got {}".format( + N, label_lengths.size + ) + assert ( + blank >= 0 and blank < C + ), "blank must be in label range [0, {}), but got {}".format(C, blank) + assert ( + pred_lengths.min() > 0 and pred_lengths.max() <= T + ), "pred_lengths must be in range ({}, {}], bug got min {}, max {}".format( + 0, T, pred_lengths.min(), pred_lengths.max() + ) + + if label.ndim == 1: # concatenated label + assert label_lengths.min() > 0, "label lengths muse be positive" + assert ( + label.size == label_lengths.sum() + ), "label size must be equal to sum(label_lengths)" + else: + N, S = label.shape + assert ( + label_lengths.min() > 0 and label_lengths.max() <= S + ), "label_lengths must be in range ({}, {}], bug got min {}, max {}".format( + 0, S, label_lengths.min(), label_lengths.max() + ) + + label = _expand_label(label, label_lengths, blank) + label_mask = label[:, 2:] != label[:, :-2] + L = label.shape[1] + + pred = pred.transpose(1, 0, 2) # (T, N, C) -> (N, T, C) + + batch_idx = linspace(0, N - 1, N).astype("int32").reshape(-1) + batch_idx_NL = broadcast_to(batch_idx.reshape(N, 1), (N, L)).reshape(-1) + + match_pred = pred[batch_idx_NL, :, label.reshape(-1)].reshape( + N, L, -1 + ) # (N, T, C) -> (N, L, T) + + log_alpha = zeros((N, L), dtype="float32") + log_alpha[:, :2] = match_pred[:, :2, 0] + log_alpha = _safelog(log_alpha) + + ret = -logaddexp( + log_alpha[batch_idx, label_lengths * 2], + log_alpha[batch_idx, label_lengths * 2 - 1], + ) * equal(pred_lengths - 1, 0) + for t in range(1, T): + la2 = log_alpha[:, :-2] + log_alpha[:, 1:] = logaddexp(log_alpha[:, 1:], log_alpha[:, :-1]) + log_alpha[:, 2:] = ( + log_alpha[:, 2:] * (1 - label_mask) + + logaddexp(log_alpha[:, 2:], la2) * label_mask + ) + log_alpha += _safelog(match_pred[:, :, t]) + + ret_t = -logaddexp( + log_alpha[batch_idx, label_lengths * 2], + log_alpha[batch_idx, label_lengths * 2 - 1], + ) + ret += ret_t * equal(pred_lengths - 1, t) + + if reduction == "mean": + return (ret / label_lengths).mean() + elif reduction == "sum": + return ret.sum() + elif reduction == "none": + return ret + else: + raise ValueError("{} is not a valid value for reduction".format(reduction)) diff --git a/imperative/python/test/unit/functional/test_elemwise.py b/imperative/python/test/unit/functional/test_elemwise.py index e4cae31c..782c000f 100644 --- a/imperative/python/test/unit/functional/test_elemwise.py +++ b/imperative/python/test/unit/functional/test_elemwise.py @@ -170,6 +170,16 @@ def test_logical_oprs(): np.testing.assert_equal(x ^ y, F.logical_xor(xx, yy).numpy()) +def test_logaddexp(): + x = np.random.randn(2, 100) + y = np.random.randn(2, 100) + xx = tensor(x) + yy = tensor(y) + out_np = np.log(np.exp(x) + np.exp(y)) + out_mge = F.logaddexp(xx, yy) + np.testing.assert_almost_equal(out_np, out_mge.numpy(), decimal=6) + + def test_qadd(): inp_scale = 0.5 outp_scale = 0.2 diff --git a/imperative/python/test/unit/functional/test_loss.py b/imperative/python/test/unit/functional/test_loss.py index 9a725552..d46f40b6 100644 --- a/imperative/python/test/unit/functional/test_loss.py +++ b/imperative/python/test/unit/functional/test_loss.py @@ -79,3 +79,128 @@ def test_cross_entropy_reduction(): with pytest.raises(ValueError): F.nn.cross_entropy(logits, label, reduction="max") + + +def ctc_nll_naive_npy( + pred, + pred_lengths, + label, + label_lengths, + blank=0, + reduction="mean", + time_major=False, +): + """naive :func:`ctc_nll` using numpy arrays. Used for testing and helping + our user to understand how CTC works. Only ``LABEL_COMPACT`` mode is + supported.""" + + pred = np.asarray(pred, dtype=np.float32) + pred_lengths = np.asarray(pred_lengths, dtype=np.int8) + label = np.asarray(label, dtype=np.int32) + label_lengths = np.asarray(label_lengths, dtype=np.int32) + + if time_major: + pred = np.transpose(pred, (1, 0, 2)) + # pred in (N, T, P) format + + batch_size, time_len, nr_class = pred.shape + assert pred_lengths.shape == (batch_size,) and pred_lengths.max() <= pred.shape[1] + assert label_lengths.shape == (batch_size,) + assert label.shape == (label_lengths.sum(),) and label.max() < nr_class + + ret = np.empty((batch_size,), dtype=np.float32) + label_start = 0 + for i in range(batch_size): + label_end = label_start + label_lengths[i] + ret[i] = _ctc_npy_single_seq( + pred[i][: pred_lengths[i]], label[label_start:label_end], blank + ) + label_start = label_end + + if reduction == "mean": + return (ret / label_lengths).mean() + elif reduction == "sum": + return ret.sum() + elif reduction == "none": + return ret + else: + raise ValueError("{} is not a valid value for reduction".format(reduction)) + + +def _ctc_npy_single_seq(pred, label, blank): + def safelog(x): + eps = np.finfo(x.dtype).tiny + return np.log(np.maximum(x, eps)) + + def log_sum_exp(x, y): + x, y = np.maximum(x, y), np.minimum(x, y) + return x + np.log1p(np.exp(y - x)) + + assert np.abs(pred.sum(axis=1) - 1).max() <= 1e-3 + len_pred, alphabet_size = pred.shape + (len_label,) = label.shape + + len_ex_label = len_label * 2 + 1 + ex_label = (np.zeros(len_ex_label)).astype(np.int32) + blank + ex_label[1::2] = label + + prob = np.zeros(len_ex_label, dtype=np.float32) + prob[0] = pred[0][ex_label[0]] + prob[1] = pred[0][ex_label[1]] + prob = safelog(prob) # compute on log scale + + ex_label_pmask = ex_label[2:] != ex_label[:-2] + for t in range(1, len_pred): + # enter loop: prob[i] = log(p(pred[:t+1], label[:i+1])) + new_prob = prob.copy() + new_prob[1:] = log_sum_exp(new_prob[1:], prob[:-1]) + new_prob[2:] = ( + new_prob[2:] * (1 - ex_label_pmask) + + log_sum_exp(new_prob[2:], prob[:-2]) * ex_label_pmask + ) + new_prob += safelog(pred[t, ex_label]) + prob = new_prob + + return -log_sum_exp(prob[-1], prob[-2]) + + +def test_ctc_loss(): + def test_func(T, C, N): + input = np.random.randn(T, N, C) + input = F.softmax(tensor(input), axis=-1).numpy() + input_lengths = np.ones(N, dtype=np.int32) * T + target_lengths = np.random.randint(low=1, high=T + 1, size=(N,), dtype=np.int32) + target = np.random.randint( + low=1, high=C, size=(sum(target_lengths)), dtype=np.int32 + ) + + input_mge = tensor(input) + input_lengths_mge = tensor(input_lengths) + + target_mge = tensor(target) + target_lengths_mge = tensor(target_lengths) + + blank = np.random.randint(C) + for method in ["mean", "sum", "none"]: + np_out = ctc_nll_naive_npy( + input, + input_lengths, + target, + target_lengths, + blank=blank, + reduction=method, + time_major=True, + ) + mge_out = F.nn.ctc_loss( + input_mge, + input_lengths_mge, + target_mge, + target_lengths_mge, + blank=blank, + reduction=method, + ) + np.testing.assert_allclose(mge_out.numpy(), np_out, rtol=2e-6) + + cases = [[1, 2, 1], [100, 50, 200], [100, 5, 1]] + for case in cases: + test_func(*case)