diff --git a/fastNLP/core/controllers/utils/utils.py b/fastNLP/core/controllers/utils/utils.py index 0a351354..b01cccc4 100644 --- a/fastNLP/core/controllers/utils/utils.py +++ b/fastNLP/core/controllers/utils/utils.py @@ -132,6 +132,10 @@ class _TruncatedDataLoader: def __getattr__(self, item): return getattr(self.dataloader, item) + def __setattr__(self, key, value): + # 添加该函数使得在进行实验性训练或者评测时,用户对于 trainer.dataloader 的感觉和正常训练完全一样; + setattr(self.dataloader, key, value) + def check_evaluate_every(evaluate_every): r"""