Browse Source

ci(imperative): fix imperative test_ctc_loss test

GitOrigin-RevId: d48dccc30c
master
Megvii Engine Team 2 years ago
parent
commit
ffb626243b
1 changed files with 2 additions and 1 deletions
  1. +2
    -1
      imperative/python/test/unit/functional/test_loss.py

+ 2
- 1
imperative/python/test/unit/functional/test_loss.py View File

@@ -134,7 +134,6 @@ def _ctc_npy_single_seq(pred, label, blank):
x, y = np.maximum(x, y), np.minimum(x, y)
return x + np.log1p(np.exp(y - x))

assert np.abs(pred.sum(axis=1) - 1).max() <= 1e-3
len_pred, alphabet_size = pred.shape
(len_label,) = label.shape

@@ -166,6 +165,8 @@ def test_ctc_loss():
def test_func(T, C, N):
input = np.random.randn(T, N, C)
input = F.softmax(Tensor(input), axis=-1).numpy()
# replace nan to 0.2
input = np.nan_to_num(input, copy=True, nan=0.2)
input_lengths = np.ones(N, dtype=np.int32) * T
target_lengths = np.random.randint(low=1, high=T + 1, size=(N,), dtype=np.int32)
target = np.random.randint(


Loading…
Cancel
Save