@@ -202,26 +202,12 @@ class TorchDriver(Driver): | |||||
num_consumed_batches = states.pop('num_consumed_batches') | num_consumed_batches = states.pop('num_consumed_batches') | ||||
if hasattr(sampler, 'state_dict') and callable(sampler.state_dict): | if hasattr(sampler, 'state_dict') and callable(sampler.state_dict): | ||||
sampler_states = sampler.state_dict() | sampler_states = sampler.state_dict() | ||||
# 需要针对 num_consumed_samples 做特殊的处理。因为DataLoader存在预取行为,直接使用sampler中的num_consumed_samples | |||||
# 会造成多余实际消耗的问题。因为 | |||||
num_consumed_samples_array = sampler_states.pop('num_consumed_samples_array', None) | |||||
if num_consumed_samples_array is not None: | |||||
if isinstance(sampler, ReproducibleSampler): # 如果是 sampler 的话,需要考虑 batch_size 。 | |||||
if dataloader_args.batch_size is not None: | |||||
num_consumed_batches = num_consumed_batches * dataloader_args.batch_size | |||||
else: # 有可能 batch_size 为 None,就只有损失精度了 | |||||
logger.rank_zero_warning("fastNLP cannot get batch_size, we have to save based on `num_consumed_samples`, " | |||||
"it may cause missing some samples when reload.") | |||||
num_consumed_batches = sampler_states['num_consumed_samples'] | |||||
sampler_states['num_consumed_samples'] = num_consumed_samples_array[num_consumed_batches] | |||||
assert sampler_states['num_consumed_samples'] != -1, "This is a bug, please report." | |||||
if dataloader_args.batch_size is not None: | |||||
sampler_states['num_consumed_samples'] = sampler.num_replicas * dataloader_args.batch_size \ | |||||
* num_consumed_batches | |||||
else: | else: | ||||
if dataloader_args.batch_size is not None: | |||||
sampler_states['num_consumed_samples'] = sampler.num_replicas * dataloader_args.batch_size \ | |||||
* num_consumed_batches | |||||
else: | |||||
logger.rank_zero_warning("fastNLP cannot get batch_size, we have to save based on `num_consumed_samples`, " | |||||
"it may cause missing some samples when reload.") | |||||
logger.rank_zero_warning("fastNLP cannot get batch_size, we have to save based on sampler's " | |||||
"`num_consumed_samples`, it may cause missing some samples when reload.") | |||||
states['sampler_states'] = sampler_states | states['sampler_states'] = sampler_states | ||||
else: | else: | ||||
@@ -283,7 +283,7 @@ def optimizer_state_to_device(state, device): | |||||
def _check_dataloader_args_for_distributed(args, controller='Trainer'): | def _check_dataloader_args_for_distributed(args, controller='Trainer'): | ||||
if type(args.batch_sampler) is not TorchBatchSampler and (type(args.sampler) not in {TorchRandomSampler, | |||||
if type(args.batch_sampler) is not TorchBatchSampler or (type(args.sampler) not in {TorchRandomSampler, | |||||
TorchSequentialSampler}): | TorchSequentialSampler}): | ||||
mode = 'training' if controller == 'Trainer' else 'evaluation' | mode = 'training' if controller == 'Trainer' else 'evaluation' | ||||
substitution = 'fastNLP.RandomSampler' if controller == 'Trainer' else 'fastNLP.UnrepeatedSequentialSampler' | substitution = 'fastNLP.RandomSampler' if controller == 'Trainer' else 'fastNLP.UnrepeatedSequentialSampler' | ||||
@@ -13,7 +13,6 @@ from itertools import chain | |||||
import numpy as np | import numpy as np | ||||
from fastNLP.core.dataset import DataSet | from fastNLP.core.dataset import DataSet | ||||
from fastNLP.envs.utils import get_global_seed | |||||
from fastNLP.core.log import logger | from fastNLP.core.log import logger | ||||
from .utils import create_array | from .utils import create_array | ||||
from abc import abstractmethod | from abc import abstractmethod | ||||
@@ -171,7 +170,7 @@ class RandomBatchSampler(ReproducibleBatchSampler): | |||||
:param kwargs: fastNLP 保留使用 | :param kwargs: fastNLP 保留使用 | ||||
""" | """ | ||||
def __init__(self, dataset, batch_size:int = 32, shuffle: bool = True, | def __init__(self, dataset, batch_size:int = 32, shuffle: bool = True, | ||||
drop_last: bool = False, seed: int = None, **kwargs): | |||||
drop_last: bool = False, seed: int = 0, **kwargs): | |||||
super().__init__() | super().__init__() | ||||
self.dataset = dataset | self.dataset = dataset | ||||
@@ -179,7 +178,7 @@ class RandomBatchSampler(ReproducibleBatchSampler): | |||||
self.batch_size = batch_size | self.batch_size = batch_size | ||||
self.shuffle = shuffle | self.shuffle = shuffle | ||||
self.drop_last = drop_last | self.drop_last = drop_last | ||||
self.seed = get_global_seed() if seed is None else seed | |||||
self.seed = int(seed) | |||||
self.num_consumed_samples = kwargs.get("num_consumed_samples", 0) # 总共迭代了多少数据了,包括多卡情况下的其它卡上的输出的数量 | self.num_consumed_samples = kwargs.get("num_consumed_samples", 0) # 总共迭代了多少数据了,包括多卡情况下的其它卡上的输出的数量 | ||||
@@ -398,7 +397,7 @@ class BucketedBatchSampler(ReproducibleBatchSampler): | |||||
:param kwargs: fastNLP 保留使用 | :param kwargs: fastNLP 保留使用 | ||||
""" | """ | ||||
def __init__(self, dataset, length: Union[List[int], str], batch_size:int = 32, num_batch_per_bucket:int = 10, | def __init__(self, dataset, length: Union[List[int], str], batch_size:int = 32, num_batch_per_bucket:int = 10, | ||||
shuffle: bool = True, drop_last: bool = False, seed: int = None, **kwargs): | |||||
shuffle: bool = True, drop_last: bool = False, seed: int = 0, **kwargs): | |||||
super().__init__() | super().__init__() | ||||
if isinstance(dataset, DataSet) and isinstance(length, str): | if isinstance(dataset, DataSet) and isinstance(length, str): | ||||
length = dataset.get_field(length).content | length = dataset.get_field(length).content | ||||
@@ -423,7 +422,7 @@ class BucketedBatchSampler(ReproducibleBatchSampler): | |||||
self.num_batch_per_bucket = num_batch_per_bucket | self.num_batch_per_bucket = num_batch_per_bucket | ||||
self.shuffle = shuffle | self.shuffle = shuffle | ||||
self.drop_last = drop_last | self.drop_last = drop_last | ||||
self.seed = get_global_seed() if seed is None else seed | |||||
self.seed = int(seed) | |||||
self.num_consumed_samples = kwargs.get("num_consumed_samples", 0) # 总共迭代了多少数据了,包括多卡情况下的其它卡上的输出的数量 | self.num_consumed_samples = kwargs.get("num_consumed_samples", 0) # 总共迭代了多少数据了,包括多卡情况下的其它卡上的输出的数量 | ||||
@@ -12,7 +12,6 @@ import numpy as np | |||||
from fastNLP.core.log import logger | from fastNLP.core.log import logger | ||||
from fastNLP.core.dataset import DataSet | from fastNLP.core.dataset import DataSet | ||||
from fastNLP.envs.utils import get_global_seed | |||||
class ReproducibleSampler: | class ReproducibleSampler: | ||||
@@ -66,11 +65,11 @@ class RandomSampler(ReproducibleSampler): | |||||
:param seed: 随机数种子。 | :param seed: 随机数种子。 | ||||
:param kwargs: 用户不需要使用,fastNLP 内部使用 | :param kwargs: 用户不需要使用,fastNLP 内部使用 | ||||
""" | """ | ||||
def __init__(self, dataset, shuffle: bool = True, seed: int = None, **kwargs): | |||||
def __init__(self, dataset, shuffle: bool = True, seed: int = 0, **kwargs): | |||||
super(RandomSampler, self).__init__() | super(RandomSampler, self).__init__() | ||||
self.dataset = dataset | self.dataset = dataset | ||||
self.shuffle = shuffle | self.shuffle = shuffle | ||||
self.seed = get_global_seed() if seed is None else seed | |||||
self.seed = int(seed) | |||||
self.num_consumed_samples = kwargs.get("num_consumed_samples", 0) # 总共迭代了多少数据了,包括多卡情况下的其它卡上的输出的数量 | self.num_consumed_samples = kwargs.get("num_consumed_samples", 0) # 总共迭代了多少数据了,包括多卡情况下的其它卡上的输出的数量 | ||||
@@ -7,7 +7,6 @@ __all__ = [ | |||||
from typing import List, Union | from typing import List, Union | ||||
from fastNLP.core.dataset import DataSet | from fastNLP.core.dataset import DataSet | ||||
from fastNLP.envs.utils import get_global_seed | |||||
import numpy as np | import numpy as np | ||||
@@ -28,10 +27,10 @@ class UnrepeatedRandomSampler(UnrepeatedSampler): | |||||
:param seed: 设置的随机数种子 | :param seed: 设置的随机数种子 | ||||
:param kwargs: fastNLP 保留使用 | :param kwargs: fastNLP 保留使用 | ||||
""" | """ | ||||
def __init__(self, dataset, shuffle: bool = False, seed: int = None, **kwargs): | |||||
def __init__(self, dataset, shuffle: bool = False, seed: int = 0, **kwargs): | |||||
self.dataset = dataset | self.dataset = dataset | ||||
self.shuffle = shuffle | self.shuffle = shuffle | ||||
self.seed = get_global_seed() if seed is None else seed | |||||
self.seed = int(seed) | |||||
# 多卡的相关的参数 | # 多卡的相关的参数 | ||||
self.num_replicas = kwargs.get('num_replicas', 1) | self.num_replicas = kwargs.get('num_replicas', 1) | ||||
@@ -1,3 +1,5 @@ | |||||
import os | |||||
import pytest | import pytest | ||||
from pathlib import Path | from pathlib import Path | ||||
@@ -185,7 +187,7 @@ class TestSetDistReproDataloader: | |||||
cls.device = [0, 1] | cls.device = [0, 1] | ||||
def setup_method(self): | def setup_method(self): | ||||
self.dataset = TorchNormalDataset(40) | |||||
self.dataset = TorchNormalDataset(100) | |||||
""" | """ | ||||
传入的 `dist` 参数为具体的 ReproducibleSampler 或 ReproducibleBatchSampler 的情况 | 传入的 `dist` 参数为具体的 ReproducibleSampler 或 ReproducibleBatchSampler 的情况 | ||||
@@ -571,7 +573,7 @@ class TestSaveLoad: | |||||
""" | """ | ||||
def setup_method(self): | def setup_method(self): | ||||
self.dataset = TorchNormalXYDataset(20) | |||||
self.dataset = TorchNormalXYDataset(100) | |||||
@magic_argv_env_context | @magic_argv_env_context | ||||
@pytest.mark.parametrize("only_state_dict", ([True, False])) | @pytest.mark.parametrize("only_state_dict", ([True, False])) | ||||
@@ -641,7 +643,7 @@ class TestSaveLoad: | |||||
rank=driver1.global_rank, | rank=driver1.global_rank, | ||||
pad=True | pad=True | ||||
) | ) | ||||
num_consumed_batches = 2 | |||||
num_consumed_batches = 4 | |||||
already_seen_x_set = set() | already_seen_x_set = set() | ||||
already_seen_y_set = set() | already_seen_y_set = set() | ||||
@@ -686,7 +688,8 @@ class TestSaveLoad: | |||||
assert not (replaced_loader is dataloader) | assert not (replaced_loader is dataloader) | ||||
assert replaced_loader.batch_sampler is dataloader.batch_sampler | assert replaced_loader.batch_sampler is dataloader.batch_sampler | ||||
assert isinstance(replaced_loader.batch_sampler, BucketedBatchSampler) | assert isinstance(replaced_loader.batch_sampler, BucketedBatchSampler) | ||||
assert replaced_loader.batch_sampler.seed == sampler_states["seed"] | |||||
if os.environ['FASTNLP_GLOBAL_RANK'] == '0': | |||||
assert replaced_loader.batch_sampler.seed == sampler_states["seed"] | |||||
assert replaced_loader.batch_sampler.num_consumed_samples == num_consumed_batches * 4 * num_replicas | assert replaced_loader.batch_sampler.num_consumed_samples == num_consumed_batches * 4 * num_replicas | ||||
# 3. 检查 fp16 是否被加载 | # 3. 检查 fp16 是否被加载 | ||||
@@ -753,7 +756,7 @@ class TestSaveLoad: | |||||
rank=driver1.global_rank, | rank=driver1.global_rank, | ||||
pad=True | pad=True | ||||
) | ) | ||||
num_consumed_batches = 2 | |||||
num_consumed_batches = 4 | |||||
already_seen_x_set = set() | already_seen_x_set = set() | ||||
already_seen_y_set = set() | already_seen_y_set = set() | ||||
@@ -792,11 +795,13 @@ class TestSaveLoad: | |||||
# 2. 检查 sampler 是否被正确地加载和替换 | # 2. 检查 sampler 是否被正确地加载和替换 | ||||
assert not (replaced_loader is dataloader) | assert not (replaced_loader is dataloader) | ||||
assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler) | assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler) | ||||
assert replaced_loader.batch_sampler.sampler.seed == sampler_states["seed"] | |||||
assert replaced_loader.batch_sampler.sampler.epoch == sampler_states["epoch"] | |||||
if os.environ['FASTNLP_GLOBAL_RANK'] == '0': | |||||
assert replaced_loader.batch_sampler.sampler.seed == sampler_states["seed"] | |||||
assert replaced_loader.batch_sampler.sampler.epoch == sampler_states["epoch"] | |||||
assert len(replaced_loader.batch_sampler.sampler.dataset) == sampler_states["length"] | |||||
assert replaced_loader.batch_sampler.sampler.shuffle == sampler_states["shuffle"] | |||||
assert replaced_loader.batch_sampler.sampler.num_consumed_samples == 4 * num_consumed_batches * num_replicas | assert replaced_loader.batch_sampler.sampler.num_consumed_samples == 4 * num_consumed_batches * num_replicas | ||||
assert len(replaced_loader.batch_sampler.sampler.dataset) == sampler_states["length"] | |||||
assert replaced_loader.batch_sampler.sampler.shuffle == sampler_states["shuffle"] | |||||
# 3. 检查 fp16 是否被加载 | # 3. 检查 fp16 是否被加载 | ||||
if fp16: | if fp16: | ||||
assert not isinstance(driver2.grad_scaler, torch.cuda.amp.GradScaler) | assert not isinstance(driver2.grad_scaler, torch.cuda.amp.GradScaler) | ||||