diff --git a/imperative/python/test/unit/functional/test_loss.py b/imperative/python/test/unit/functional/test_loss.py index 2a1950de..d047b3ed 100644 --- a/imperative/python/test/unit/functional/test_loss.py +++ b/imperative/python/test/unit/functional/test_loss.py @@ -134,7 +134,6 @@ def _ctc_npy_single_seq(pred, label, blank): x, y = np.maximum(x, y), np.minimum(x, y) return x + np.log1p(np.exp(y - x)) - assert np.abs(pred.sum(axis=1) - 1).max() <= 1e-3 len_pred, alphabet_size = pred.shape (len_label,) = label.shape @@ -166,6 +165,8 @@ def test_ctc_loss(): def test_func(T, C, N): input = np.random.randn(T, N, C) input = F.softmax(Tensor(input), axis=-1).numpy() + # replace nan to 0.2 + input = np.nan_to_num(input, copy=True, nan=0.2) input_lengths = np.ones(N, dtype=np.int32) * T target_lengths = np.random.randint(low=1, high=T + 1, size=(N,), dtype=np.int32) target = np.random.randint(