@@ -202,26 +202,12 @@ class TorchDriver(Driver): | |||
num_consumed_batches = states.pop('num_consumed_batches') | |||
if hasattr(sampler, 'state_dict') and callable(sampler.state_dict): | |||
sampler_states = sampler.state_dict() | |||
# 需要针对 num_consumed_samples 做特殊的处理。因为DataLoader存在预取行为,直接使用sampler中的num_consumed_samples | |||
# 会造成多余实际消耗的问题。因为 | |||
num_consumed_samples_array = sampler_states.pop('num_consumed_samples_array', None) | |||
if num_consumed_samples_array is not None: | |||
if isinstance(sampler, ReproducibleSampler): # 如果是 sampler 的话,需要考虑 batch_size 。 | |||
if dataloader_args.batch_size is not None: | |||
num_consumed_batches = num_consumed_batches * dataloader_args.batch_size | |||
else: # 有可能 batch_size 为 None,就只有损失精度了 | |||
logger.rank_zero_warning("fastNLP cannot get batch_size, we have to save based on `num_consumed_samples`, " | |||
"it may cause missing some samples when reload.") | |||
num_consumed_batches = sampler_states['num_consumed_samples'] | |||
sampler_states['num_consumed_samples'] = num_consumed_samples_array[num_consumed_batches] | |||
assert sampler_states['num_consumed_samples'] != -1, "This is a bug, please report." | |||
if dataloader_args.batch_size is not None: | |||
sampler_states['num_consumed_samples'] = sampler.num_replicas * dataloader_args.batch_size \ | |||
* num_consumed_batches | |||
else: | |||
if dataloader_args.batch_size is not None: | |||
sampler_states['num_consumed_samples'] = sampler.num_replicas * dataloader_args.batch_size \ | |||
* num_consumed_batches | |||
else: | |||
logger.rank_zero_warning("fastNLP cannot get batch_size, we have to save based on `num_consumed_samples`, " | |||
"it may cause missing some samples when reload.") | |||
logger.rank_zero_warning("fastNLP cannot get batch_size, we have to save based on sampler's " | |||
"`num_consumed_samples`, it may cause missing some samples when reload.") | |||
states['sampler_states'] = sampler_states | |||
else: | |||
@@ -283,7 +283,7 @@ def optimizer_state_to_device(state, device): | |||
def _check_dataloader_args_for_distributed(args, controller='Trainer'): | |||
if type(args.batch_sampler) is not TorchBatchSampler and (type(args.sampler) not in {TorchRandomSampler, | |||
if type(args.batch_sampler) is not TorchBatchSampler or (type(args.sampler) not in {TorchRandomSampler, | |||
TorchSequentialSampler}): | |||
mode = 'training' if controller == 'Trainer' else 'evaluation' | |||
substitution = 'fastNLP.RandomSampler' if controller == 'Trainer' else 'fastNLP.UnrepeatedSequentialSampler' | |||
@@ -13,7 +13,6 @@ from itertools import chain | |||
import numpy as np | |||
from fastNLP.core.dataset import DataSet | |||
from fastNLP.envs.utils import get_global_seed | |||
from fastNLP.core.log import logger | |||
from .utils import create_array | |||
from abc import abstractmethod | |||
@@ -171,7 +170,7 @@ class RandomBatchSampler(ReproducibleBatchSampler): | |||
:param kwargs: fastNLP 保留使用 | |||
""" | |||
def __init__(self, dataset, batch_size:int = 32, shuffle: bool = True, | |||
drop_last: bool = False, seed: int = None, **kwargs): | |||
drop_last: bool = False, seed: int = 0, **kwargs): | |||
super().__init__() | |||
self.dataset = dataset | |||
@@ -179,7 +178,7 @@ class RandomBatchSampler(ReproducibleBatchSampler): | |||
self.batch_size = batch_size | |||
self.shuffle = shuffle | |||
self.drop_last = drop_last | |||
self.seed = get_global_seed() if seed is None else seed | |||
self.seed = int(seed) | |||
self.num_consumed_samples = kwargs.get("num_consumed_samples", 0) # 总共迭代了多少数据了,包括多卡情况下的其它卡上的输出的数量 | |||
@@ -398,7 +397,7 @@ class BucketedBatchSampler(ReproducibleBatchSampler): | |||
:param kwargs: fastNLP 保留使用 | |||
""" | |||
def __init__(self, dataset, length: Union[List[int], str], batch_size:int = 32, num_batch_per_bucket:int = 10, | |||
shuffle: bool = True, drop_last: bool = False, seed: int = None, **kwargs): | |||
shuffle: bool = True, drop_last: bool = False, seed: int = 0, **kwargs): | |||
super().__init__() | |||
if isinstance(dataset, DataSet) and isinstance(length, str): | |||
length = dataset.get_field(length).content | |||
@@ -423,7 +422,7 @@ class BucketedBatchSampler(ReproducibleBatchSampler): | |||
self.num_batch_per_bucket = num_batch_per_bucket | |||
self.shuffle = shuffle | |||
self.drop_last = drop_last | |||
self.seed = get_global_seed() if seed is None else seed | |||
self.seed = int(seed) | |||
self.num_consumed_samples = kwargs.get("num_consumed_samples", 0) # 总共迭代了多少数据了,包括多卡情况下的其它卡上的输出的数量 | |||
@@ -12,7 +12,6 @@ import numpy as np | |||
from fastNLP.core.log import logger | |||
from fastNLP.core.dataset import DataSet | |||
from fastNLP.envs.utils import get_global_seed | |||
class ReproducibleSampler: | |||
@@ -66,11 +65,11 @@ class RandomSampler(ReproducibleSampler): | |||
:param seed: 随机数种子。 | |||
:param kwargs: 用户不需要使用,fastNLP 内部使用 | |||
""" | |||
def __init__(self, dataset, shuffle: bool = True, seed: int = None, **kwargs): | |||
def __init__(self, dataset, shuffle: bool = True, seed: int = 0, **kwargs): | |||
super(RandomSampler, self).__init__() | |||
self.dataset = dataset | |||
self.shuffle = shuffle | |||
self.seed = get_global_seed() if seed is None else seed | |||
self.seed = int(seed) | |||
self.num_consumed_samples = kwargs.get("num_consumed_samples", 0) # 总共迭代了多少数据了,包括多卡情况下的其它卡上的输出的数量 | |||
@@ -7,7 +7,6 @@ __all__ = [ | |||
from typing import List, Union | |||
from fastNLP.core.dataset import DataSet | |||
from fastNLP.envs.utils import get_global_seed | |||
import numpy as np | |||
@@ -28,10 +27,10 @@ class UnrepeatedRandomSampler(UnrepeatedSampler): | |||
:param seed: 设置的随机数种子 | |||
:param kwargs: fastNLP 保留使用 | |||
""" | |||
def __init__(self, dataset, shuffle: bool = False, seed: int = None, **kwargs): | |||
def __init__(self, dataset, shuffle: bool = False, seed: int = 0, **kwargs): | |||
self.dataset = dataset | |||
self.shuffle = shuffle | |||
self.seed = get_global_seed() if seed is None else seed | |||
self.seed = int(seed) | |||
# 多卡的相关的参数 | |||
self.num_replicas = kwargs.get('num_replicas', 1) | |||
@@ -1,3 +1,5 @@ | |||
import os | |||
import pytest | |||
from pathlib import Path | |||
@@ -185,7 +187,7 @@ class TestSetDistReproDataloader: | |||
cls.device = [0, 1] | |||
def setup_method(self): | |||
self.dataset = TorchNormalDataset(40) | |||
self.dataset = TorchNormalDataset(100) | |||
""" | |||
传入的 `dist` 参数为具体的 ReproducibleSampler 或 ReproducibleBatchSampler 的情况 | |||
@@ -571,7 +573,7 @@ class TestSaveLoad: | |||
""" | |||
def setup_method(self): | |||
self.dataset = TorchNormalXYDataset(20) | |||
self.dataset = TorchNormalXYDataset(100) | |||
@magic_argv_env_context | |||
@pytest.mark.parametrize("only_state_dict", ([True, False])) | |||
@@ -641,7 +643,7 @@ class TestSaveLoad: | |||
rank=driver1.global_rank, | |||
pad=True | |||
) | |||
num_consumed_batches = 2 | |||
num_consumed_batches = 4 | |||
already_seen_x_set = set() | |||
already_seen_y_set = set() | |||
@@ -686,7 +688,8 @@ class TestSaveLoad: | |||
assert not (replaced_loader is dataloader) | |||
assert replaced_loader.batch_sampler is dataloader.batch_sampler | |||
assert isinstance(replaced_loader.batch_sampler, BucketedBatchSampler) | |||
assert replaced_loader.batch_sampler.seed == sampler_states["seed"] | |||
if os.environ['FASTNLP_GLOBAL_RANK'] == '0': | |||
assert replaced_loader.batch_sampler.seed == sampler_states["seed"] | |||
assert replaced_loader.batch_sampler.num_consumed_samples == num_consumed_batches * 4 * num_replicas | |||
# 3. 检查 fp16 是否被加载 | |||
@@ -753,7 +756,7 @@ class TestSaveLoad: | |||
rank=driver1.global_rank, | |||
pad=True | |||
) | |||
num_consumed_batches = 2 | |||
num_consumed_batches = 4 | |||
already_seen_x_set = set() | |||
already_seen_y_set = set() | |||
@@ -792,11 +795,13 @@ class TestSaveLoad: | |||
# 2. 检查 sampler 是否被正确地加载和替换 | |||
assert not (replaced_loader is dataloader) | |||
assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler) | |||
assert replaced_loader.batch_sampler.sampler.seed == sampler_states["seed"] | |||
assert replaced_loader.batch_sampler.sampler.epoch == sampler_states["epoch"] | |||
if os.environ['FASTNLP_GLOBAL_RANK'] == '0': | |||
assert replaced_loader.batch_sampler.sampler.seed == sampler_states["seed"] | |||
assert replaced_loader.batch_sampler.sampler.epoch == sampler_states["epoch"] | |||
assert len(replaced_loader.batch_sampler.sampler.dataset) == sampler_states["length"] | |||
assert replaced_loader.batch_sampler.sampler.shuffle == sampler_states["shuffle"] | |||
assert replaced_loader.batch_sampler.sampler.num_consumed_samples == 4 * num_consumed_batches * num_replicas | |||
assert len(replaced_loader.batch_sampler.sampler.dataset) == sampler_states["length"] | |||
assert replaced_loader.batch_sampler.sampler.shuffle == sampler_states["shuffle"] | |||
# 3. 检查 fp16 是否被加载 | |||
if fp16: | |||
assert not isinstance(driver2.grad_scaler, torch.cuda.amp.GradScaler) | |||