Browse Source

修复设置了global seed的bug

tags/v1.0.0alpha
yhcc 3 years ago
parent
commit
dd1c5ca035
6 changed files with 28 additions and 40 deletions
  1. +5
    -19
      fastNLP/core/drivers/torch_driver/torch_driver.py
  2. +1
    -1
      fastNLP/core/drivers/torch_driver/utils.py
  3. +4
    -5
      fastNLP/core/samplers/reproducible_batch_sampler.py
  4. +2
    -3
      fastNLP/core/samplers/reproducible_sampler.py
  5. +2
    -3
      fastNLP/core/samplers/unrepeated_sampler.py
  6. +14
    -9
      tests/core/drivers/torch_driver/test_ddp.py

+ 5
- 19
fastNLP/core/drivers/torch_driver/torch_driver.py View File

@@ -202,26 +202,12 @@ class TorchDriver(Driver):
num_consumed_batches = states.pop('num_consumed_batches')
if hasattr(sampler, 'state_dict') and callable(sampler.state_dict):
sampler_states = sampler.state_dict()
# 需要针对 num_consumed_samples 做特殊的处理。因为DataLoader存在预取行为,直接使用sampler中的num_consumed_samples
# 会造成多余实际消耗的问题。因为
num_consumed_samples_array = sampler_states.pop('num_consumed_samples_array', None)
if num_consumed_samples_array is not None:
if isinstance(sampler, ReproducibleSampler): # 如果是 sampler 的话,需要考虑 batch_size 。
if dataloader_args.batch_size is not None:
num_consumed_batches = num_consumed_batches * dataloader_args.batch_size
else: # 有可能 batch_size 为 None,就只有损失精度了
logger.rank_zero_warning("fastNLP cannot get batch_size, we have to save based on `num_consumed_samples`, "
"it may cause missing some samples when reload.")
num_consumed_batches = sampler_states['num_consumed_samples']
sampler_states['num_consumed_samples'] = num_consumed_samples_array[num_consumed_batches]
assert sampler_states['num_consumed_samples'] != -1, "This is a bug, please report."
if dataloader_args.batch_size is not None:
sampler_states['num_consumed_samples'] = sampler.num_replicas * dataloader_args.batch_size \
* num_consumed_batches
else:
if dataloader_args.batch_size is not None:
sampler_states['num_consumed_samples'] = sampler.num_replicas * dataloader_args.batch_size \
* num_consumed_batches
else:
logger.rank_zero_warning("fastNLP cannot get batch_size, we have to save based on `num_consumed_samples`, "
"it may cause missing some samples when reload.")
logger.rank_zero_warning("fastNLP cannot get batch_size, we have to save based on sampler's "
"`num_consumed_samples`, it may cause missing some samples when reload.")

states['sampler_states'] = sampler_states
else:


+ 1
- 1
fastNLP/core/drivers/torch_driver/utils.py View File

@@ -283,7 +283,7 @@ def optimizer_state_to_device(state, device):


def _check_dataloader_args_for_distributed(args, controller='Trainer'):
if type(args.batch_sampler) is not TorchBatchSampler and (type(args.sampler) not in {TorchRandomSampler,
if type(args.batch_sampler) is not TorchBatchSampler or (type(args.sampler) not in {TorchRandomSampler,
TorchSequentialSampler}):
mode = 'training' if controller == 'Trainer' else 'evaluation'
substitution = 'fastNLP.RandomSampler' if controller == 'Trainer' else 'fastNLP.UnrepeatedSequentialSampler'


+ 4
- 5
fastNLP/core/samplers/reproducible_batch_sampler.py View File

@@ -13,7 +13,6 @@ from itertools import chain
import numpy as np

from fastNLP.core.dataset import DataSet
from fastNLP.envs.utils import get_global_seed
from fastNLP.core.log import logger
from .utils import create_array
from abc import abstractmethod
@@ -171,7 +170,7 @@ class RandomBatchSampler(ReproducibleBatchSampler):
:param kwargs: fastNLP 保留使用
"""
def __init__(self, dataset, batch_size:int = 32, shuffle: bool = True,
drop_last: bool = False, seed: int = None, **kwargs):
drop_last: bool = False, seed: int = 0, **kwargs):
super().__init__()

self.dataset = dataset
@@ -179,7 +178,7 @@ class RandomBatchSampler(ReproducibleBatchSampler):
self.batch_size = batch_size
self.shuffle = shuffle
self.drop_last = drop_last
self.seed = get_global_seed() if seed is None else seed
self.seed = int(seed)

self.num_consumed_samples = kwargs.get("num_consumed_samples", 0) # 总共迭代了多少数据了,包括多卡情况下的其它卡上的输出的数量

@@ -398,7 +397,7 @@ class BucketedBatchSampler(ReproducibleBatchSampler):
:param kwargs: fastNLP 保留使用
"""
def __init__(self, dataset, length: Union[List[int], str], batch_size:int = 32, num_batch_per_bucket:int = 10,
shuffle: bool = True, drop_last: bool = False, seed: int = None, **kwargs):
shuffle: bool = True, drop_last: bool = False, seed: int = 0, **kwargs):
super().__init__()
if isinstance(dataset, DataSet) and isinstance(length, str):
length = dataset.get_field(length).content
@@ -423,7 +422,7 @@ class BucketedBatchSampler(ReproducibleBatchSampler):
self.num_batch_per_bucket = num_batch_per_bucket
self.shuffle = shuffle
self.drop_last = drop_last
self.seed = get_global_seed() if seed is None else seed
self.seed = int(seed)

self.num_consumed_samples = kwargs.get("num_consumed_samples", 0) # 总共迭代了多少数据了,包括多卡情况下的其它卡上的输出的数量



+ 2
- 3
fastNLP/core/samplers/reproducible_sampler.py View File

@@ -12,7 +12,6 @@ import numpy as np

from fastNLP.core.log import logger
from fastNLP.core.dataset import DataSet
from fastNLP.envs.utils import get_global_seed


class ReproducibleSampler:
@@ -66,11 +65,11 @@ class RandomSampler(ReproducibleSampler):
:param seed: 随机数种子。
:param kwargs: 用户不需要使用,fastNLP 内部使用
"""
def __init__(self, dataset, shuffle: bool = True, seed: int = None, **kwargs):
def __init__(self, dataset, shuffle: bool = True, seed: int = 0, **kwargs):
super(RandomSampler, self).__init__()
self.dataset = dataset
self.shuffle = shuffle
self.seed = get_global_seed() if seed is None else seed
self.seed = int(seed)
self.num_consumed_samples = kwargs.get("num_consumed_samples", 0) # 总共迭代了多少数据了,包括多卡情况下的其它卡上的输出的数量



+ 2
- 3
fastNLP/core/samplers/unrepeated_sampler.py View File

@@ -7,7 +7,6 @@ __all__ = [

from typing import List, Union
from fastNLP.core.dataset import DataSet
from fastNLP.envs.utils import get_global_seed

import numpy as np

@@ -28,10 +27,10 @@ class UnrepeatedRandomSampler(UnrepeatedSampler):
:param seed: 设置的随机数种子
:param kwargs: fastNLP 保留使用
"""
def __init__(self, dataset, shuffle: bool = False, seed: int = None, **kwargs):
def __init__(self, dataset, shuffle: bool = False, seed: int = 0, **kwargs):
self.dataset = dataset
self.shuffle = shuffle
self.seed = get_global_seed() if seed is None else seed
self.seed = int(seed)

# 多卡的相关的参数
self.num_replicas = kwargs.get('num_replicas', 1)


+ 14
- 9
tests/core/drivers/torch_driver/test_ddp.py View File

@@ -1,3 +1,5 @@
import os

import pytest
from pathlib import Path

@@ -185,7 +187,7 @@ class TestSetDistReproDataloader:
cls.device = [0, 1]

def setup_method(self):
self.dataset = TorchNormalDataset(40)
self.dataset = TorchNormalDataset(100)

"""
传入的 `dist` 参数为具体的 ReproducibleSampler 或 ReproducibleBatchSampler 的情况
@@ -571,7 +573,7 @@ class TestSaveLoad:
"""

def setup_method(self):
self.dataset = TorchNormalXYDataset(20)
self.dataset = TorchNormalXYDataset(100)

@magic_argv_env_context
@pytest.mark.parametrize("only_state_dict", ([True, False]))
@@ -641,7 +643,7 @@ class TestSaveLoad:
rank=driver1.global_rank,
pad=True
)
num_consumed_batches = 2
num_consumed_batches = 4

already_seen_x_set = set()
already_seen_y_set = set()
@@ -686,7 +688,8 @@ class TestSaveLoad:
assert not (replaced_loader is dataloader)
assert replaced_loader.batch_sampler is dataloader.batch_sampler
assert isinstance(replaced_loader.batch_sampler, BucketedBatchSampler)
assert replaced_loader.batch_sampler.seed == sampler_states["seed"]
if os.environ['FASTNLP_GLOBAL_RANK'] == '0':
assert replaced_loader.batch_sampler.seed == sampler_states["seed"]
assert replaced_loader.batch_sampler.num_consumed_samples == num_consumed_batches * 4 * num_replicas

# 3. 检查 fp16 是否被加载
@@ -753,7 +756,7 @@ class TestSaveLoad:
rank=driver1.global_rank,
pad=True
)
num_consumed_batches = 2
num_consumed_batches = 4

already_seen_x_set = set()
already_seen_y_set = set()
@@ -792,11 +795,13 @@ class TestSaveLoad:
# 2. 检查 sampler 是否被正确地加载和替换
assert not (replaced_loader is dataloader)
assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler)
assert replaced_loader.batch_sampler.sampler.seed == sampler_states["seed"]
assert replaced_loader.batch_sampler.sampler.epoch == sampler_states["epoch"]
if os.environ['FASTNLP_GLOBAL_RANK'] == '0':
assert replaced_loader.batch_sampler.sampler.seed == sampler_states["seed"]
assert replaced_loader.batch_sampler.sampler.epoch == sampler_states["epoch"]
assert len(replaced_loader.batch_sampler.sampler.dataset) == sampler_states["length"]
assert replaced_loader.batch_sampler.sampler.shuffle == sampler_states["shuffle"]
assert replaced_loader.batch_sampler.sampler.num_consumed_samples == 4 * num_consumed_batches * num_replicas
assert len(replaced_loader.batch_sampler.sampler.dataset) == sampler_states["length"]
assert replaced_loader.batch_sampler.sampler.shuffle == sampler_states["shuffle"]

# 3. 检查 fp16 是否被加载
if fp16:
assert not isinstance(driver2.grad_scaler, torch.cuda.amp.GradScaler)


Loading…
Cancel
Save