f_rich_progress在没有bar的时候暂时关闭live;2.修改sampler获取dataset长度的方式以适配jittortags/v1.0.0alpha
@@ -85,7 +85,8 @@ class MoreEvaluateCallback(HasMonitorCallback): | |||||
if watch_monitor is None and evaluate_every is None: | if watch_monitor is None and evaluate_every is None: | ||||
raise RuntimeError("`evaluate_every` and `watch_monitor` cannot be None at the same time.") | raise RuntimeError("`evaluate_every` and `watch_monitor` cannot be None at the same time.") | ||||
if watch_monitor is not None and evaluate_every is not None: | if watch_monitor is not None and evaluate_every is not None: | ||||
raise RuntimeError("`evaluate_every` and `watch_monitor` cannot be set at the same time.") | |||||
raise RuntimeError(f"`evaluate_every`({evaluate_every}) and `watch_monitor`({watch_monitor}) " | |||||
f"cannot be set at the same time.") | |||||
if topk_monitor is not None and topk == 0: | if topk_monitor is not None and topk == 0: | ||||
raise RuntimeError("`topk_monitor` is set, but `topk` is 0.") | raise RuntimeError("`topk_monitor` is set, but `topk` is 0.") | ||||
@@ -36,7 +36,8 @@ class Saver: | |||||
model_save_fn:Callable=None, **kwargs): | model_save_fn:Callable=None, **kwargs): | ||||
if folder is None: | if folder is None: | ||||
folder = Path.cwd().absolute() | folder = Path.cwd().absolute() | ||||
logger.info(f"Parameter `folder` is None, and fastNLP will use {folder} to save and load your model.") | |||||
if save_object is not None: | |||||
logger.info(f"Parameter `folder` is None, and fastNLP will use {folder} to save and load your model.") | |||||
folder = Path(folder) | folder = Path(folder) | ||||
if not folder.exists(): | if not folder.exists(): | ||||
folder.mkdir(parents=True, exist_ok=True) | folder.mkdir(parents=True, exist_ok=True) | ||||
@@ -208,7 +209,7 @@ class TopkSaver(ResultsMonitor, Saver): | |||||
if topk is None: | if topk is None: | ||||
topk = 0 | topk = 0 | ||||
ResultsMonitor.__init__(self, monitor, larger_better) | ResultsMonitor.__init__(self, monitor, larger_better) | ||||
Saver.__init__(self, folder, save_object, only_state_dict, model_save_fn, **kwargs) | |||||
Saver.__init__(self, folder, save_object if topk!=0 else None, only_state_dict, model_save_fn, **kwargs) | |||||
if monitor is not None and topk == 0: | if monitor is not None and topk == 0: | ||||
raise RuntimeError("`monitor` is set, but `topk` is 0.") | raise RuntimeError("`monitor` is set, but `topk` is 0.") | ||||
@@ -210,7 +210,7 @@ class RandomBatchSampler(ReproducibleBatchSampler): | |||||
self.num_consumed_samples = 0 | self.num_consumed_samples = 0 | ||||
self.during_iter = True | self.during_iter = True | ||||
indices = list(range(len(self.dataset))) | |||||
indices = list(range(getattr(self.dataset, 'total_len', len(self.dataset)))) | |||||
if self.shuffle: | if self.shuffle: | ||||
if self.num_consumed_samples > 0: # 需要先按照原来的排序,删掉多余的 | if self.num_consumed_samples > 0: # 需要先按照原来的排序,删掉多余的 | ||||
@@ -237,7 +237,7 @@ class RandomBatchSampler(ReproducibleBatchSampler): | |||||
if len(indices)%self.batch_size!=0: | if len(indices)%self.batch_size!=0: | ||||
batches.append(indices[_num_batches*self.batch_size:]) | batches.append(indices[_num_batches*self.batch_size:]) | ||||
need_pad_num = (len(self.dataset)-self.num_consumed_samples) % self.num_replicas | |||||
need_pad_num = (getattr(self.dataset, 'total_len', len(self.dataset))-self.num_consumed_samples) % self.num_replicas | |||||
if self.pad and need_pad_num !=0 and need_pad_num<=self.rank: | if self.pad and need_pad_num !=0 and need_pad_num<=self.rank: | ||||
if len(batches) > 0: | if len(batches) > 0: | ||||
if len(batches[-1])<self.batch_size: | if len(batches[-1])<self.batch_size: | ||||
@@ -290,9 +290,9 @@ class RandomBatchSampler(ReproducibleBatchSampler): | |||||
@property | @property | ||||
def batch_idx_in_epoch(self): | def batch_idx_in_epoch(self): | ||||
if self.drop_last: | if self.drop_last: | ||||
return len(self.dataset) // self.num_replicas // self.batch_size - self.num_left_samples // self.batch_size | |||||
return getattr(self.dataset, 'total_len', len(self.dataset)) // self.num_replicas // self.batch_size - self.num_left_samples // self.batch_size | |||||
else: | else: | ||||
return (len(self.dataset) // self.num_replicas + self.batch_size - 1) // self.batch_size - \ | |||||
return (getattr(self.dataset, 'total_len', len(self.dataset)) // self.num_replicas + self.batch_size - 1) // self.batch_size - \ | |||||
(self.num_left_samples + self.batch_size - 1) // self.batch_size | (self.num_left_samples + self.batch_size - 1) // self.batch_size | ||||
@property | @property | ||||
@@ -313,8 +313,8 @@ class RandomBatchSampler(ReproducibleBatchSampler): | |||||
:return: | :return: | ||||
""" | """ | ||||
num_consumed_samples = self.num_consumed_samples | num_consumed_samples = self.num_consumed_samples | ||||
return math.ceil((len(self.dataset) - num_consumed_samples) / self.num_replicas) if \ | |||||
self.pad else math.floor(((len(self.dataset) - num_consumed_samples) / self.num_replicas)) | |||||
return math.ceil((getattr(self.dataset, 'total_len', len(self.dataset)) - num_consumed_samples) / self.num_replicas) if \ | |||||
self.pad else math.floor(((getattr(self.dataset, 'total_len', len(self.dataset)) - num_consumed_samples) / self.num_replicas)) | |||||
def __len__(self)->int: | def __len__(self)->int: | ||||
""" | """ | ||||
@@ -332,7 +332,7 @@ class RandomBatchSampler(ReproducibleBatchSampler): | |||||
raise RuntimeError("BucketedBatchSampler does not support saving before last checkpoint states have been" | raise RuntimeError("BucketedBatchSampler does not support saving before last checkpoint states have been" | ||||
" consumed. ") | " consumed. ") | ||||
states = {'seed': self.seed, 'epoch': self.epoch, 'num_consumed_samples': self.num_consumed_samples, | states = {'seed': self.seed, 'epoch': self.epoch, 'num_consumed_samples': self.num_consumed_samples, | ||||
'sampler_type': self.__class__.__name__, 'length': len(self.dataset), 'shuffle': self.shuffle, | |||||
'sampler_type': self.__class__.__name__, 'length': getattr(self.dataset, 'total_len', len(self.dataset)), 'shuffle': self.shuffle, | |||||
'batch_size': self.batch_size, | 'batch_size': self.batch_size, | ||||
'num_replicas': self.num_replicas} | 'num_replicas': self.num_replicas} | ||||
@@ -347,7 +347,7 @@ class RandomBatchSampler(ReproducibleBatchSampler): | |||||
f"we cannot use {self.__class__.__name__} to load it." | f"we cannot use {self.__class__.__name__} to load it." | ||||
length = states['length'] | length = states['length'] | ||||
assert length == len(self.dataset), "The number of samples is different between the checkpoint record " \ | |||||
assert length == getattr(self.dataset, 'total_len', len(self.dataset)), "The number of samples is different between the checkpoint record " \ | |||||
"and current dataset." | "and current dataset." | ||||
self.seed = states['seed'] | self.seed = states['seed'] | ||||
self.epoch = states['epoch'] | self.epoch = states['epoch'] | ||||
@@ -464,8 +464,8 @@ class BucketedBatchSampler(ReproducibleBatchSampler): | |||||
:return: | :return: | ||||
""" | """ | ||||
num_consumed_samples = self.num_consumed_samples | num_consumed_samples = self.num_consumed_samples | ||||
return math.ceil((len(self.dataset) - num_consumed_samples) / self.num_replicas) if \ | |||||
self.pad else math.floor(((len(self.dataset) - num_consumed_samples) / self.num_replicas)) | |||||
return math.ceil((getattr(self.dataset, 'total_len', len(self.dataset)) - num_consumed_samples) / self.num_replicas) if \ | |||||
self.pad else math.floor(((getattr(self.dataset, 'total_len', len(self.dataset)) - num_consumed_samples) / self.num_replicas)) | |||||
def __len__(self)->int: | def __len__(self)->int: | ||||
""" | """ | ||||
@@ -515,7 +515,7 @@ class BucketedBatchSampler(ReproducibleBatchSampler): | |||||
if len(sorted_indices)%self.batch_size!=0: | if len(sorted_indices)%self.batch_size!=0: | ||||
batches.append(sorted_indices[_num_batches*self.batch_size:]) | batches.append(sorted_indices[_num_batches*self.batch_size:]) | ||||
need_pad_num = (len(self.dataset)-self.num_consumed_samples) % self.num_replicas | |||||
need_pad_num = (getattr(self.dataset, 'total_len', len(self.dataset))-self.num_consumed_samples) % self.num_replicas | |||||
if self.pad and need_pad_num !=0 and need_pad_num<=self.rank: | if self.pad and need_pad_num !=0 and need_pad_num<=self.rank: | ||||
if len(batches) > 0: | if len(batches) > 0: | ||||
if len(batches[-1])<self.batch_size: | if len(batches[-1])<self.batch_size: | ||||
@@ -593,7 +593,7 @@ class BucketedBatchSampler(ReproducibleBatchSampler): | |||||
raise RuntimeError("BucketedBatchSampler does not support saving before last checkpoint states have been" | raise RuntimeError("BucketedBatchSampler does not support saving before last checkpoint states have been" | ||||
" consumed. ") | " consumed. ") | ||||
states = {'seed': self.seed, 'epoch': self.epoch, 'num_consumed_samples': self.num_consumed_samples, | states = {'seed': self.seed, 'epoch': self.epoch, 'num_consumed_samples': self.num_consumed_samples, | ||||
'sampler_type': self.__class__.__name__, 'length': len(self.dataset), 'shuffle': self.shuffle, | |||||
'sampler_type': self.__class__.__name__, 'length': getattr(self.dataset, 'total_len', len(self.dataset)), 'shuffle': self.shuffle, | |||||
'batch_size': self.batch_size, 'num_batch_per_bucket': self.num_batch_per_bucket, | 'batch_size': self.batch_size, 'num_batch_per_bucket': self.num_batch_per_bucket, | ||||
'num_replicas': self.num_replicas | 'num_replicas': self.num_replicas | ||||
} | } | ||||
@@ -609,7 +609,7 @@ class BucketedBatchSampler(ReproducibleBatchSampler): | |||||
f"we cannot use {self.__class__.__name__} to load it." | f"we cannot use {self.__class__.__name__} to load it." | ||||
length = states['length'] | length = states['length'] | ||||
assert length == len(self.dataset), "The number of samples is different between the checkpoint record " \ | |||||
assert length == getattr(self.dataset, 'total_len', len(self.dataset)), "The number of samples is different between the checkpoint record " \ | |||||
"and current dataset." | "and current dataset." | ||||
self.seed = states['seed'] | self.seed = states['seed'] | ||||
self.epoch = states['epoch'] | self.epoch = states['epoch'] | ||||
@@ -630,7 +630,7 @@ class BucketedBatchSampler(ReproducibleBatchSampler): | |||||
@property | @property | ||||
def batch_idx_in_epoch(self): | def batch_idx_in_epoch(self): | ||||
if self.drop_last: | if self.drop_last: | ||||
return len(self.dataset) // self.num_replicas // self.batch_size - self.num_left_samples // self.batch_size | |||||
return getattr(self.dataset, 'total_len', len(self.dataset)) // self.num_replicas // self.batch_size - self.num_left_samples // self.batch_size | |||||
else: | else: | ||||
return (len(self.dataset) // self.num_replicas + self.batch_size - 1) // self.batch_size - \ | |||||
return (getattr(self.dataset, 'total_len', len(self.dataset)) // self.num_replicas + self.batch_size - 1) // self.batch_size - \ | |||||
(self.num_left_samples + self.batch_size - 1) // self.batch_size | (self.num_left_samples + self.batch_size - 1) // self.batch_size |
@@ -131,14 +131,14 @@ class RandomSampler(ReproducibleSampler): | |||||
:return: | :return: | ||||
""" | """ | ||||
if self.shuffle: | if self.shuffle: | ||||
indices = list(range(len(self.dataset))) | |||||
indices = list(range(getattr(self.dataset, 'total_len', len(self.dataset)))) | |||||
seed = self.seed + self.epoch | seed = self.seed + self.epoch | ||||
rng = np.random.default_rng(abs(seed)) | rng = np.random.default_rng(abs(seed)) | ||||
rng.shuffle(indices) | rng.shuffle(indices) | ||||
if self.epoch < 0: # 防止用户忘记调用 set_epoch,至少这样可以保证每次epoch出来的index顺序不同。 | if self.epoch < 0: # 防止用户忘记调用 set_epoch,至少这样可以保证每次epoch出来的index顺序不同。 | ||||
self.epoch -= 1 | self.epoch -= 1 | ||||
else: | else: | ||||
indices = list(range(len(self.dataset))) | |||||
indices = list(range(getattr(self.dataset, 'total_len', len(self.dataset)))) | |||||
return indices | return indices | ||||
def state_dict(self) -> Dict: | def state_dict(self) -> Dict: | ||||
@@ -155,8 +155,8 @@ class RandomSampler(ReproducibleSampler): | |||||
f"we cannot use {self.__class__.__name__} to load it." | f"we cannot use {self.__class__.__name__} to load it." | ||||
length = states['length'] | length = states['length'] | ||||
assert length == len(self.dataset), f"The number of samples is different between the checkpoint record({length}) " \ | |||||
f"and current dataset({len(self.dataset)})." | |||||
assert length == getattr(self.dataset, 'total_len', len(self.dataset)), f"The number of samples is different between the checkpoint record({length}) " \ | |||||
f"and current dataset({getattr(self.dataset, 'total_len', len(self.dataset))})." | |||||
self.seed = states['seed'] | self.seed = states['seed'] | ||||
self.epoch = states['epoch'] | self.epoch = states['epoch'] | ||||
self.num_consumed_samples = states['num_consumed_samples'] | self.num_consumed_samples = states['num_consumed_samples'] | ||||
@@ -208,8 +208,8 @@ class RandomSampler(ReproducibleSampler): | |||||
:return: | :return: | ||||
""" | """ | ||||
num_consumed_samples = self.num_consumed_samples | num_consumed_samples = self.num_consumed_samples | ||||
return math.ceil((len(self.dataset) - num_consumed_samples) / self.num_replicas) if \ | |||||
self.pad else math.floor(((len(self.dataset) - num_consumed_samples) / self.num_replicas)) | |||||
return math.ceil((getattr(self.dataset, 'total_len', len(self.dataset)) - num_consumed_samples) / self.num_replicas) if \ | |||||
self.pad else math.floor(((getattr(self.dataset, 'total_len', len(self.dataset)) - num_consumed_samples) / self.num_replicas)) | |||||
class SequentialSampler(RandomSampler): | class SequentialSampler(RandomSampler): | ||||
@@ -258,11 +258,11 @@ class SequentialSampler(RandomSampler): | |||||
:return: | :return: | ||||
""" | """ | ||||
return list(range(len(self.dataset))) | |||||
return list(range(getattr(self.dataset, 'total_len', len(self.dataset)))) | |||||
def state_dict(self) -> Dict: | def state_dict(self) -> Dict: | ||||
states = {'num_consumed_samples': self.num_consumed_samples, 'sampler_type': self.__class__.__name__, | states = {'num_consumed_samples': self.num_consumed_samples, 'sampler_type': self.__class__.__name__, | ||||
'length': len(self.dataset) | |||||
'length': getattr(self.dataset, 'total_len', len(self.dataset)) | |||||
} | } | ||||
return states | return states | ||||
@@ -275,8 +275,8 @@ class SequentialSampler(RandomSampler): | |||||
f"we cannot use {self.__class__.__name__} to load it." | f"we cannot use {self.__class__.__name__} to load it." | ||||
length = states['length'] | length = states['length'] | ||||
assert length == len(self.dataset), f"The number of samples is different between the checkpoint record({length}) " \ | |||||
f"and current dataset({len(self.dataset)})." | |||||
assert length == getattr(self.dataset, 'total_len', len(self.dataset)), f"The number of samples is different between the checkpoint record({length}) " \ | |||||
f"and current dataset({getattr(self.dataset, 'total_len', len(self.dataset))})." | |||||
self.num_consumed_samples = states['num_consumed_samples'] | self.num_consumed_samples = states['num_consumed_samples'] | ||||
if self.num_consumed_samples >= length: # 如果保存的时候已经到达了最后一个sample了,则直接将结果重置为0 | if self.num_consumed_samples >= length: # 如果保存的时候已经到达了最后一个sample了,则直接将结果重置为0 | ||||
self.num_consumed_samples = 0 | self.num_consumed_samples = 0 | ||||
@@ -314,9 +314,9 @@ class SortedSampler(SequentialSampler): | |||||
except BaseException as e: | except BaseException as e: | ||||
logger.error(f"Cannot use {self.__class__.__name__} as length, since it is not sortable.") | logger.error(f"Cannot use {self.__class__.__name__} as length, since it is not sortable.") | ||||
assert len(length) == len(dataset), f"The length of `dataset`({len(dataset)}) and " \ | |||||
f"`length`({len(length)}) should be equal." | |||||
assert len(self.sorted_indices) == len(dataset), "The indices and dataset should have equal length." | |||||
assert len(length) == getattr(self.dataset, 'total_len', len(self.dataset)), f"The length of `dataset`({len(dataset)}) and " \ | |||||
f"`length`({getattr(self.dataset, 'total_len', len(self.dataset))}) should be equal." | |||||
assert len(self.sorted_indices) == getattr(self.dataset, 'total_len', len(self.dataset)), "The indices and dataset should have equal length." | |||||
self.length = np.array(length, dtype=int) # 按照长到短排列的序号。 | self.length = np.array(length, dtype=int) # 按照长到短排列的序号。 | ||||
self.sorted_indices = np.argsort(self.length)[::-1].tolist() # 按长度从高到低排序的 | self.sorted_indices = np.argsort(self.length)[::-1].tolist() # 按长度从高到低排序的 | ||||
@@ -42,8 +42,8 @@ class UnrepeatedRandomSampler(UnrepeatedSampler): | |||||
返回 sampler 一次完整的迭代过程会产生多少个index。多卡的情况下,只考虑当前rank; | 返回 sampler 一次完整的迭代过程会产生多少个index。多卡的情况下,只考虑当前rank; | ||||
:return: | :return: | ||||
""" | """ | ||||
num_common = len(self.dataset)//self.num_replicas | |||||
num_samples = num_common + int(self.rank < (len(self.dataset)-num_common*self.num_replicas)) | |||||
num_common = getattr(self.dataset, 'total_len', len(self.dataset))//self.num_replicas | |||||
num_samples = num_common + int(self.rank < (getattr(self.dataset, 'total_len', len(self.dataset))-num_common*self.num_replicas)) | |||||
return num_samples | return num_samples | ||||
def __iter__(self): | def __iter__(self): | ||||
@@ -63,14 +63,14 @@ class UnrepeatedRandomSampler(UnrepeatedSampler): | |||||
:return: | :return: | ||||
""" | """ | ||||
if self.shuffle: | if self.shuffle: | ||||
indices = list(range(len(self.dataset))) | |||||
indices = list(range(getattr(self.dataset, 'total_len', len(self.dataset)))) | |||||
seed = self.seed + self.epoch | seed = self.seed + self.epoch | ||||
rng = np.random.default_rng(abs(seed)) | rng = np.random.default_rng(abs(seed)) | ||||
rng.shuffle(indices) | rng.shuffle(indices) | ||||
if self.epoch < 0: # 防止用户忘记调用 set_epoch,至少这样可以保证每次epoch出来的index顺序不同。 | if self.epoch < 0: # 防止用户忘记调用 set_epoch,至少这样可以保证每次epoch出来的index顺序不同。 | ||||
self.epoch -= 1 | self.epoch -= 1 | ||||
else: | else: | ||||
indices = list(range(len(self.dataset))) | |||||
indices = list(range(getattr(self.dataset, 'total_len', len(self.dataset)))) | |||||
return indices | return indices | ||||
def set_epoch(self, epoch: int) -> None: | def set_epoch(self, epoch: int) -> None: | ||||
@@ -84,8 +84,8 @@ class UnrepeatedRandomSampler(UnrepeatedSampler): | |||||
:param rank: | :param rank: | ||||
:return: | :return: | ||||
""" | """ | ||||
assert num_replicas<=len(self.dataset), f"The number of replicas({num_replicas}) should be lesser than the " \ | |||||
f"number of samples({len(self.dataset)})." | |||||
assert num_replicas<=getattr(self.dataset, 'total_len', len(self.dataset)), f"The number of replicas({num_replicas}) should be lesser than the " \ | |||||
f"number of samples({getattr(self.dataset, 'total_len', len(self.dataset))})." | |||||
assert num_replicas>0 and isinstance(num_replicas, int) | assert num_replicas>0 and isinstance(num_replicas, int) | ||||
assert isinstance(rank, int) and 0<=rank<num_replicas | assert isinstance(rank, int) and 0<=rank<num_replicas | ||||
# 注意初始化该函数时,所有的状态都应当默认是一个 epoch 刚开始训练的状态; | # 注意初始化该函数时,所有的状态都应当默认是一个 epoch 刚开始训练的状态; | ||||
@@ -147,5 +147,5 @@ class UnrepeatedSequentialSampler(UnrepeatedRandomSampler): | |||||
yield index | yield index | ||||
def generate_indices(self) -> List[int]: | def generate_indices(self) -> List[int]: | ||||
return list(range(len(self.dataset))) | |||||
return list(range(getattr(self.dataset, 'total_len', len(self.dataset)))) | |||||
@@ -149,9 +149,12 @@ class FRichProgress(Progress, metaclass=Singleton): | |||||
super().stop_task(task_id) | super().stop_task(task_id) | ||||
super().remove_task(task_id) | super().remove_task(task_id) | ||||
self.refresh() # 使得bar不残留 | self.refresh() # 使得bar不残留 | ||||
# 这里需要注释掉的原因是由于,在dataset多次apply的过程中会出现自动换行的问题。以前保留这个的原因应该是由于evaluate结束bar不消失。 | |||||
# if len(self._tasks) == 0: | |||||
# self.live.stop() | |||||
if len(self._tasks) == 0: | |||||
# 这里将这个line函数给hack一下防止stop的时候打印出空行 | |||||
old_line = getattr(self.live.console, 'line') | |||||
setattr(self.live.console, 'line', lambda *args,**kwargs:...) | |||||
self.live.stop() | |||||
setattr(self.live.console, 'line', old_line) | |||||
def start(self) -> None: | def start(self) -> None: | ||||
super().start() | super().start() | ||||