From badd07f8f13784058bbd9d1da137bd2396b4168a Mon Sep 17 00:00:00 2001 From: x54-729 <17307130121@fudan.edu.cn> Date: Sun, 29 May 2022 07:04:24 +0000 Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E6=94=B9=20ReproducibleSampler=20Repr?= =?UTF-8?q?oducibleBatchSampler=20=E7=9A=84=20num=5Fsamples=20=E8=8E=B7?= =?UTF-8?q?=E5=8F=96=E6=96=B9=E5=BC=8F?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/core/samplers/reproducible_batch_sampler.py | 20 ++++++++++++++++++-- fastNLP/core/samplers/reproducible_sampler.py | 5 ++++- 2 files changed, 22 insertions(+), 3 deletions(-) diff --git a/fastNLP/core/samplers/reproducible_batch_sampler.py b/fastNLP/core/samplers/reproducible_batch_sampler.py index 874ad895..edb8a67f 100644 --- a/fastNLP/core/samplers/reproducible_batch_sampler.py +++ b/fastNLP/core/samplers/reproducible_batch_sampler.py @@ -318,7 +318,15 @@ class RandomBatchSampler(ReproducibleBatchSampler): @property def num_samples(self): - return getattr(self.dataset, 'total_len', len(self.dataset)) + """ + 返回样本的总数 + + :return: + """ + total_len = getattr(self.dataset, 'total_len', None) + if not isinstance(total_len, int): + total_len = len(self.dataset) + return total_len def __len__(self)->int: """ @@ -473,7 +481,15 @@ class BucketedBatchSampler(ReproducibleBatchSampler): @property def num_samples(self): - return getattr(self.dataset, 'total_len', len(self.dataset)) + """ + 返回样本的总数 + + :return: + """ + total_len = getattr(self.dataset, 'total_len', None) + if not isinstance(total_len, int): + total_len = len(self.dataset) + return total_len def __len__(self)->int: """ diff --git a/fastNLP/core/samplers/reproducible_sampler.py b/fastNLP/core/samplers/reproducible_sampler.py index a1dc318e..fe38a808 100644 --- a/fastNLP/core/samplers/reproducible_sampler.py +++ b/fastNLP/core/samplers/reproducible_sampler.py @@ -222,7 +222,10 @@ class RandomSampler(ReproducibleSampler): :return: """ - return getattr(self.dataset, 'total_len', len(self.dataset)) + total_len = getattr(self.dataset, 'total_len', None) + if not isinstance(total_len, int): + total_len = len(self.dataset) + return total_len class SequentialSampler(RandomSampler): """