diff --git a/fastNLP/core/callbacks/more_evaluate_callback.py b/fastNLP/core/callbacks/more_evaluate_callback.py index 538ffc10..e11bacde 100644 --- a/fastNLP/core/callbacks/more_evaluate_callback.py +++ b/fastNLP/core/callbacks/more_evaluate_callback.py @@ -85,7 +85,8 @@ class MoreEvaluateCallback(HasMonitorCallback): if watch_monitor is None and evaluate_every is None: raise RuntimeError("`evaluate_every` and `watch_monitor` cannot be None at the same time.") if watch_monitor is not None and evaluate_every is not None: - raise RuntimeError("`evaluate_every` and `watch_monitor` cannot be set at the same time.") + raise RuntimeError(f"`evaluate_every`({evaluate_every}) and `watch_monitor`({watch_monitor}) " + f"cannot be set at the same time.") if topk_monitor is not None and topk == 0: raise RuntimeError("`topk_monitor` is set, but `topk` is 0.") diff --git a/fastNLP/core/callbacks/topk_saver.py b/fastNLP/core/callbacks/topk_saver.py index 32341e7b..0317b7e2 100644 --- a/fastNLP/core/callbacks/topk_saver.py +++ b/fastNLP/core/callbacks/topk_saver.py @@ -36,7 +36,8 @@ class Saver: model_save_fn:Callable=None, **kwargs): if folder is None: folder = Path.cwd().absolute() - logger.info(f"Parameter `folder` is None, and fastNLP will use {folder} to save and load your model.") + if save_object is not None: + logger.info(f"Parameter `folder` is None, and fastNLP will use {folder} to save and load your model.") folder = Path(folder) if not folder.exists(): folder.mkdir(parents=True, exist_ok=True) @@ -208,7 +209,7 @@ class TopkSaver(ResultsMonitor, Saver): if topk is None: topk = 0 ResultsMonitor.__init__(self, monitor, larger_better) - Saver.__init__(self, folder, save_object, only_state_dict, model_save_fn, **kwargs) + Saver.__init__(self, folder, save_object if topk!=0 else None, only_state_dict, model_save_fn, **kwargs) if monitor is not None and topk == 0: raise RuntimeError("`monitor` is set, but `topk` is 0.") diff --git a/fastNLP/core/samplers/reproducible_batch_sampler.py b/fastNLP/core/samplers/reproducible_batch_sampler.py index 9c7bf2c6..679bb1cf 100644 --- a/fastNLP/core/samplers/reproducible_batch_sampler.py +++ b/fastNLP/core/samplers/reproducible_batch_sampler.py @@ -210,7 +210,7 @@ class RandomBatchSampler(ReproducibleBatchSampler): self.num_consumed_samples = 0 self.during_iter = True - indices = list(range(len(self.dataset))) + indices = list(range(getattr(self.dataset, 'total_len', len(self.dataset)))) if self.shuffle: if self.num_consumed_samples > 0: # 需要先按照原来的排序,删掉多余的 @@ -237,7 +237,7 @@ class RandomBatchSampler(ReproducibleBatchSampler): if len(indices)%self.batch_size!=0: batches.append(indices[_num_batches*self.batch_size:]) - need_pad_num = (len(self.dataset)-self.num_consumed_samples) % self.num_replicas + need_pad_num = (getattr(self.dataset, 'total_len', len(self.dataset))-self.num_consumed_samples) % self.num_replicas if self.pad and need_pad_num !=0 and need_pad_num<=self.rank: if len(batches) > 0: if len(batches[-1])int: """ @@ -332,7 +332,7 @@ class RandomBatchSampler(ReproducibleBatchSampler): raise RuntimeError("BucketedBatchSampler does not support saving before last checkpoint states have been" " consumed. ") states = {'seed': self.seed, 'epoch': self.epoch, 'num_consumed_samples': self.num_consumed_samples, - 'sampler_type': self.__class__.__name__, 'length': len(self.dataset), 'shuffle': self.shuffle, + 'sampler_type': self.__class__.__name__, 'length': getattr(self.dataset, 'total_len', len(self.dataset)), 'shuffle': self.shuffle, 'batch_size': self.batch_size, 'num_replicas': self.num_replicas} @@ -347,7 +347,7 @@ class RandomBatchSampler(ReproducibleBatchSampler): f"we cannot use {self.__class__.__name__} to load it." length = states['length'] - assert length == len(self.dataset), "The number of samples is different between the checkpoint record " \ + assert length == getattr(self.dataset, 'total_len', len(self.dataset)), "The number of samples is different between the checkpoint record " \ "and current dataset." self.seed = states['seed'] self.epoch = states['epoch'] @@ -464,8 +464,8 @@ class BucketedBatchSampler(ReproducibleBatchSampler): :return: """ num_consumed_samples = self.num_consumed_samples - return math.ceil((len(self.dataset) - num_consumed_samples) / self.num_replicas) if \ - self.pad else math.floor(((len(self.dataset) - num_consumed_samples) / self.num_replicas)) + return math.ceil((getattr(self.dataset, 'total_len', len(self.dataset)) - num_consumed_samples) / self.num_replicas) if \ + self.pad else math.floor(((getattr(self.dataset, 'total_len', len(self.dataset)) - num_consumed_samples) / self.num_replicas)) def __len__(self)->int: """ @@ -515,7 +515,7 @@ class BucketedBatchSampler(ReproducibleBatchSampler): if len(sorted_indices)%self.batch_size!=0: batches.append(sorted_indices[_num_batches*self.batch_size:]) - need_pad_num = (len(self.dataset)-self.num_consumed_samples) % self.num_replicas + need_pad_num = (getattr(self.dataset, 'total_len', len(self.dataset))-self.num_consumed_samples) % self.num_replicas if self.pad and need_pad_num !=0 and need_pad_num<=self.rank: if len(batches) > 0: if len(batches[-1]) Dict: @@ -155,8 +155,8 @@ class RandomSampler(ReproducibleSampler): f"we cannot use {self.__class__.__name__} to load it." length = states['length'] - assert length == len(self.dataset), f"The number of samples is different between the checkpoint record({length}) " \ - f"and current dataset({len(self.dataset)})." + assert length == getattr(self.dataset, 'total_len', len(self.dataset)), f"The number of samples is different between the checkpoint record({length}) " \ + f"and current dataset({getattr(self.dataset, 'total_len', len(self.dataset))})." self.seed = states['seed'] self.epoch = states['epoch'] self.num_consumed_samples = states['num_consumed_samples'] @@ -208,8 +208,8 @@ class RandomSampler(ReproducibleSampler): :return: """ num_consumed_samples = self.num_consumed_samples - return math.ceil((len(self.dataset) - num_consumed_samples) / self.num_replicas) if \ - self.pad else math.floor(((len(self.dataset) - num_consumed_samples) / self.num_replicas)) + return math.ceil((getattr(self.dataset, 'total_len', len(self.dataset)) - num_consumed_samples) / self.num_replicas) if \ + self.pad else math.floor(((getattr(self.dataset, 'total_len', len(self.dataset)) - num_consumed_samples) / self.num_replicas)) class SequentialSampler(RandomSampler): @@ -258,11 +258,11 @@ class SequentialSampler(RandomSampler): :return: """ - return list(range(len(self.dataset))) + return list(range(getattr(self.dataset, 'total_len', len(self.dataset)))) def state_dict(self) -> Dict: states = {'num_consumed_samples': self.num_consumed_samples, 'sampler_type': self.__class__.__name__, - 'length': len(self.dataset) + 'length': getattr(self.dataset, 'total_len', len(self.dataset)) } return states @@ -275,8 +275,8 @@ class SequentialSampler(RandomSampler): f"we cannot use {self.__class__.__name__} to load it." length = states['length'] - assert length == len(self.dataset), f"The number of samples is different between the checkpoint record({length}) " \ - f"and current dataset({len(self.dataset)})." + assert length == getattr(self.dataset, 'total_len', len(self.dataset)), f"The number of samples is different between the checkpoint record({length}) " \ + f"and current dataset({getattr(self.dataset, 'total_len', len(self.dataset))})." self.num_consumed_samples = states['num_consumed_samples'] if self.num_consumed_samples >= length: # 如果保存的时候已经到达了最后一个sample了,则直接将结果重置为0 self.num_consumed_samples = 0 @@ -314,9 +314,9 @@ class SortedSampler(SequentialSampler): except BaseException as e: logger.error(f"Cannot use {self.__class__.__name__} as length, since it is not sortable.") - assert len(length) == len(dataset), f"The length of `dataset`({len(dataset)}) and " \ - f"`length`({len(length)}) should be equal." - assert len(self.sorted_indices) == len(dataset), "The indices and dataset should have equal length." + assert len(length) == getattr(self.dataset, 'total_len', len(self.dataset)), f"The length of `dataset`({len(dataset)}) and " \ + f"`length`({getattr(self.dataset, 'total_len', len(self.dataset))}) should be equal." + assert len(self.sorted_indices) == getattr(self.dataset, 'total_len', len(self.dataset)), "The indices and dataset should have equal length." self.length = np.array(length, dtype=int) # 按照长到短排列的序号。 self.sorted_indices = np.argsort(self.length)[::-1].tolist() # 按长度从高到低排序的 diff --git a/fastNLP/core/samplers/unrepeated_sampler.py b/fastNLP/core/samplers/unrepeated_sampler.py index b76f9ffd..0ff55674 100644 --- a/fastNLP/core/samplers/unrepeated_sampler.py +++ b/fastNLP/core/samplers/unrepeated_sampler.py @@ -42,8 +42,8 @@ class UnrepeatedRandomSampler(UnrepeatedSampler): 返回 sampler 一次完整的迭代过程会产生多少个index。多卡的情况下,只考虑当前rank; :return: """ - num_common = len(self.dataset)//self.num_replicas - num_samples = num_common + int(self.rank < (len(self.dataset)-num_common*self.num_replicas)) + num_common = getattr(self.dataset, 'total_len', len(self.dataset))//self.num_replicas + num_samples = num_common + int(self.rank < (getattr(self.dataset, 'total_len', len(self.dataset))-num_common*self.num_replicas)) return num_samples def __iter__(self): @@ -63,14 +63,14 @@ class UnrepeatedRandomSampler(UnrepeatedSampler): :return: """ if self.shuffle: - indices = list(range(len(self.dataset))) + indices = list(range(getattr(self.dataset, 'total_len', len(self.dataset)))) seed = self.seed + self.epoch rng = np.random.default_rng(abs(seed)) rng.shuffle(indices) if self.epoch < 0: # 防止用户忘记调用 set_epoch,至少这样可以保证每次epoch出来的index顺序不同。 self.epoch -= 1 else: - indices = list(range(len(self.dataset))) + indices = list(range(getattr(self.dataset, 'total_len', len(self.dataset)))) return indices def set_epoch(self, epoch: int) -> None: @@ -84,8 +84,8 @@ class UnrepeatedRandomSampler(UnrepeatedSampler): :param rank: :return: """ - assert num_replicas<=len(self.dataset), f"The number of replicas({num_replicas}) should be lesser than the " \ - f"number of samples({len(self.dataset)})." + assert num_replicas<=getattr(self.dataset, 'total_len', len(self.dataset)), f"The number of replicas({num_replicas}) should be lesser than the " \ + f"number of samples({getattr(self.dataset, 'total_len', len(self.dataset))})." assert num_replicas>0 and isinstance(num_replicas, int) assert isinstance(rank, int) and 0<=rank List[int]: - return list(range(len(self.dataset))) + return list(range(getattr(self.dataset, 'total_len', len(self.dataset)))) diff --git a/fastNLP/core/utils/rich_progress.py b/fastNLP/core/utils/rich_progress.py index 0e6d5a01..d8e9d45b 100644 --- a/fastNLP/core/utils/rich_progress.py +++ b/fastNLP/core/utils/rich_progress.py @@ -149,9 +149,12 @@ class FRichProgress(Progress, metaclass=Singleton): super().stop_task(task_id) super().remove_task(task_id) self.refresh() # 使得bar不残留 - # 这里需要注释掉的原因是由于,在dataset多次apply的过程中会出现自动换行的问题。以前保留这个的原因应该是由于evaluate结束bar不消失。 - # if len(self._tasks) == 0: - # self.live.stop() + if len(self._tasks) == 0: + # 这里将这个line函数给hack一下防止stop的时候打印出空行 + old_line = getattr(self.live.console, 'line') + setattr(self.live.console, 'line', lambda *args,**kwargs:...) + self.live.stop() + setattr(self.live.console, 'line', old_line) def start(self) -> None: super().start()