Browse Source

修改 ReproducibleSampler ReproducibleBatchSampler 的 num_samples 获取方式

tags/v1.0.0alpha
x54-729 3 years ago
parent
commit
badd07f8f1
2 changed files with 22 additions and 3 deletions
  1. +18
    -2
      fastNLP/core/samplers/reproducible_batch_sampler.py
  2. +4
    -1
      fastNLP/core/samplers/reproducible_sampler.py

+ 18
- 2
fastNLP/core/samplers/reproducible_batch_sampler.py View File

@@ -318,7 +318,15 @@ class RandomBatchSampler(ReproducibleBatchSampler):

@property
def num_samples(self):
return getattr(self.dataset, 'total_len', len(self.dataset))
"""
返回样本的总数

:return:
"""
total_len = getattr(self.dataset, 'total_len', None)
if not isinstance(total_len, int):
total_len = len(self.dataset)
return total_len

def __len__(self)->int:
"""
@@ -473,7 +481,15 @@ class BucketedBatchSampler(ReproducibleBatchSampler):

@property
def num_samples(self):
return getattr(self.dataset, 'total_len', len(self.dataset))
"""
返回样本的总数

:return:
"""
total_len = getattr(self.dataset, 'total_len', None)
if not isinstance(total_len, int):
total_len = len(self.dataset)
return total_len

def __len__(self)->int:
"""


+ 4
- 1
fastNLP/core/samplers/reproducible_sampler.py View File

@@ -222,7 +222,10 @@ class RandomSampler(ReproducibleSampler):

:return:
"""
return getattr(self.dataset, 'total_len', len(self.dataset))
total_len = getattr(self.dataset, 'total_len', None)
if not isinstance(total_len, int):
total_len = len(self.dataset)
return total_len

class SequentialSampler(RandomSampler):
"""


Loading…
Cancel
Save