From a6cfc4086f1b5cfaebbec3c4f33af66b95254aee Mon Sep 17 00:00:00 2001 From: x54-729 <17307130121@fudan.edu.cn> Date: Wed, 25 May 2022 17:34:32 +0000 Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E6=94=B9seq=5Flen=5Fto=5Fmask?= =?UTF-8?q?=E7=9A=84jittor=E5=AE=9E=E7=8E=B0=E5=8F=8A=E6=B5=8B=E8=AF=95?= =?UTF-8?q?=E4=BE=8B=E4=B8=AD=E7=9A=84=E4=B8=80=E5=A4=84=E4=BC=A0=E5=8F=82?= =?UTF-8?q?=E9=94=99=E8=AF=AF?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/core/utils/seq_len_to_mask.py | 2 +- tests/core/utils/test_seq_len_to_mask.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/fastNLP/core/utils/seq_len_to_mask.py b/fastNLP/core/utils/seq_len_to_mask.py index e244603c..710c0a2b 100644 --- a/fastNLP/core/utils/seq_len_to_mask.py +++ b/fastNLP/core/utils/seq_len_to_mask.py @@ -74,7 +74,7 @@ def seq_len_to_mask(seq_len, max_len: Optional[int]=None): if isinstance(seq_len, jittor.Var): assert seq_len.ndim == 1, f"seq_len can only have one dimension, got {seq_len.ndim == 1}." batch_size = seq_len.shape[0] - broad_cast_seq_len = jittor.arange(max_len).expand(batch_size, -1) + broad_cast_seq_len = jittor.arange(max_len).reshape(1, max_len).expand(batch_size, -1) mask = broad_cast_seq_len < seq_len.unsqueeze(1) return mask except NameError as e: diff --git a/tests/core/utils/test_seq_len_to_mask.py b/tests/core/utils/test_seq_len_to_mask.py index 0a17bae6..64c84837 100644 --- a/tests/core/utils/test_seq_len_to_mask.py +++ b/tests/core/utils/test_seq_len_to_mask.py @@ -78,7 +78,7 @@ class TestSeqLenToMask: mask = seq_len_to_mask(seq_len) # 3. pad到指定长度 - seq_len = paddle.randint(1, 10, size=(10,)) + seq_len = paddle.randint(1, 10, shape=(10,)) mask = seq_len_to_mask(seq_len, 100) assert 100 == mask.shape[1]