diff --git a/fastNLP/core/utils/seq_len_to_mask.py b/fastNLP/core/utils/seq_len_to_mask.py index e244603c..710c0a2b 100644 --- a/fastNLP/core/utils/seq_len_to_mask.py +++ b/fastNLP/core/utils/seq_len_to_mask.py @@ -74,7 +74,7 @@ def seq_len_to_mask(seq_len, max_len: Optional[int]=None): if isinstance(seq_len, jittor.Var): assert seq_len.ndim == 1, f"seq_len can only have one dimension, got {seq_len.ndim == 1}." batch_size = seq_len.shape[0] - broad_cast_seq_len = jittor.arange(max_len).expand(batch_size, -1) + broad_cast_seq_len = jittor.arange(max_len).reshape(1, max_len).expand(batch_size, -1) mask = broad_cast_seq_len < seq_len.unsqueeze(1) return mask except NameError as e: diff --git a/tests/core/utils/test_seq_len_to_mask.py b/tests/core/utils/test_seq_len_to_mask.py index 0a17bae6..64c84837 100644 --- a/tests/core/utils/test_seq_len_to_mask.py +++ b/tests/core/utils/test_seq_len_to_mask.py @@ -78,7 +78,7 @@ class TestSeqLenToMask: mask = seq_len_to_mask(seq_len) # 3. pad到指定长度 - seq_len = paddle.randint(1, 10, size=(10,)) + seq_len = paddle.randint(1, 10, shape=(10,)) mask = seq_len_to_mask(seq_len, 100) assert 100 == mask.shape[1]