From dd1c5ca03579b435ef690559874fc840ee58e48f Mon Sep 17 00:00:00 2001 From: yhcc Date: Mon, 13 Jun 2022 23:15:06 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E5=A4=8D=E8=AE=BE=E7=BD=AE=E4=BA=86gl?= =?UTF-8?q?obal=20seed=E7=9A=84bug?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/core/drivers/torch_driver/torch_driver.py | 24 +++++----------------- fastNLP/core/drivers/torch_driver/utils.py | 2 +- .../core/samplers/reproducible_batch_sampler.py | 9 ++++---- fastNLP/core/samplers/reproducible_sampler.py | 5 ++--- fastNLP/core/samplers/unrepeated_sampler.py | 5 ++--- tests/core/drivers/torch_driver/test_ddp.py | 23 +++++++++++++-------- 6 files changed, 28 insertions(+), 40 deletions(-) diff --git a/fastNLP/core/drivers/torch_driver/torch_driver.py b/fastNLP/core/drivers/torch_driver/torch_driver.py index 96529073..fb01b6c3 100644 --- a/fastNLP/core/drivers/torch_driver/torch_driver.py +++ b/fastNLP/core/drivers/torch_driver/torch_driver.py @@ -202,26 +202,12 @@ class TorchDriver(Driver): num_consumed_batches = states.pop('num_consumed_batches') if hasattr(sampler, 'state_dict') and callable(sampler.state_dict): sampler_states = sampler.state_dict() - # 需要针对 num_consumed_samples 做特殊的处理。因为DataLoader存在预取行为,直接使用sampler中的num_consumed_samples - # 会造成多余实际消耗的问题。因为 - num_consumed_samples_array = sampler_states.pop('num_consumed_samples_array', None) - if num_consumed_samples_array is not None: - if isinstance(sampler, ReproducibleSampler): # 如果是 sampler 的话,需要考虑 batch_size 。 - if dataloader_args.batch_size is not None: - num_consumed_batches = num_consumed_batches * dataloader_args.batch_size - else: # 有可能 batch_size 为 None,就只有损失精度了 - logger.rank_zero_warning("fastNLP cannot get batch_size, we have to save based on `num_consumed_samples`, " - "it may cause missing some samples when reload.") - num_consumed_batches = sampler_states['num_consumed_samples'] - sampler_states['num_consumed_samples'] = num_consumed_samples_array[num_consumed_batches] - assert sampler_states['num_consumed_samples'] != -1, "This is a bug, please report." + if dataloader_args.batch_size is not None: + sampler_states['num_consumed_samples'] = sampler.num_replicas * dataloader_args.batch_size \ + * num_consumed_batches else: - if dataloader_args.batch_size is not None: - sampler_states['num_consumed_samples'] = sampler.num_replicas * dataloader_args.batch_size \ - * num_consumed_batches - else: - logger.rank_zero_warning("fastNLP cannot get batch_size, we have to save based on `num_consumed_samples`, " - "it may cause missing some samples when reload.") + logger.rank_zero_warning("fastNLP cannot get batch_size, we have to save based on sampler's " + "`num_consumed_samples`, it may cause missing some samples when reload.") states['sampler_states'] = sampler_states else: diff --git a/fastNLP/core/drivers/torch_driver/utils.py b/fastNLP/core/drivers/torch_driver/utils.py index 300bf196..f5a76a9e 100644 --- a/fastNLP/core/drivers/torch_driver/utils.py +++ b/fastNLP/core/drivers/torch_driver/utils.py @@ -283,7 +283,7 @@ def optimizer_state_to_device(state, device): def _check_dataloader_args_for_distributed(args, controller='Trainer'): - if type(args.batch_sampler) is not TorchBatchSampler and (type(args.sampler) not in {TorchRandomSampler, + if type(args.batch_sampler) is not TorchBatchSampler or (type(args.sampler) not in {TorchRandomSampler, TorchSequentialSampler}): mode = 'training' if controller == 'Trainer' else 'evaluation' substitution = 'fastNLP.RandomSampler' if controller == 'Trainer' else 'fastNLP.UnrepeatedSequentialSampler' diff --git a/fastNLP/core/samplers/reproducible_batch_sampler.py b/fastNLP/core/samplers/reproducible_batch_sampler.py index f522f997..50276ba1 100644 --- a/fastNLP/core/samplers/reproducible_batch_sampler.py +++ b/fastNLP/core/samplers/reproducible_batch_sampler.py @@ -13,7 +13,6 @@ from itertools import chain import numpy as np from fastNLP.core.dataset import DataSet -from fastNLP.envs.utils import get_global_seed from fastNLP.core.log import logger from .utils import create_array from abc import abstractmethod @@ -171,7 +170,7 @@ class RandomBatchSampler(ReproducibleBatchSampler): :param kwargs: fastNLP 保留使用 """ def __init__(self, dataset, batch_size:int = 32, shuffle: bool = True, - drop_last: bool = False, seed: int = None, **kwargs): + drop_last: bool = False, seed: int = 0, **kwargs): super().__init__() self.dataset = dataset @@ -179,7 +178,7 @@ class RandomBatchSampler(ReproducibleBatchSampler): self.batch_size = batch_size self.shuffle = shuffle self.drop_last = drop_last - self.seed = get_global_seed() if seed is None else seed + self.seed = int(seed) self.num_consumed_samples = kwargs.get("num_consumed_samples", 0) # 总共迭代了多少数据了,包括多卡情况下的其它卡上的输出的数量 @@ -398,7 +397,7 @@ class BucketedBatchSampler(ReproducibleBatchSampler): :param kwargs: fastNLP 保留使用 """ def __init__(self, dataset, length: Union[List[int], str], batch_size:int = 32, num_batch_per_bucket:int = 10, - shuffle: bool = True, drop_last: bool = False, seed: int = None, **kwargs): + shuffle: bool = True, drop_last: bool = False, seed: int = 0, **kwargs): super().__init__() if isinstance(dataset, DataSet) and isinstance(length, str): length = dataset.get_field(length).content @@ -423,7 +422,7 @@ class BucketedBatchSampler(ReproducibleBatchSampler): self.num_batch_per_bucket = num_batch_per_bucket self.shuffle = shuffle self.drop_last = drop_last - self.seed = get_global_seed() if seed is None else seed + self.seed = int(seed) self.num_consumed_samples = kwargs.get("num_consumed_samples", 0) # 总共迭代了多少数据了,包括多卡情况下的其它卡上的输出的数量 diff --git a/fastNLP/core/samplers/reproducible_sampler.py b/fastNLP/core/samplers/reproducible_sampler.py index dc396851..e1a06fa1 100644 --- a/fastNLP/core/samplers/reproducible_sampler.py +++ b/fastNLP/core/samplers/reproducible_sampler.py @@ -12,7 +12,6 @@ import numpy as np from fastNLP.core.log import logger from fastNLP.core.dataset import DataSet -from fastNLP.envs.utils import get_global_seed class ReproducibleSampler: @@ -66,11 +65,11 @@ class RandomSampler(ReproducibleSampler): :param seed: 随机数种子。 :param kwargs: 用户不需要使用,fastNLP 内部使用 """ - def __init__(self, dataset, shuffle: bool = True, seed: int = None, **kwargs): + def __init__(self, dataset, shuffle: bool = True, seed: int = 0, **kwargs): super(RandomSampler, self).__init__() self.dataset = dataset self.shuffle = shuffle - self.seed = get_global_seed() if seed is None else seed + self.seed = int(seed) self.num_consumed_samples = kwargs.get("num_consumed_samples", 0) # 总共迭代了多少数据了,包括多卡情况下的其它卡上的输出的数量 diff --git a/fastNLP/core/samplers/unrepeated_sampler.py b/fastNLP/core/samplers/unrepeated_sampler.py index 22207274..e959a4d0 100644 --- a/fastNLP/core/samplers/unrepeated_sampler.py +++ b/fastNLP/core/samplers/unrepeated_sampler.py @@ -7,7 +7,6 @@ __all__ = [ from typing import List, Union from fastNLP.core.dataset import DataSet -from fastNLP.envs.utils import get_global_seed import numpy as np @@ -28,10 +27,10 @@ class UnrepeatedRandomSampler(UnrepeatedSampler): :param seed: 设置的随机数种子 :param kwargs: fastNLP 保留使用 """ - def __init__(self, dataset, shuffle: bool = False, seed: int = None, **kwargs): + def __init__(self, dataset, shuffle: bool = False, seed: int = 0, **kwargs): self.dataset = dataset self.shuffle = shuffle - self.seed = get_global_seed() if seed is None else seed + self.seed = int(seed) # 多卡的相关的参数 self.num_replicas = kwargs.get('num_replicas', 1) diff --git a/tests/core/drivers/torch_driver/test_ddp.py b/tests/core/drivers/torch_driver/test_ddp.py index 74f44c04..46abd84c 100644 --- a/tests/core/drivers/torch_driver/test_ddp.py +++ b/tests/core/drivers/torch_driver/test_ddp.py @@ -1,3 +1,5 @@ +import os + import pytest from pathlib import Path @@ -185,7 +187,7 @@ class TestSetDistReproDataloader: cls.device = [0, 1] def setup_method(self): - self.dataset = TorchNormalDataset(40) + self.dataset = TorchNormalDataset(100) """ 传入的 `dist` 参数为具体的 ReproducibleSampler 或 ReproducibleBatchSampler 的情况 @@ -571,7 +573,7 @@ class TestSaveLoad: """ def setup_method(self): - self.dataset = TorchNormalXYDataset(20) + self.dataset = TorchNormalXYDataset(100) @magic_argv_env_context @pytest.mark.parametrize("only_state_dict", ([True, False])) @@ -641,7 +643,7 @@ class TestSaveLoad: rank=driver1.global_rank, pad=True ) - num_consumed_batches = 2 + num_consumed_batches = 4 already_seen_x_set = set() already_seen_y_set = set() @@ -686,7 +688,8 @@ class TestSaveLoad: assert not (replaced_loader is dataloader) assert replaced_loader.batch_sampler is dataloader.batch_sampler assert isinstance(replaced_loader.batch_sampler, BucketedBatchSampler) - assert replaced_loader.batch_sampler.seed == sampler_states["seed"] + if os.environ['FASTNLP_GLOBAL_RANK'] == '0': + assert replaced_loader.batch_sampler.seed == sampler_states["seed"] assert replaced_loader.batch_sampler.num_consumed_samples == num_consumed_batches * 4 * num_replicas # 3. 检查 fp16 是否被加载 @@ -753,7 +756,7 @@ class TestSaveLoad: rank=driver1.global_rank, pad=True ) - num_consumed_batches = 2 + num_consumed_batches = 4 already_seen_x_set = set() already_seen_y_set = set() @@ -792,11 +795,13 @@ class TestSaveLoad: # 2. 检查 sampler 是否被正确地加载和替换 assert not (replaced_loader is dataloader) assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler) - assert replaced_loader.batch_sampler.sampler.seed == sampler_states["seed"] - assert replaced_loader.batch_sampler.sampler.epoch == sampler_states["epoch"] + if os.environ['FASTNLP_GLOBAL_RANK'] == '0': + assert replaced_loader.batch_sampler.sampler.seed == sampler_states["seed"] + assert replaced_loader.batch_sampler.sampler.epoch == sampler_states["epoch"] + assert len(replaced_loader.batch_sampler.sampler.dataset) == sampler_states["length"] + assert replaced_loader.batch_sampler.sampler.shuffle == sampler_states["shuffle"] assert replaced_loader.batch_sampler.sampler.num_consumed_samples == 4 * num_consumed_batches * num_replicas - assert len(replaced_loader.batch_sampler.sampler.dataset) == sampler_states["length"] - assert replaced_loader.batch_sampler.sampler.shuffle == sampler_states["shuffle"] + # 3. 检查 fp16 是否被加载 if fp16: assert not isinstance(driver2.grad_scaler, torch.cuda.amp.GradScaler)