@@ -20,6 +20,7 @@ class ProgressCallback(HasMonitorCallback): | |||
must_have_monitor=must_have_monitor) | |||
self.best_monitor_epoch = -1 | |||
self.best_monitor_step = -1 | |||
self.best_results = None | |||
def record_better_monitor(self, trainer): | |||
self.best_monitor_step = trainer.global_forward_batches | |||
@@ -29,6 +30,8 @@ class ProgressCallback(HasMonitorCallback): | |||
if self.best_monitor_epoch != -1: | |||
msg = f"The best performance for monitor {self._real_monitor}:{self.monitor_value} was achieved in" \ | |||
f" Epoch:{self.best_monitor_epoch}, Global Batch:{self.best_monitor_step}." | |||
if self.best_results is not None: | |||
msg = msg + ' The evaluation result: \n' + str(self.best_results) | |||
logger.info(msg) | |||
@property | |||
@@ -147,9 +150,11 @@ class RichCallback(ProgressCallback): | |||
results = {key:trainer.driver.tensor_to_numeric(value) for key, value in results.items() if | |||
not key.startswith('_')} | |||
if self.format_json: | |||
self.progress_bar.console.print_json(json.dumps(results)) | |||
results = json.dumps(results) | |||
self.progress_bar.console.print_json(results) | |||
else: | |||
self.progress_bar.print(results) | |||
self.best_results = results | |||
def clear_tasks(self): | |||
for key, taskid in self.task2id.items(): | |||
@@ -227,9 +232,9 @@ class RawTextCallback(ProgressCallback): | |||
results = {key:trainer.driver.tensor_to_numeric(value) for key, value in results.items() if | |||
not key.startswith('_')} | |||
if self.format_json: | |||
logger.info(json.dumps(results)) | |||
else: | |||
logger.info(results) | |||
results = json.dumps(results) | |||
logger.info(results) | |||
self.best_results = results | |||
@property | |||
def name(self): # progress bar的名称 | |||
@@ -316,9 +321,9 @@ class TqdmCallback(ProgressCallback): | |||
results = {key:trainer.driver.tensor_to_numeric(value) for key, value in results.items() if | |||
not key.startswith('_')} | |||
if self.format_json: | |||
logger.info(json.dumps(results)) | |||
else: | |||
logger.info(results) | |||
results = json.dumps(results) | |||
logger.info(results) | |||
self.best_results = results | |||
def clear_tasks(self): | |||
for key, taskid in self.task2id.items(): | |||
@@ -35,6 +35,7 @@ from fastNLP.envs import rank_zero_call | |||
from fastNLP.core.log import logger | |||
from fastNLP.envs import FASTNLP_MODEL_FILENAME, FASTNLP_CHECKPOINT_FILENAME | |||
from fastNLP.core.utils.exceptions import EarlyStopException | |||
from fastNLP.core.dataloaders import OverfitDataLoader | |||
class Trainer(TrainerEventTrigger): | |||
@@ -244,7 +245,20 @@ class Trainer(TrainerEventTrigger): | |||
注意该参数仅当 ``Trainer`` 内置的 ``Evaluator`` 不为 None 时且有需要该参数但是没有设置该参数的 *callback* 实例才有效; | |||
:param n_batches: 迭代多少个 ``batch`` 的训练结束。当该值不为 -1 时,将直接忽略 ``n_epochs`` 的值。 | |||
:param n_batches: 总共迭代多少个 ``batch`` 的训练结束。当该值不为 -1 时,将直接忽略 ``n_epochs`` 的值。 | |||
:param overfit_batches: 使用该参数来支持 '过拟合' 的功能;支持的值为 ``-1``、``0`` 或者 大于 0 的整数,表示使用多少个 batch 的数据 | |||
来进行过拟合训练;其中 0 为表示不进行任何操作;-1 表示使用所有的数据进行训练; | |||
.. note:: | |||
您可以使用该参数来简单地查看您的模型是否是 '正确的',即您的模型是否能够在少量的数据上快速进行收敛,从而说明损失函数以及优化器等 | |||
没有问题。当使用该参数时,我们会直接从 ``train_dataloader`` 中提取固定数量的 batch,然后在所有 epoch 中都是用这些数据 | |||
来进行训练; | |||
.. warning:: | |||
在使用该参数时,您同样可以指定 ``metrics`` 参数来进行简单的验证,当该参数和 ``metrics`` 同时出现时,我们会将 evaluate_dataloaders | |||
直接替换为在过拟合中所使用的训练数据;因此您需要保证您的 ``metrics`` 是能够在 ``train_dataloader`` 上使用的; | |||
:param marker: 用于标记一个 ``Trainer`` 实例,从而在用户调用 ``Trainer.on`` 函数时,标记该函数属于哪一个具体的 ``Trainer`` 实例;默认为 None; | |||
@@ -372,6 +386,7 @@ class Trainer(TrainerEventTrigger): | |||
monitor: Union[str, Callable] = None, | |||
larger_better: bool = True, | |||
n_batches: int = -1, | |||
overfit_batches: int = 0, | |||
marker: Optional[str] = None, | |||
**kwargs | |||
): | |||
@@ -469,9 +484,6 @@ class Trainer(TrainerEventTrigger): | |||
n_batches=n_batches | |||
) | |||
if metrics is None and evaluate_dataloaders is not None: | |||
raise ValueError("You have set 'evaluate_dataloaders' but forget to set 'metrics'.") | |||
if metrics is not None and evaluate_dataloaders is None: | |||
raise ValueError("You have set 'metrics' but forget to set 'evaluate_dataloaders'.") | |||
@@ -495,33 +507,42 @@ class Trainer(TrainerEventTrigger): | |||
else: | |||
_dist_sampler = None | |||
self.dataloader = self.train_dataloader | |||
self.driver.set_deterministic_dataloader(self.dataloader) | |||
self.dataloader = self.driver.set_dist_repro_dataloader(dataloader=self.train_dataloader, dist=_dist_sampler, | |||
reproducible=self.callback_manager._need_reproducible_sampler) | |||
# 进行 overfit 相关的设置; | |||
if overfit_batches != 0: | |||
self.dataloader = OverfitDataLoader(self.dataloader, overfit_batches) | |||
self.overfit_batches = overfit_batches | |||
self.evaluator = None | |||
self.monitor = monitor | |||
self.larger_better = larger_better | |||
if metrics is not None and evaluate_dataloaders is not None: | |||
check_evaluate_every(evaluate_every) | |||
progress_bar = kwargs.get('progress_bar', 'auto') # 如果不为 | |||
if not (isinstance(progress_bar, str) or progress_bar is None): # 应该是ProgressCallback,获取其名称。 | |||
progress_bar = progress_bar.name | |||
self.evaluator = Evaluator(model=model, dataloaders=evaluate_dataloaders, metrics=metrics, | |||
driver=self.driver, evaluate_batch_step_fn=evaluate_batch_step_fn, | |||
evaluate_fn=evaluate_fn, input_mapping=evaluate_input_mapping, | |||
output_mapping=evaluate_output_mapping, fp16=fp16, verbose=0, | |||
use_dist_sampler=kwargs.get("evaluate_use_dist_sampler", use_dist_sampler), | |||
progress_bar=progress_bar, | |||
check_dataloader_legality=kwargs.get('check_dataloader_legality', True)) | |||
if metrics is not None: | |||
if overfit_batches != 0: | |||
evaluate_dataloaders = self.dataloader | |||
if evaluate_dataloaders is not None: | |||
check_evaluate_every(evaluate_every) | |||
progress_bar = kwargs.get('progress_bar', 'auto') # 如果不为 | |||
if not (isinstance(progress_bar, str) or progress_bar is None): # 应该是ProgressCallback,获取其名称。 | |||
progress_bar = progress_bar.name | |||
self.evaluator = Evaluator(model=model, dataloaders=evaluate_dataloaders, metrics=metrics, | |||
driver=self.driver, evaluate_batch_step_fn=evaluate_batch_step_fn, | |||
evaluate_fn=evaluate_fn, input_mapping=evaluate_input_mapping, | |||
output_mapping=evaluate_output_mapping, fp16=fp16, verbose=0, | |||
use_dist_sampler=kwargs.get("evaluate_use_dist_sampler", use_dist_sampler), | |||
progress_bar=progress_bar, | |||
check_dataloader_legality=kwargs.get('check_dataloader_legality', True)) | |||
else: | |||
raise ValueError("You have set 'evaluate_dataloaders' but forget to set 'metrics'.") | |||
if train_fn is not None and not isinstance(train_fn, str): | |||
raise TypeError("Parameter `train_fn` can only be `str` type when it is not None.") | |||
self._train_step, self._train_step_signature_fn = self.driver.get_model_call_fn("train_step" if train_fn is None else train_fn) | |||
self.train_fn = train_fn | |||
self.dataloader = self.train_dataloader | |||
self.driver.set_deterministic_dataloader(self.dataloader) | |||
self.dataloader = self.driver.set_dist_repro_dataloader(dataloader=self.train_dataloader, dist=_dist_sampler, | |||
reproducible=self.callback_manager._need_reproducible_sampler) | |||
self.evaluate_batch_step_fn = evaluate_batch_step_fn | |||
self.kwargs = kwargs | |||
@@ -7,10 +7,13 @@ __all__ = [ | |||
'prepare_paddle_dataloader', | |||
'prepare_torch_dataloader', | |||
"prepare_dataloader" | |||
"prepare_dataloader", | |||
"OverfitDataLoader" | |||
] | |||
from .jittor_dataloader import JittorDataLoader, prepare_jittor_dataloader | |||
from .torch_dataloader import TorchDataLoader, prepare_torch_dataloader, MixDataLoader | |||
from .paddle_dataloader import PaddleDataLoader, prepare_paddle_dataloader | |||
from .prepare_dataloader import prepare_dataloader | |||
from .prepare_dataloader import prepare_dataloader | |||
from .utils import OverfitDataLoader |
@@ -1,4 +1,5 @@ | |||
from typing import Callable, Any, Union | |||
import os | |||
from typing import Callable, Any, Union, Sequence | |||
from abc import ABC | |||
import inspect | |||
import ast | |||
@@ -6,7 +7,8 @@ import ast | |||
from ..log import logger | |||
from ..utils.cache_results import get_func_calls, truncate_start_blanks | |||
__all__ = [ | |||
"indice_collate_wrapper" | |||
"indice_collate_wrapper", | |||
"OverfitDataLoader" | |||
] | |||
@@ -111,6 +113,35 @@ class HasLenGetitemType(ABC): | |||
return NotImplemented | |||
class OverfitDataLoader: | |||
""" | |||
实现一个简单的迭代器来模拟实际的 dataloader,从给定的 dataloader 中取出部分数据,来让 Trainer 实现 overfit 的功能; | |||
""" | |||
def __init__(self, dataloader, overfit_batches: int): | |||
self.dataloader = dataloader # 需要将实际的 dataloader 挂载到该对象上,从而应付一些对于实际的 dataloader 的操作; | |||
self.batches = [] | |||
self.overfit_batches = int(overfit_batches) | |||
if self.overfit_batches > len(dataloader): | |||
logger.warning("Parameter 'overfit_batches' is bigger than the length of 'train_dataloader'.") | |||
for idx, batch in enumerate(dataloader): | |||
if idx < self.overfit_batches or self.overfit_batches <= -1: | |||
self.batches.append(batch) | |||
def __len__(self): | |||
return len(self.batches) | |||
def __iter__(self): | |||
for batch in self.batches: | |||
yield batch | |||
def __getattr__(self, item): | |||
return getattr(self.dataloader, item) | |||
if __name__ == '__main__': | |||
def demo(*args, **kwargs): | |||
pass | |||
@@ -6,6 +6,7 @@ from dataclasses import dataclass | |||
from fastNLP.envs.imports import _NEED_IMPORT_JITTOR | |||
from fastNLP.core.drivers.driver import Driver | |||
from fastNLP.core.dataloaders import JittorDataLoader | |||
from fastNLP.core.dataloaders import OverfitDataLoader | |||
from fastNLP.core.samplers import ReproducibleSampler, RandomSampler | |||
from fastNLP.core.log import logger | |||
from fastNLP.core.utils import apply_to_collection, nullcontext | |||
@@ -69,7 +70,7 @@ class JittorDriver(Driver): | |||
self.wo_auto_param_call = kwargs.get("model_wo_auto_param_call", False) | |||
def check_dataloader_legality(self, dataloader): | |||
if not isinstance(dataloader, (Dataset, JittorDataLoader)): | |||
if not isinstance(dataloader, (Dataset, JittorDataLoader, OverfitDataLoader)): | |||
raise TypeError(f"{Dataset} or {JittorDataLoader} is expected, instead of `{type(dataloader)}`") | |||
if len(dataloader) == 0: | |||
logger.rank_zero_warning("Your dataloader is empty, which is not recommended because it " | |||
@@ -14,6 +14,7 @@ from fastNLP.envs import ( | |||
FASTNLP_BACKEND_LAUNCH, | |||
FASTNLP_GLOBAL_SEED, | |||
) | |||
from fastNLP.core.samplers import ReproducibleBatchSampler | |||
from fastNLP.core.log import logger | |||
if _NEED_IMPORT_JITTOR: | |||
@@ -63,6 +64,9 @@ def replace_batch_sampler(dataloader, batch_sampler): | |||
"or report this bug to us.") | |||
def replace_sampler(dataloader: Union["Dataset", "JittorDataLoader"], sampler): | |||
batch_sampler = getattr(dataloader, "sampler") | |||
if batch_sampler is not None and isinstance(batch_sampler, ReproducibleBatchSampler): | |||
raise RuntimeError("It should not be running here, please report a bug to us.") | |||
if isinstance(dataloader, JittorDataLoader): | |||
init_params = dict(inspect.signature(dataloader.__init__).parameters) | |||
reconstruct_args = {name: getattr(dataloader, name, p.default) for name, p in init_params.items()} | |||
@@ -19,6 +19,7 @@ from fastNLP.envs import ( | |||
rank_zero_call, | |||
) | |||
from fastNLP.core.log import logger | |||
from fastNLP.core.dataloaders import OverfitDataLoader | |||
from fastNLP.core.samplers import ( | |||
ReproducibleBatchSampler, | |||
ReproducibleSampler, | |||
@@ -93,7 +94,7 @@ class PaddleDriver(Driver): | |||
self.grad_scaler.update() | |||
def check_dataloader_legality(self, dataloader): | |||
if not isinstance(dataloader, DataLoader): | |||
if not isinstance(dataloader, DataLoader) and not isinstance(dataloader, OverfitDataLoader): | |||
raise TypeError(f"{DataLoader} is expected, instead of `{type(dataloader)}`") | |||
if dataloader.batch_size is None and dataloader.batch_sampler is None: | |||
raise ValueError("Please ensure at least one of your dataloader's batch_size and batch_sampler" | |||
@@ -15,6 +15,7 @@ from fastNLP.envs import ( | |||
FASTNLP_BACKEND_LAUNCH, | |||
FASTNLP_GLOBAL_SEED, | |||
) | |||
from fastNLP.core.samplers import ReproducibleBatchSampler | |||
from fastNLP.core.utils import auto_param_call, paddle_to | |||
from fastNLP.core.log import logger | |||
@@ -129,7 +130,7 @@ def _build_fp16_env(dummy=False): | |||
"NOTE: your device does NOT support faster training with fp16, " | |||
"please switch to FP32 which is likely to be faster" | |||
) | |||
return auto_cast, GradScaler | |||
return auto_cast, GradScaler | |||
def find_free_ports(num): | |||
""" | |||
@@ -189,10 +190,11 @@ def replace_batch_sampler(dataloader: "DataLoader", batch_sampler: "BatchSampler | |||
non_default_params.add("dataset") | |||
reconstruct_args = {k: v for k, v in instance_attrs.items() if k in non_default_params} | |||
reconstruct_args.update({ | |||
"batch_sampler": batch_sampler, "shuffle": False, "drop_last": False, "batch_size": 1, | |||
"persistent_workers": dataloader._persistent_workers, | |||
}) | |||
if isinstance(dataloader, DataLoader): | |||
reconstruct_args.update({ | |||
"batch_sampler": batch_sampler, "shuffle": False, "drop_last": False, "batch_size": 1, | |||
"persistent_workers": dataloader._persistent_workers, | |||
}) | |||
# POSITIONAL_OR_KEYWORD 代表一般的参数 | |||
# 收集初始化函数中出现的、一般形式的、不带默认值且不在 reconstruct_args 中的参数 | |||
@@ -210,9 +212,10 @@ def replace_batch_sampler(dataloader: "DataLoader", batch_sampler: "BatchSampler | |||
required_args = sorted(required_args) | |||
dataloader_self_name = dataloader.__class__.__name__ | |||
raise Exception( | |||
f"Trying to inject `BatchSampler` into the `{dataloader_self_name}` instance. " | |||
"This would fail as some of the `__init__` arguments are not available as instance attributes. " | |||
f"The missing attributes are {required_args}. " | |||
f"Need to inject arguments {required_args} into the __init__ of `{dataloader_self_name}`. " | |||
f"But they are not found in the attribute of `{dataloader_self_name}`, fastNLP cannot determine its " | |||
f"value when try to reinitialize `{dataloader_self_name}`, please add `{required_args}` to be " | |||
f"`{dataloader_self_name}`'s attribute." | |||
) | |||
# 这种错误针对的是传入的 dataloader 不是直接的 DataLoader,而是定制了 DataLoader,但是 __init__ 中没有 **kwargs; | |||
@@ -224,10 +227,11 @@ def replace_batch_sampler(dataloader: "DataLoader", batch_sampler: "BatchSampler | |||
missing_kwargs = sorted(missing_kwargs) | |||
dataloader_self_name = dataloader.__class__.__name__ | |||
raise Exception( | |||
f"Trying to inject `BatchSampler` into the `{dataloader_self_name}` instance. " | |||
"This would fail as it doesn't expose all its attributes in the `__init__` signature. " | |||
f"The missing arguments are {missing_kwargs}. " | |||
f"The parameter:{missing_kwargs} needed to reinitialize `{dataloader_self_name}` is not found." | |||
) | |||
# 如果没有kwargs,则保证一下只传入需要的参数 | |||
if not isinstance(dataloader, DataLoader): | |||
reconstruct_args = {key:value for key,value in reconstruct_args.items() if key in init_params} | |||
return type(dataloader)(**reconstruct_args) | |||
@@ -235,6 +239,9 @@ def replace_sampler(dataloader, new_sampler): | |||
""" | |||
使用 ``new_sampler`` 重新构建一个 ``BatchSampler``,并替换到 ``dataloader`` 中 | |||
""" | |||
batch_sampler = getattr(dataloader, "batch_sampler") | |||
if batch_sampler is not None and isinstance(batch_sampler, ReproducibleBatchSampler): | |||
raise RuntimeError("It should not be running here, please report a bug to us.") | |||
new_batch_sampler = deepcopy(dataloader.batch_sampler) | |||
new_batch_sampler.sampler = new_sampler | |||
return replace_batch_sampler(dataloader, new_batch_sampler) | |||
@@ -140,9 +140,6 @@ if _NEED_IMPORT_TORCH: | |||
import torch.distributed as dist | |||
from torch.nn.parallel import DistributedDataParallel | |||
from torch.utils.data import BatchSampler | |||
from torch.utils.data import RandomSampler as TorchRandomSampler | |||
from torch.utils.data import SequentialSampler as TorchSequentialSampler | |||
from torch.utils.data import BatchSampler as TorchBatchSampler | |||
__all__ = [ | |||
'TorchDDPDriver' | |||
@@ -421,6 +418,7 @@ class TorchDDPDriver(TorchDriver): | |||
os.environ['MASTER_ADDR'] = self.master_address | |||
os.environ['MASTER_PORT'] = self.master_port | |||
os.environ["RANK"] = "0" | |||
os.environ["LOCAL_RANK"] = str(self.local_rank) | |||
os.environ["WORLD_SIZE"] = f"{self.world_size}" | |||
@@ -433,6 +431,7 @@ class TorchDDPDriver(TorchDriver): | |||
for rank in range(1, len(self.parallel_device)): | |||
env_copy = os.environ.copy() | |||
env_copy["LOCAL_RANK"] = f"{rank}" | |||
env_copy["RANK"] = f"{rank}" | |||
# 如果是多机,一定需要用户自己拉起,因此我们自己使用 open_subprocesses 开启的进程的 FASTNLP_GLOBAL_RANK 一定是 LOCAL_RANK; | |||
env_copy[FASTNLP_GLOBAL_RANK] = str(rank) | |||
@@ -31,6 +31,7 @@ from fastNLP.envs import rank_zero_call | |||
from fastNLP.envs import FASTNLP_GLOBAL_RANK, FASTNLP_MODEL_FILENAME, FASTNLP_CHECKPOINT_FILENAME | |||
from fastNLP.core.log import logger | |||
from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler, ReproduceBatchSampler, RandomSampler | |||
from fastNLP.core.dataloaders import OverfitDataLoader | |||
class TorchDriver(Driver): | |||
@@ -92,7 +93,7 @@ class TorchDriver(Driver): | |||
self.grad_scaler.update() | |||
def check_dataloader_legality(self, dataloader): | |||
if not isinstance(dataloader, DataLoader): | |||
if not isinstance(dataloader, DataLoader) and not isinstance(dataloader, OverfitDataLoader): | |||
raise TypeError(f"{DataLoader} is expected, instead of `{type(dataloader)}`") | |||
if len(dataloader) == 0: | |||
logger.rank_zero_warning("Your dataloader is empty, which is not recommended because it " | |||
@@ -181,18 +181,16 @@ def replace_sampler(dataloader: "DataLoader", sampler): | |||
instance_attrs = {k: v for k, v in vars(dataloader).items() if not k.startswith('_')} | |||
# 'multiprocessing_context' 是 user-defined function; | |||
instance_attrs["multiprocessing_context"] = dataloader.multiprocessing_context | |||
if getattr(dataloader, 'multiprocessing_context', None) is not None: | |||
instance_attrs["multiprocessing_context"] = dataloader.multiprocessing_context | |||
# 拿到 dataloader '__init__' 函数的默认函数签名; | |||
init_params = dict(inspect.signature(dataloader.__init__).parameters) | |||
# 这里为什么要单独弄的原因在于,用户在定制自己的 dataloader 的同时可能为了方便只设定一些参数,而后面直接使用 **kwargs 的方式,这时如果 | |||
# 其在初始化自己的 dataloader 实例的时候加入了一些其它的新的参数(首先这一步是必要的,因为我们只能通过这样加 sampler;另一方面,用户 | |||
# 可能确实通过 **kwargs 加入了一些新的参数),如果假设用户是这样使用的: "super().__init__(**kwargs)",那么我们就只能去 DataLoader | |||
# 中寻找; | |||
# 防止用户的 DataLoader 是继承了 pytorch 的 DataLoader,然后还是使用了 **kwargs 的方式对父类传参数 | |||
has_variadic_kwargs = any(v.kind is v.VAR_KEYWORD for k, v in init_params.items()) | |||
if has_variadic_kwargs: | |||
# 这里之所以这样写是因为用户自己定制的 Dataloader 中名字一样的参数所设置的默认值可能不同;因此不能直接使用 update 覆盖掉了; | |||
if has_variadic_kwargs and isinstance(dataloader, DataLoader): | |||
# 防止用户写入了 super().__init__(**kwargs) | |||
for key, value in dict(inspect.signature(DataLoader.__init__).parameters).items(): | |||
if key not in init_params and key != 'self': | |||
init_params[key] = value | |||
@@ -204,7 +202,8 @@ def replace_sampler(dataloader: "DataLoader", sampler): | |||
non_default_params.add("dataset") | |||
reconstruct_args = {k: v for k, v in instance_attrs.items() if k in non_default_params} | |||
reconstruct_args.update({"sampler": sampler, "shuffle": False, "batch_sampler": None}) | |||
if isinstance(dataloader, DataLoader): | |||
reconstruct_args.update({"sampler": sampler, "shuffle": False, "batch_sampler": None}) | |||
batch_sampler = getattr(dataloader, "batch_sampler") | |||
if batch_sampler is not None and isinstance(batch_sampler, ReproducibleBatchSampler): | |||
@@ -218,35 +217,31 @@ def replace_sampler(dataloader: "DataLoader", sampler): | |||
and p.name not in reconstruct_args | |||
} | |||
# 这种错误针对的是 __init__ 中的参数没有用同样名字的 self 挂上; | |||
# 在 attribute 中没有找到这些参数,导致了没有办法重新初始化 | |||
if required_args: | |||
required_args = sorted(required_args) | |||
dataloader_self_name = dataloader.__class__.__name__ | |||
raise Exception( | |||
f"Trying to inject `DistributedSampler` into the `{dataloader_self_name}` instance. " | |||
"This would fail as some of the `__init__` arguments are not available as instance attributes. " | |||
f"The missing attributes are {required_args}. " | |||
f"HINT: If you wrote the `{dataloader_self_name}` class, define `self.missing_arg_name` or " | |||
"manually add the `DistributedSampler` as: " | |||
f"`{dataloader_self_name}(dataset, sampler=DistributedSampler(dataset))`." | |||
f"Need to inject arguments {required_args} into the __init__ of `{dataloader_self_name}`. " | |||
f"But they are not found in the attribute of `{dataloader_self_name}`, fastNLP cannot determine its " | |||
f"value when try to reinitialize `{dataloader_self_name}`, please add `{required_args}` to be " | |||
f"`{dataloader_self_name}`'s attribute." | |||
) | |||
# 这种错误针对的是传入的 dataloader 不是直接的 DataLoader,而是定制了 DataLoader,但是 __init__ 中没有 **kwargs; | |||
if not has_variadic_kwargs: | |||
# the dataloader signature does not allow keyword arguments that need to be passed | |||
missing_kwargs = reconstruct_args.keys() - init_params.keys() | |||
if missing_kwargs: | |||
missing_kwargs = sorted(missing_kwargs) | |||
dataloader_self_name = dataloader.__class__.__name__ | |||
raise Exception( | |||
f"Trying to inject `DistributedSampler` into the `{dataloader_self_name}` instance. " | |||
"This would fail as it doesn't expose all its attributes in the `__init__` signature. " | |||
f"The missing arguments are {missing_kwargs}. " | |||
f"HINT: If you wrote the `{dataloader_self_name}` class, add the `__init__` arguments or " | |||
"manually add the `DistributedSampler` as: " | |||
f"`{dataloader_self_name}(dataset, sampler=DistributedSampler(dataset))`." | |||
f"The parameter:{missing_kwargs} needed to reinitialize `{dataloader_self_name}` is not found." | |||
) | |||
# 如果没有kwargs,则保证一下只传入需要的参数 | |||
if not isinstance(dataloader, DataLoader): | |||
reconstruct_args = {key:value for key,value in reconstruct_args.items() if key in init_params} | |||
return type(dataloader)(**reconstruct_args) | |||
@@ -260,6 +255,13 @@ def replace_batch_sampler(dataloader, new_batch_sampler): | |||
params_keys.remove(k) | |||
params = {k: getattr(dataloader, k) for k in params_keys} | |||
params["batch_sampler"] = new_batch_sampler | |||
if not isinstance(dataloader, DataLoader): | |||
init_params = dict(inspect.signature(dataloader.__init__).parameters) | |||
has_variadic_kwargs = any(v.kind is v.VAR_KEYWORD for k, v in init_params.items()) | |||
if not has_variadic_kwargs: | |||
params = {key:value for key,value in params.items() if key in init_params} | |||
return type(dataloader)(**params) | |||
@@ -98,7 +98,7 @@ class Metric: | |||
return _wrap_get_metric | |||
def __setattr__(self, key, value): | |||
if hasattr(self, '_cannot_change_element') and self._cannot_change_element is True: | |||
if getattr(self, '_cannot_change_element', False): | |||
if key in self.elements and isinstance(value, (float, int, bool)): | |||
self.elements[key].fill_value(value) | |||
return | |||
@@ -109,6 +109,14 @@ class Metric: | |||
raise RuntimeError("Please use register_element() function to add Element.") | |||
object.__setattr__(self, key, value) | |||
# 当调用 __getattribute__ 没有找到时才会触发这个, 保留这个的目的只是为了防止 ide 的 warning | |||
def __getattr__(self, name: str) -> Element: | |||
if 'elements' in self.__dict__: | |||
elements = self.__dict__['elements'] | |||
if name in elements: | |||
return elements[name] | |||
raise AttributeError("`{}` object has no attribute `{}`.".format(type(self).__name__, name)) | |||
def _wrap_update(self, update): | |||
@functools.wraps(update) | |||
def _wrap_update(*args, **kwargs): | |||
@@ -286,6 +286,9 @@ def test_trainer_specific_params_1( | |||
assert trainer.driver.non_blocking is False | |||
assert trainer.driver.wo_auto_param_call is True | |||
if dist.is_initialized(): | |||
dist.destroy_process_group() | |||
@pytest.mark.torch | |||
@pytest.mark.parametrize("driver,device", [("torch", [0, 1])]) # ("torch", [0, 1]),("torch", 1) | |||
@@ -332,5 +335,47 @@ def test_trainer_specific_params_2( | |||
assert _ddp_kwargs.get("broadcast_buffers") is True | |||
assert _ddp_kwargs.get("find_unused_parameters") is True | |||
if dist.is_initialized(): | |||
dist.destroy_process_group() | |||
@pytest.mark.torch | |||
@pytest.mark.parametrize("driver,device", [("torch", 1), ("torch", [0, 1])]) # ("torch", [0, 1]),("torch", 1) | |||
@pytest.mark.parametrize("overfit_batches,num_train_batch_per_epoch", [(-1, -1), (0, -1), (3, 10), (6, -1)]) | |||
@magic_argv_env_context | |||
def test_trainer_w_evaluator_overfit_torch( | |||
model_and_optimizers: TrainerParameters, | |||
driver, | |||
device, | |||
overfit_batches, | |||
num_train_batch_per_epoch | |||
): | |||
""" | |||
测试一些特殊的参数是否能够正确地传递; | |||
""" | |||
trainer = Trainer( | |||
model=model_and_optimizers.model, | |||
driver=driver, | |||
device=device, | |||
overfit_batches=overfit_batches, | |||
optimizers=model_and_optimizers.optimizers, | |||
train_dataloader=model_and_optimizers.train_dataloader, | |||
evaluate_dataloaders={"dl": model_and_optimizers.evaluate_dataloaders}, | |||
input_mapping=model_and_optimizers.input_mapping, | |||
output_mapping=model_and_optimizers.output_mapping, | |||
metrics=model_and_optimizers.metrics, | |||
n_epochs=2, | |||
output_from_new_proc="all", | |||
evaluate_every=-1, | |||
torch_kwargs={ | |||
"non_blocking": False, | |||
"set_grad_to_none": True | |||
} | |||
) | |||
trainer.run(num_train_batch_per_epoch=num_train_batch_per_epoch) | |||
if dist.is_initialized(): | |||
dist.destroy_process_group() |
@@ -361,5 +361,35 @@ def test_torch_wo_auto_param_call( | |||
dist.destroy_process_group() | |||
# 测试 accumulation_steps; | |||
@pytest.mark.torch | |||
@pytest.mark.parametrize("driver,device", [("torch", 1), ("torch", [0, 1])]) | |||
@pytest.mark.parametrize("overfit_batches,num_train_batch_per_epoch", [(-1, -1), (0, -1), (3, 10), (6, -1)]) | |||
@magic_argv_env_context | |||
def test_trainer_overfit_torch( | |||
model_and_optimizers: TrainerParameters, | |||
driver, | |||
device, | |||
overfit_batches, | |||
num_train_batch_per_epoch | |||
): | |||
trainer = Trainer( | |||
model=model_and_optimizers.model, | |||
driver=driver, | |||
device=device, | |||
overfit_batches=overfit_batches, | |||
optimizers=model_and_optimizers.optimizers, | |||
train_dataloader=model_and_optimizers.train_dataloader, | |||
evaluate_dataloaders=model_and_optimizers.evaluate_dataloaders, | |||
input_mapping=model_and_optimizers.input_mapping, | |||
output_mapping=model_and_optimizers.output_mapping, | |||
metrics=model_and_optimizers.metrics, | |||
output_from_new_proc="all", | |||
n_epochs=2, | |||
) | |||
trainer.run(num_train_batch_per_epoch=num_train_batch_per_epoch) | |||
if dist.is_initialized(): | |||
dist.destroy_process_group() | |||
@@ -1,3 +1,5 @@ | |||
import os | |||
import pytest | |||
from fastNLP.core.drivers import PaddleSingleDriver, PaddleFleetDriver | |||
@@ -40,9 +42,14 @@ def test_get_fleet(device): | |||
""" | |||
测试 fleet 多卡的初始化情况 | |||
""" | |||
flag = False | |||
if "USER_CUDA_VISIBLE_DEVICES" not in os.environ: | |||
os.environ["USER_CUDA_VISIBLE_DEVICES"] = "0,1,2,3" | |||
flag = True | |||
model = PaddleNormalModel_Classification_1(20, 10) | |||
driver = initialize_paddle_driver("paddle", device, model) | |||
if flag: | |||
del os.environ["USER_CUDA_VISIBLE_DEVICES"] | |||
assert isinstance(driver, PaddleFleetDriver) | |||