Browse Source

Merge branch 'dev0.8.0' of github.com:fastnlp/fastNLP into dev0.8.0

tags/v1.0.0alpha
x54-729 3 years ago
parent
commit
a39d3011e8
10 changed files with 203 additions and 60 deletions
  1. +12
    -7
      fastNLP/core/callbacks/progress_callback.py
  2. +43
    -22
      fastNLP/core/controllers/trainer.py
  3. +5
    -2
      fastNLP/core/dataloaders/__init__.py
  4. +33
    -2
      fastNLP/core/dataloaders/utils.py
  5. +0
    -3
      fastNLP/core/drivers/torch_driver/ddp.py
  6. +2
    -1
      fastNLP/core/drivers/torch_driver/torch_driver.py
  7. +24
    -22
      fastNLP/core/drivers/torch_driver/utils.py
  8. +9
    -1
      fastNLP/core/metrics/metric.py
  9. +45
    -0
      tests/core/controllers/test_trainer_w_evaluator_torch.py
  10. +30
    -0
      tests/core/controllers/test_trainer_wo_evaluator_torch.py

+ 12
- 7
fastNLP/core/callbacks/progress_callback.py View File

@@ -20,6 +20,7 @@ class ProgressCallback(HasMonitorCallback):
must_have_monitor=must_have_monitor) must_have_monitor=must_have_monitor)
self.best_monitor_epoch = -1 self.best_monitor_epoch = -1
self.best_monitor_step = -1 self.best_monitor_step = -1
self.best_results = None


def record_better_monitor(self, trainer): def record_better_monitor(self, trainer):
self.best_monitor_step = trainer.global_forward_batches self.best_monitor_step = trainer.global_forward_batches
@@ -29,6 +30,8 @@ class ProgressCallback(HasMonitorCallback):
if self.best_monitor_epoch != -1: if self.best_monitor_epoch != -1:
msg = f"The best performance for monitor {self._real_monitor}:{self.monitor_value} was achieved in" \ msg = f"The best performance for monitor {self._real_monitor}:{self.monitor_value} was achieved in" \
f" Epoch:{self.best_monitor_epoch}, Global Batch:{self.best_monitor_step}." f" Epoch:{self.best_monitor_epoch}, Global Batch:{self.best_monitor_step}."
if self.best_results is not None:
msg = msg + ' The evaluation result: \n' + str(self.best_results)
logger.info(msg) logger.info(msg)


@property @property
@@ -147,9 +150,11 @@ class RichCallback(ProgressCallback):
results = {key:trainer.driver.tensor_to_numeric(value) for key, value in results.items() if results = {key:trainer.driver.tensor_to_numeric(value) for key, value in results.items() if
not key.startswith('_')} not key.startswith('_')}
if self.format_json: if self.format_json:
self.progress_bar.console.print_json(json.dumps(results))
results = json.dumps(results)
self.progress_bar.console.print_json(results)
else: else:
self.progress_bar.print(results) self.progress_bar.print(results)
self.best_results = results


def clear_tasks(self): def clear_tasks(self):
for key, taskid in self.task2id.items(): for key, taskid in self.task2id.items():
@@ -227,9 +232,9 @@ class RawTextCallback(ProgressCallback):
results = {key:trainer.driver.tensor_to_numeric(value) for key, value in results.items() if results = {key:trainer.driver.tensor_to_numeric(value) for key, value in results.items() if
not key.startswith('_')} not key.startswith('_')}
if self.format_json: if self.format_json:
logger.info(json.dumps(results))
else:
logger.info(results)
results = json.dumps(results)
logger.info(results)
self.best_results = results


@property @property
def name(self): # progress bar的名称 def name(self): # progress bar的名称
@@ -316,9 +321,9 @@ class TqdmCallback(ProgressCallback):
results = {key:trainer.driver.tensor_to_numeric(value) for key, value in results.items() if results = {key:trainer.driver.tensor_to_numeric(value) for key, value in results.items() if
not key.startswith('_')} not key.startswith('_')}
if self.format_json: if self.format_json:
logger.info(json.dumps(results))
else:
logger.info(results)
results = json.dumps(results)
logger.info(results)
self.best_results = results


def clear_tasks(self): def clear_tasks(self):
for key, taskid in self.task2id.items(): for key, taskid in self.task2id.items():


+ 43
- 22
fastNLP/core/controllers/trainer.py View File

@@ -35,6 +35,7 @@ from fastNLP.envs import rank_zero_call
from fastNLP.core.log import logger from fastNLP.core.log import logger
from fastNLP.envs import FASTNLP_MODEL_FILENAME, FASTNLP_CHECKPOINT_FILENAME from fastNLP.envs import FASTNLP_MODEL_FILENAME, FASTNLP_CHECKPOINT_FILENAME
from fastNLP.core.utils.exceptions import EarlyStopException from fastNLP.core.utils.exceptions import EarlyStopException
from fastNLP.core.dataloaders import OverfitDataLoader




class Trainer(TrainerEventTrigger): class Trainer(TrainerEventTrigger):
@@ -244,7 +245,20 @@ class Trainer(TrainerEventTrigger):


注意该参数仅当 ``Trainer`` 内置的 ``Evaluator`` 不为 None 时且有需要该参数但是没有设置该参数的 *callback* 实例才有效; 注意该参数仅当 ``Trainer`` 内置的 ``Evaluator`` 不为 None 时且有需要该参数但是没有设置该参数的 *callback* 实例才有效;


:param n_batches: 迭代多少个 ``batch`` 的训练结束。当该值不为 -1 时,将直接忽略 ``n_epochs`` 的值。
:param n_batches: 总共迭代多少个 ``batch`` 的训练结束。当该值不为 -1 时,将直接忽略 ``n_epochs`` 的值。
:param overfit_batches: 使用该参数来支持 '过拟合' 的功能;支持的值为 ``-1``、``0`` 或者 大于 0 的整数,表示使用多少个 batch 的数据
来进行过拟合训练;其中 0 为表示不进行任何操作;-1 表示使用所有的数据进行训练;

.. note::

您可以使用该参数来简单地查看您的模型是否是 '正确的',即您的模型是否能够在少量的数据上快速进行收敛,从而说明损失函数以及优化器等
没有问题。当使用该参数时,我们会直接从 ``train_dataloader`` 中提取固定数量的 batch,然后在所有 epoch 中都是用这些数据
来进行训练;

.. warning::

在使用该参数时,您同样可以指定 ``metrics`` 参数来进行简单的验证,当该参数和 ``metrics`` 同时出现时,我们会将 evaluate_dataloaders
直接替换为在过拟合中所使用的训练数据;因此您需要保证您的 ``metrics`` 是能够在 ``train_dataloader`` 上使用的;


:param marker: 用于标记一个 ``Trainer`` 实例,从而在用户调用 ``Trainer.on`` 函数时,标记该函数属于哪一个具体的 ``Trainer`` 实例;默认为 None; :param marker: 用于标记一个 ``Trainer`` 实例,从而在用户调用 ``Trainer.on`` 函数时,标记该函数属于哪一个具体的 ``Trainer`` 实例;默认为 None;


@@ -372,6 +386,7 @@ class Trainer(TrainerEventTrigger):
monitor: Union[str, Callable] = None, monitor: Union[str, Callable] = None,
larger_better: bool = True, larger_better: bool = True,
n_batches: int = -1, n_batches: int = -1,
overfit_batches: int = 0,
marker: Optional[str] = None, marker: Optional[str] = None,
**kwargs **kwargs
): ):
@@ -469,9 +484,6 @@ class Trainer(TrainerEventTrigger):
n_batches=n_batches n_batches=n_batches
) )


if metrics is None and evaluate_dataloaders is not None:
raise ValueError("You have set 'evaluate_dataloaders' but forget to set 'metrics'.")

if metrics is not None and evaluate_dataloaders is None: if metrics is not None and evaluate_dataloaders is None:
raise ValueError("You have set 'metrics' but forget to set 'evaluate_dataloaders'.") raise ValueError("You have set 'metrics' but forget to set 'evaluate_dataloaders'.")


@@ -495,33 +507,42 @@ class Trainer(TrainerEventTrigger):
else: else:
_dist_sampler = None _dist_sampler = None


self.dataloader = self.train_dataloader
self.driver.set_deterministic_dataloader(self.dataloader)

self.dataloader = self.driver.set_dist_repro_dataloader(dataloader=self.train_dataloader, dist=_dist_sampler,
reproducible=self.callback_manager._need_reproducible_sampler)
# 进行 overfit 相关的设置;
if overfit_batches != 0:
self.dataloader = OverfitDataLoader(self.dataloader, overfit_batches)
self.overfit_batches = overfit_batches

self.evaluator = None self.evaluator = None
self.monitor = monitor self.monitor = monitor
self.larger_better = larger_better self.larger_better = larger_better
if metrics is not None and evaluate_dataloaders is not None:
check_evaluate_every(evaluate_every)
progress_bar = kwargs.get('progress_bar', 'auto') # 如果不为
if not (isinstance(progress_bar, str) or progress_bar is None): # 应该是ProgressCallback,获取其名称。
progress_bar = progress_bar.name
self.evaluator = Evaluator(model=model, dataloaders=evaluate_dataloaders, metrics=metrics,
driver=self.driver, evaluate_batch_step_fn=evaluate_batch_step_fn,
evaluate_fn=evaluate_fn, input_mapping=evaluate_input_mapping,
output_mapping=evaluate_output_mapping, fp16=fp16, verbose=0,
use_dist_sampler=kwargs.get("evaluate_use_dist_sampler", use_dist_sampler),
progress_bar=progress_bar,
check_dataloader_legality=kwargs.get('check_dataloader_legality', True))
if metrics is not None:
if overfit_batches != 0:
evaluate_dataloaders = self.dataloader
if evaluate_dataloaders is not None:
check_evaluate_every(evaluate_every)
progress_bar = kwargs.get('progress_bar', 'auto') # 如果不为
if not (isinstance(progress_bar, str) or progress_bar is None): # 应该是ProgressCallback,获取其名称。
progress_bar = progress_bar.name
self.evaluator = Evaluator(model=model, dataloaders=evaluate_dataloaders, metrics=metrics,
driver=self.driver, evaluate_batch_step_fn=evaluate_batch_step_fn,
evaluate_fn=evaluate_fn, input_mapping=evaluate_input_mapping,
output_mapping=evaluate_output_mapping, fp16=fp16, verbose=0,
use_dist_sampler=kwargs.get("evaluate_use_dist_sampler", use_dist_sampler),
progress_bar=progress_bar,
check_dataloader_legality=kwargs.get('check_dataloader_legality', True))
else:
raise ValueError("You have set 'evaluate_dataloaders' but forget to set 'metrics'.")


if train_fn is not None and not isinstance(train_fn, str): if train_fn is not None and not isinstance(train_fn, str):
raise TypeError("Parameter `train_fn` can only be `str` type when it is not None.") raise TypeError("Parameter `train_fn` can only be `str` type when it is not None.")
self._train_step, self._train_step_signature_fn = self.driver.get_model_call_fn("train_step" if train_fn is None else train_fn) self._train_step, self._train_step_signature_fn = self.driver.get_model_call_fn("train_step" if train_fn is None else train_fn)
self.train_fn = train_fn self.train_fn = train_fn


self.dataloader = self.train_dataloader
self.driver.set_deterministic_dataloader(self.dataloader)

self.dataloader = self.driver.set_dist_repro_dataloader(dataloader=self.train_dataloader, dist=_dist_sampler,
reproducible=self.callback_manager._need_reproducible_sampler)

self.evaluate_batch_step_fn = evaluate_batch_step_fn self.evaluate_batch_step_fn = evaluate_batch_step_fn
self.kwargs = kwargs self.kwargs = kwargs




+ 5
- 2
fastNLP/core/dataloaders/__init__.py View File

@@ -7,10 +7,13 @@ __all__ = [
'prepare_paddle_dataloader', 'prepare_paddle_dataloader',
'prepare_torch_dataloader', 'prepare_torch_dataloader',


"prepare_dataloader"
"prepare_dataloader",

"OverfitDataLoader"
] ]


from .jittor_dataloader import JittorDataLoader, prepare_jittor_dataloader from .jittor_dataloader import JittorDataLoader, prepare_jittor_dataloader
from .torch_dataloader import TorchDataLoader, prepare_torch_dataloader, MixDataLoader from .torch_dataloader import TorchDataLoader, prepare_torch_dataloader, MixDataLoader
from .paddle_dataloader import PaddleDataLoader, prepare_paddle_dataloader from .paddle_dataloader import PaddleDataLoader, prepare_paddle_dataloader
from .prepare_dataloader import prepare_dataloader
from .prepare_dataloader import prepare_dataloader
from .utils import OverfitDataLoader

+ 33
- 2
fastNLP/core/dataloaders/utils.py View File

@@ -1,4 +1,5 @@
from typing import Callable, Any, Union
import os
from typing import Callable, Any, Union, Sequence
from abc import ABC from abc import ABC
import inspect import inspect
import ast import ast
@@ -6,7 +7,8 @@ import ast
from ..log import logger from ..log import logger
from ..utils.cache_results import get_func_calls, truncate_start_blanks from ..utils.cache_results import get_func_calls, truncate_start_blanks
__all__ = [ __all__ = [
"indice_collate_wrapper"
"indice_collate_wrapper",
"OverfitDataLoader"
] ]




@@ -111,6 +113,35 @@ class HasLenGetitemType(ABC):
return NotImplemented return NotImplemented




class OverfitDataLoader:
"""
实现一个简单的迭代器来模拟实际的 dataloader,从给定的 dataloader 中取出部分数据,来让 Trainer 实现 overfit 的功能;
"""

def __init__(self, dataloader, overfit_batches: int):
self.dataloader = dataloader # 需要将实际的 dataloader 挂载到该对象上,从而应付一些对于实际的 dataloader 的操作;
self.batches = []
self.overfit_batches = int(overfit_batches)

if self.overfit_batches > len(dataloader):
logger.warning("Parameter 'overfit_batches' is bigger than the length of 'train_dataloader'.")

for idx, batch in enumerate(dataloader):
if idx < self.overfit_batches or self.overfit_batches <= -1:
self.batches.append(batch)

def __len__(self):
return len(self.batches)

def __iter__(self):
for batch in self.batches:
yield batch

def __getattr__(self, item):
return getattr(self.dataloader, item)



if __name__ == '__main__': if __name__ == '__main__':
def demo(*args, **kwargs): def demo(*args, **kwargs):
pass pass


+ 0
- 3
fastNLP/core/drivers/torch_driver/ddp.py View File

@@ -140,9 +140,6 @@ if _NEED_IMPORT_TORCH:
import torch.distributed as dist import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel from torch.nn.parallel import DistributedDataParallel
from torch.utils.data import BatchSampler from torch.utils.data import BatchSampler
from torch.utils.data import RandomSampler as TorchRandomSampler
from torch.utils.data import SequentialSampler as TorchSequentialSampler
from torch.utils.data import BatchSampler as TorchBatchSampler


__all__ = [ __all__ = [
'TorchDDPDriver' 'TorchDDPDriver'


+ 2
- 1
fastNLP/core/drivers/torch_driver/torch_driver.py View File

@@ -31,6 +31,7 @@ from fastNLP.envs import rank_zero_call
from fastNLP.envs import FASTNLP_GLOBAL_RANK, FASTNLP_MODEL_FILENAME, FASTNLP_CHECKPOINT_FILENAME from fastNLP.envs import FASTNLP_GLOBAL_RANK, FASTNLP_MODEL_FILENAME, FASTNLP_CHECKPOINT_FILENAME
from fastNLP.core.log import logger from fastNLP.core.log import logger
from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler, ReproduceBatchSampler, RandomSampler from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler, ReproduceBatchSampler, RandomSampler
from fastNLP.core.dataloaders import OverfitDataLoader




class TorchDriver(Driver): class TorchDriver(Driver):
@@ -92,7 +93,7 @@ class TorchDriver(Driver):
self.grad_scaler.update() self.grad_scaler.update()


def check_dataloader_legality(self, dataloader): def check_dataloader_legality(self, dataloader):
if not isinstance(dataloader, DataLoader):
if not isinstance(dataloader, DataLoader) and not isinstance(dataloader, OverfitDataLoader):
raise TypeError(f"{DataLoader} is expected, instead of `{type(dataloader)}`") raise TypeError(f"{DataLoader} is expected, instead of `{type(dataloader)}`")
if len(dataloader) == 0: if len(dataloader) == 0:
logger.rank_zero_warning("Your dataloader is empty, which is not recommended because it " logger.rank_zero_warning("Your dataloader is empty, which is not recommended because it "


+ 24
- 22
fastNLP/core/drivers/torch_driver/utils.py View File

@@ -181,18 +181,16 @@ def replace_sampler(dataloader: "DataLoader", sampler):
instance_attrs = {k: v for k, v in vars(dataloader).items() if not k.startswith('_')} instance_attrs = {k: v for k, v in vars(dataloader).items() if not k.startswith('_')}


# 'multiprocessing_context' 是 user-defined function; # 'multiprocessing_context' 是 user-defined function;
instance_attrs["multiprocessing_context"] = dataloader.multiprocessing_context
if getattr(dataloader, 'multiprocessing_context', None) is not None:
instance_attrs["multiprocessing_context"] = dataloader.multiprocessing_context


# 拿到 dataloader '__init__' 函数的默认函数签名; # 拿到 dataloader '__init__' 函数的默认函数签名;
init_params = dict(inspect.signature(dataloader.__init__).parameters) init_params = dict(inspect.signature(dataloader.__init__).parameters)


# 这里为什么要单独弄的原因在于,用户在定制自己的 dataloader 的同时可能为了方便只设定一些参数,而后面直接使用 **kwargs 的方式,这时如果
# 其在初始化自己的 dataloader 实例的时候加入了一些其它的新的参数(首先这一步是必要的,因为我们只能通过这样加 sampler;另一方面,用户
# 可能确实通过 **kwargs 加入了一些新的参数),如果假设用户是这样使用的: "super().__init__(**kwargs)",那么我们就只能去 DataLoader
# 中寻找;
# 防止用户的 DataLoader 是继承了 pytorch 的 DataLoader,然后还是使用了 **kwargs 的方式对父类传参数
has_variadic_kwargs = any(v.kind is v.VAR_KEYWORD for k, v in init_params.items()) has_variadic_kwargs = any(v.kind is v.VAR_KEYWORD for k, v in init_params.items())
if has_variadic_kwargs:
# 这里之所以这样写是因为用户自己定制的 Dataloader 中名字一样的参数所设置的默认值可能不同;因此不能直接使用 update 覆盖掉了;
if has_variadic_kwargs and isinstance(dataloader, DataLoader):
# 防止用户写入了 super().__init__(**kwargs)
for key, value in dict(inspect.signature(DataLoader.__init__).parameters).items(): for key, value in dict(inspect.signature(DataLoader.__init__).parameters).items():
if key not in init_params and key != 'self': if key not in init_params and key != 'self':
init_params[key] = value init_params[key] = value
@@ -204,7 +202,8 @@ def replace_sampler(dataloader: "DataLoader", sampler):
non_default_params.add("dataset") non_default_params.add("dataset")


reconstruct_args = {k: v for k, v in instance_attrs.items() if k in non_default_params} reconstruct_args = {k: v for k, v in instance_attrs.items() if k in non_default_params}
reconstruct_args.update({"sampler": sampler, "shuffle": False, "batch_sampler": None})
if isinstance(dataloader, DataLoader):
reconstruct_args.update({"sampler": sampler, "shuffle": False, "batch_sampler": None})


batch_sampler = getattr(dataloader, "batch_sampler") batch_sampler = getattr(dataloader, "batch_sampler")
if batch_sampler is not None and isinstance(batch_sampler, ReproducibleBatchSampler): if batch_sampler is not None and isinstance(batch_sampler, ReproducibleBatchSampler):
@@ -218,35 +217,31 @@ def replace_sampler(dataloader: "DataLoader", sampler):
and p.name not in reconstruct_args and p.name not in reconstruct_args
} }


# 这种错误针对的是 __init__ 中的参数没有用同样名字的 self 挂上;
# 在 attribute 中没有找到这些参数,导致了没有办法重新初始化
if required_args: if required_args:
required_args = sorted(required_args) required_args = sorted(required_args)
dataloader_self_name = dataloader.__class__.__name__ dataloader_self_name = dataloader.__class__.__name__
raise Exception( raise Exception(
f"Trying to inject `DistributedSampler` into the `{dataloader_self_name}` instance. "
"This would fail as some of the `__init__` arguments are not available as instance attributes. "
f"The missing attributes are {required_args}. "
f"HINT: If you wrote the `{dataloader_self_name}` class, define `self.missing_arg_name` or "
"manually add the `DistributedSampler` as: "
f"`{dataloader_self_name}(dataset, sampler=DistributedSampler(dataset))`."
f"Need to inject arguments {required_args} into the __init__ of `{dataloader_self_name}`. "
f"But they are not found in the attribute of `{dataloader_self_name}`, fastNLP cannot determine its "
f"value when try to reinitialize `{dataloader_self_name}`, please add `{required_args}` to be "
f"`{dataloader_self_name}`'s attribute."
) )


# 这种错误针对的是传入的 dataloader 不是直接的 DataLoader,而是定制了 DataLoader,但是 __init__ 中没有 **kwargs; # 这种错误针对的是传入的 dataloader 不是直接的 DataLoader,而是定制了 DataLoader,但是 __init__ 中没有 **kwargs;
if not has_variadic_kwargs: if not has_variadic_kwargs:

# the dataloader signature does not allow keyword arguments that need to be passed # the dataloader signature does not allow keyword arguments that need to be passed
missing_kwargs = reconstruct_args.keys() - init_params.keys() missing_kwargs = reconstruct_args.keys() - init_params.keys()
if missing_kwargs: if missing_kwargs:
missing_kwargs = sorted(missing_kwargs) missing_kwargs = sorted(missing_kwargs)
dataloader_self_name = dataloader.__class__.__name__ dataloader_self_name = dataloader.__class__.__name__
raise Exception( raise Exception(
f"Trying to inject `DistributedSampler` into the `{dataloader_self_name}` instance. "
"This would fail as it doesn't expose all its attributes in the `__init__` signature. "
f"The missing arguments are {missing_kwargs}. "
f"HINT: If you wrote the `{dataloader_self_name}` class, add the `__init__` arguments or "
"manually add the `DistributedSampler` as: "
f"`{dataloader_self_name}(dataset, sampler=DistributedSampler(dataset))`."
f"The parameter:{missing_kwargs} needed to reinitialize `{dataloader_self_name}` is not found."
) )
# 如果没有kwargs,则保证一下只传入需要的参数
if not isinstance(dataloader, DataLoader):
reconstruct_args = {key:value for key,value in reconstruct_args.items() if key in init_params}

return type(dataloader)(**reconstruct_args) return type(dataloader)(**reconstruct_args)




@@ -260,6 +255,13 @@ def replace_batch_sampler(dataloader, new_batch_sampler):
params_keys.remove(k) params_keys.remove(k)
params = {k: getattr(dataloader, k) for k in params_keys} params = {k: getattr(dataloader, k) for k in params_keys}
params["batch_sampler"] = new_batch_sampler params["batch_sampler"] = new_batch_sampler

if not isinstance(dataloader, DataLoader):
init_params = dict(inspect.signature(dataloader.__init__).parameters)
has_variadic_kwargs = any(v.kind is v.VAR_KEYWORD for k, v in init_params.items())
if not has_variadic_kwargs:
params = {key:value for key,value in params.items() if key in init_params}

return type(dataloader)(**params) return type(dataloader)(**params)






+ 9
- 1
fastNLP/core/metrics/metric.py View File

@@ -98,7 +98,7 @@ class Metric:
return _wrap_get_metric return _wrap_get_metric


def __setattr__(self, key, value): def __setattr__(self, key, value):
if hasattr(self, '_cannot_change_element') and self._cannot_change_element is True:
if getattr(self, '_cannot_change_element', False):
if key in self.elements and isinstance(value, (float, int, bool)): if key in self.elements and isinstance(value, (float, int, bool)):
self.elements[key].fill_value(value) self.elements[key].fill_value(value)
return return
@@ -109,6 +109,14 @@ class Metric:
raise RuntimeError("Please use register_element() function to add Element.") raise RuntimeError("Please use register_element() function to add Element.")
object.__setattr__(self, key, value) object.__setattr__(self, key, value)


# 当调用 __getattribute__ 没有找到时才会触发这个, 保留这个的目的只是为了防止 ide 的 warning
def __getattr__(self, name: str) -> Element:
if 'elements' in self.__dict__:
elements = self.__dict__['elements']
if name in elements:
return elements[name]
raise AttributeError("`{}` object has no attribute `{}`.".format(type(self).__name__, name))

def _wrap_update(self, update): def _wrap_update(self, update):
@functools.wraps(update) @functools.wraps(update)
def _wrap_update(*args, **kwargs): def _wrap_update(*args, **kwargs):


+ 45
- 0
tests/core/controllers/test_trainer_w_evaluator_torch.py View File

@@ -286,6 +286,9 @@ def test_trainer_specific_params_1(
assert trainer.driver.non_blocking is False assert trainer.driver.non_blocking is False
assert trainer.driver.wo_auto_param_call is True assert trainer.driver.wo_auto_param_call is True


if dist.is_initialized():
dist.destroy_process_group()



@pytest.mark.torch @pytest.mark.torch
@pytest.mark.parametrize("driver,device", [("torch", [0, 1])]) # ("torch", [0, 1]),("torch", 1) @pytest.mark.parametrize("driver,device", [("torch", [0, 1])]) # ("torch", [0, 1]),("torch", 1)
@@ -332,5 +335,47 @@ def test_trainer_specific_params_2(
assert _ddp_kwargs.get("broadcast_buffers") is True assert _ddp_kwargs.get("broadcast_buffers") is True
assert _ddp_kwargs.get("find_unused_parameters") is True assert _ddp_kwargs.get("find_unused_parameters") is True


if dist.is_initialized():
dist.destroy_process_group()




@pytest.mark.torch
@pytest.mark.parametrize("driver,device", [("torch", 1), ("torch", [0, 1])]) # ("torch", [0, 1]),("torch", 1)
@pytest.mark.parametrize("overfit_batches,num_train_batch_per_epoch", [(-1, -1), (0, -1), (3, 10), (6, -1)])
@magic_argv_env_context
def test_trainer_w_evaluator_overfit_torch(
model_and_optimizers: TrainerParameters,
driver,
device,
overfit_batches,
num_train_batch_per_epoch
):
"""
测试一些特殊的参数是否能够正确地传递;
"""
trainer = Trainer(
model=model_and_optimizers.model,
driver=driver,
device=device,
overfit_batches=overfit_batches,
optimizers=model_and_optimizers.optimizers,
train_dataloader=model_and_optimizers.train_dataloader,
evaluate_dataloaders={"dl": model_and_optimizers.evaluate_dataloaders},
input_mapping=model_and_optimizers.input_mapping,
output_mapping=model_and_optimizers.output_mapping,
metrics=model_and_optimizers.metrics,
n_epochs=2,
output_from_new_proc="all",
evaluate_every=-1,

torch_kwargs={
"non_blocking": False,
"set_grad_to_none": True
}

)

trainer.run(num_train_batch_per_epoch=num_train_batch_per_epoch)

if dist.is_initialized():
dist.destroy_process_group()

+ 30
- 0
tests/core/controllers/test_trainer_wo_evaluator_torch.py View File

@@ -361,5 +361,35 @@ def test_torch_wo_auto_param_call(
dist.destroy_process_group() dist.destroy_process_group()




# 测试 accumulation_steps;
@pytest.mark.torch
@pytest.mark.parametrize("driver,device", [("torch", 1), ("torch", [0, 1])])
@pytest.mark.parametrize("overfit_batches,num_train_batch_per_epoch", [(-1, -1), (0, -1), (3, 10), (6, -1)])
@magic_argv_env_context
def test_trainer_overfit_torch(
model_and_optimizers: TrainerParameters,
driver,
device,
overfit_batches,
num_train_batch_per_epoch
):
trainer = Trainer(
model=model_and_optimizers.model,
driver=driver,
device=device,
overfit_batches=overfit_batches,
optimizers=model_and_optimizers.optimizers,
train_dataloader=model_and_optimizers.train_dataloader,
evaluate_dataloaders=model_and_optimizers.evaluate_dataloaders,
input_mapping=model_and_optimizers.input_mapping,
output_mapping=model_and_optimizers.output_mapping,
metrics=model_and_optimizers.metrics,
output_from_new_proc="all",
n_epochs=2,
)

trainer.run(num_train_batch_per_epoch=num_train_batch_per_epoch)


if dist.is_initialized():
dist.destroy_process_group()



Loading…
Cancel
Save