f_rich_progress在没有bar的时候暂时关闭live;2.修改sampler获取dataset长度的方式以适配jittortags/v1.0.0alpha
@@ -85,7 +85,8 @@ class MoreEvaluateCallback(HasMonitorCallback): | |||
if watch_monitor is None and evaluate_every is None: | |||
raise RuntimeError("`evaluate_every` and `watch_monitor` cannot be None at the same time.") | |||
if watch_monitor is not None and evaluate_every is not None: | |||
raise RuntimeError("`evaluate_every` and `watch_monitor` cannot be set at the same time.") | |||
raise RuntimeError(f"`evaluate_every`({evaluate_every}) and `watch_monitor`({watch_monitor}) " | |||
f"cannot be set at the same time.") | |||
if topk_monitor is not None and topk == 0: | |||
raise RuntimeError("`topk_monitor` is set, but `topk` is 0.") | |||
@@ -36,7 +36,8 @@ class Saver: | |||
model_save_fn:Callable=None, **kwargs): | |||
if folder is None: | |||
folder = Path.cwd().absolute() | |||
logger.info(f"Parameter `folder` is None, and fastNLP will use {folder} to save and load your model.") | |||
if save_object is not None: | |||
logger.info(f"Parameter `folder` is None, and fastNLP will use {folder} to save and load your model.") | |||
folder = Path(folder) | |||
if not folder.exists(): | |||
folder.mkdir(parents=True, exist_ok=True) | |||
@@ -208,7 +209,7 @@ class TopkSaver(ResultsMonitor, Saver): | |||
if topk is None: | |||
topk = 0 | |||
ResultsMonitor.__init__(self, monitor, larger_better) | |||
Saver.__init__(self, folder, save_object, only_state_dict, model_save_fn, **kwargs) | |||
Saver.__init__(self, folder, save_object if topk!=0 else None, only_state_dict, model_save_fn, **kwargs) | |||
if monitor is not None and topk == 0: | |||
raise RuntimeError("`monitor` is set, but `topk` is 0.") | |||
@@ -210,7 +210,7 @@ class RandomBatchSampler(ReproducibleBatchSampler): | |||
self.num_consumed_samples = 0 | |||
self.during_iter = True | |||
indices = list(range(len(self.dataset))) | |||
indices = list(range(getattr(self.dataset, 'total_len', len(self.dataset)))) | |||
if self.shuffle: | |||
if self.num_consumed_samples > 0: # 需要先按照原来的排序,删掉多余的 | |||
@@ -237,7 +237,7 @@ class RandomBatchSampler(ReproducibleBatchSampler): | |||
if len(indices)%self.batch_size!=0: | |||
batches.append(indices[_num_batches*self.batch_size:]) | |||
need_pad_num = (len(self.dataset)-self.num_consumed_samples) % self.num_replicas | |||
need_pad_num = (getattr(self.dataset, 'total_len', len(self.dataset))-self.num_consumed_samples) % self.num_replicas | |||
if self.pad and need_pad_num !=0 and need_pad_num<=self.rank: | |||
if len(batches) > 0: | |||
if len(batches[-1])<self.batch_size: | |||
@@ -290,9 +290,9 @@ class RandomBatchSampler(ReproducibleBatchSampler): | |||
@property | |||
def batch_idx_in_epoch(self): | |||
if self.drop_last: | |||
return len(self.dataset) // self.num_replicas // self.batch_size - self.num_left_samples // self.batch_size | |||
return getattr(self.dataset, 'total_len', len(self.dataset)) // self.num_replicas // self.batch_size - self.num_left_samples // self.batch_size | |||
else: | |||
return (len(self.dataset) // self.num_replicas + self.batch_size - 1) // self.batch_size - \ | |||
return (getattr(self.dataset, 'total_len', len(self.dataset)) // self.num_replicas + self.batch_size - 1) // self.batch_size - \ | |||
(self.num_left_samples + self.batch_size - 1) // self.batch_size | |||
@property | |||
@@ -313,8 +313,8 @@ class RandomBatchSampler(ReproducibleBatchSampler): | |||
:return: | |||
""" | |||
num_consumed_samples = self.num_consumed_samples | |||
return math.ceil((len(self.dataset) - num_consumed_samples) / self.num_replicas) if \ | |||
self.pad else math.floor(((len(self.dataset) - num_consumed_samples) / self.num_replicas)) | |||
return math.ceil((getattr(self.dataset, 'total_len', len(self.dataset)) - num_consumed_samples) / self.num_replicas) if \ | |||
self.pad else math.floor(((getattr(self.dataset, 'total_len', len(self.dataset)) - num_consumed_samples) / self.num_replicas)) | |||
def __len__(self)->int: | |||
""" | |||
@@ -332,7 +332,7 @@ class RandomBatchSampler(ReproducibleBatchSampler): | |||
raise RuntimeError("BucketedBatchSampler does not support saving before last checkpoint states have been" | |||
" consumed. ") | |||
states = {'seed': self.seed, 'epoch': self.epoch, 'num_consumed_samples': self.num_consumed_samples, | |||
'sampler_type': self.__class__.__name__, 'length': len(self.dataset), 'shuffle': self.shuffle, | |||
'sampler_type': self.__class__.__name__, 'length': getattr(self.dataset, 'total_len', len(self.dataset)), 'shuffle': self.shuffle, | |||
'batch_size': self.batch_size, | |||
'num_replicas': self.num_replicas} | |||
@@ -347,7 +347,7 @@ class RandomBatchSampler(ReproducibleBatchSampler): | |||
f"we cannot use {self.__class__.__name__} to load it." | |||
length = states['length'] | |||
assert length == len(self.dataset), "The number of samples is different between the checkpoint record " \ | |||
assert length == getattr(self.dataset, 'total_len', len(self.dataset)), "The number of samples is different between the checkpoint record " \ | |||
"and current dataset." | |||
self.seed = states['seed'] | |||
self.epoch = states['epoch'] | |||
@@ -464,8 +464,8 @@ class BucketedBatchSampler(ReproducibleBatchSampler): | |||
:return: | |||
""" | |||
num_consumed_samples = self.num_consumed_samples | |||
return math.ceil((len(self.dataset) - num_consumed_samples) / self.num_replicas) if \ | |||
self.pad else math.floor(((len(self.dataset) - num_consumed_samples) / self.num_replicas)) | |||
return math.ceil((getattr(self.dataset, 'total_len', len(self.dataset)) - num_consumed_samples) / self.num_replicas) if \ | |||
self.pad else math.floor(((getattr(self.dataset, 'total_len', len(self.dataset)) - num_consumed_samples) / self.num_replicas)) | |||
def __len__(self)->int: | |||
""" | |||
@@ -515,7 +515,7 @@ class BucketedBatchSampler(ReproducibleBatchSampler): | |||
if len(sorted_indices)%self.batch_size!=0: | |||
batches.append(sorted_indices[_num_batches*self.batch_size:]) | |||
need_pad_num = (len(self.dataset)-self.num_consumed_samples) % self.num_replicas | |||
need_pad_num = (getattr(self.dataset, 'total_len', len(self.dataset))-self.num_consumed_samples) % self.num_replicas | |||
if self.pad and need_pad_num !=0 and need_pad_num<=self.rank: | |||
if len(batches) > 0: | |||
if len(batches[-1])<self.batch_size: | |||
@@ -593,7 +593,7 @@ class BucketedBatchSampler(ReproducibleBatchSampler): | |||
raise RuntimeError("BucketedBatchSampler does not support saving before last checkpoint states have been" | |||
" consumed. ") | |||
states = {'seed': self.seed, 'epoch': self.epoch, 'num_consumed_samples': self.num_consumed_samples, | |||
'sampler_type': self.__class__.__name__, 'length': len(self.dataset), 'shuffle': self.shuffle, | |||
'sampler_type': self.__class__.__name__, 'length': getattr(self.dataset, 'total_len', len(self.dataset)), 'shuffle': self.shuffle, | |||
'batch_size': self.batch_size, 'num_batch_per_bucket': self.num_batch_per_bucket, | |||
'num_replicas': self.num_replicas | |||
} | |||
@@ -609,7 +609,7 @@ class BucketedBatchSampler(ReproducibleBatchSampler): | |||
f"we cannot use {self.__class__.__name__} to load it." | |||
length = states['length'] | |||
assert length == len(self.dataset), "The number of samples is different between the checkpoint record " \ | |||
assert length == getattr(self.dataset, 'total_len', len(self.dataset)), "The number of samples is different between the checkpoint record " \ | |||
"and current dataset." | |||
self.seed = states['seed'] | |||
self.epoch = states['epoch'] | |||
@@ -630,7 +630,7 @@ class BucketedBatchSampler(ReproducibleBatchSampler): | |||
@property | |||
def batch_idx_in_epoch(self): | |||
if self.drop_last: | |||
return len(self.dataset) // self.num_replicas // self.batch_size - self.num_left_samples // self.batch_size | |||
return getattr(self.dataset, 'total_len', len(self.dataset)) // self.num_replicas // self.batch_size - self.num_left_samples // self.batch_size | |||
else: | |||
return (len(self.dataset) // self.num_replicas + self.batch_size - 1) // self.batch_size - \ | |||
return (getattr(self.dataset, 'total_len', len(self.dataset)) // self.num_replicas + self.batch_size - 1) // self.batch_size - \ | |||
(self.num_left_samples + self.batch_size - 1) // self.batch_size |
@@ -131,14 +131,14 @@ class RandomSampler(ReproducibleSampler): | |||
:return: | |||
""" | |||
if self.shuffle: | |||
indices = list(range(len(self.dataset))) | |||
indices = list(range(getattr(self.dataset, 'total_len', len(self.dataset)))) | |||
seed = self.seed + self.epoch | |||
rng = np.random.default_rng(abs(seed)) | |||
rng.shuffle(indices) | |||
if self.epoch < 0: # 防止用户忘记调用 set_epoch,至少这样可以保证每次epoch出来的index顺序不同。 | |||
self.epoch -= 1 | |||
else: | |||
indices = list(range(len(self.dataset))) | |||
indices = list(range(getattr(self.dataset, 'total_len', len(self.dataset)))) | |||
return indices | |||
def state_dict(self) -> Dict: | |||
@@ -155,8 +155,8 @@ class RandomSampler(ReproducibleSampler): | |||
f"we cannot use {self.__class__.__name__} to load it." | |||
length = states['length'] | |||
assert length == len(self.dataset), f"The number of samples is different between the checkpoint record({length}) " \ | |||
f"and current dataset({len(self.dataset)})." | |||
assert length == getattr(self.dataset, 'total_len', len(self.dataset)), f"The number of samples is different between the checkpoint record({length}) " \ | |||
f"and current dataset({getattr(self.dataset, 'total_len', len(self.dataset))})." | |||
self.seed = states['seed'] | |||
self.epoch = states['epoch'] | |||
self.num_consumed_samples = states['num_consumed_samples'] | |||
@@ -208,8 +208,8 @@ class RandomSampler(ReproducibleSampler): | |||
:return: | |||
""" | |||
num_consumed_samples = self.num_consumed_samples | |||
return math.ceil((len(self.dataset) - num_consumed_samples) / self.num_replicas) if \ | |||
self.pad else math.floor(((len(self.dataset) - num_consumed_samples) / self.num_replicas)) | |||
return math.ceil((getattr(self.dataset, 'total_len', len(self.dataset)) - num_consumed_samples) / self.num_replicas) if \ | |||
self.pad else math.floor(((getattr(self.dataset, 'total_len', len(self.dataset)) - num_consumed_samples) / self.num_replicas)) | |||
class SequentialSampler(RandomSampler): | |||
@@ -258,11 +258,11 @@ class SequentialSampler(RandomSampler): | |||
:return: | |||
""" | |||
return list(range(len(self.dataset))) | |||
return list(range(getattr(self.dataset, 'total_len', len(self.dataset)))) | |||
def state_dict(self) -> Dict: | |||
states = {'num_consumed_samples': self.num_consumed_samples, 'sampler_type': self.__class__.__name__, | |||
'length': len(self.dataset) | |||
'length': getattr(self.dataset, 'total_len', len(self.dataset)) | |||
} | |||
return states | |||
@@ -275,8 +275,8 @@ class SequentialSampler(RandomSampler): | |||
f"we cannot use {self.__class__.__name__} to load it." | |||
length = states['length'] | |||
assert length == len(self.dataset), f"The number of samples is different between the checkpoint record({length}) " \ | |||
f"and current dataset({len(self.dataset)})." | |||
assert length == getattr(self.dataset, 'total_len', len(self.dataset)), f"The number of samples is different between the checkpoint record({length}) " \ | |||
f"and current dataset({getattr(self.dataset, 'total_len', len(self.dataset))})." | |||
self.num_consumed_samples = states['num_consumed_samples'] | |||
if self.num_consumed_samples >= length: # 如果保存的时候已经到达了最后一个sample了,则直接将结果重置为0 | |||
self.num_consumed_samples = 0 | |||
@@ -314,9 +314,9 @@ class SortedSampler(SequentialSampler): | |||
except BaseException as e: | |||
logger.error(f"Cannot use {self.__class__.__name__} as length, since it is not sortable.") | |||
assert len(length) == len(dataset), f"The length of `dataset`({len(dataset)}) and " \ | |||
f"`length`({len(length)}) should be equal." | |||
assert len(self.sorted_indices) == len(dataset), "The indices and dataset should have equal length." | |||
assert len(length) == getattr(self.dataset, 'total_len', len(self.dataset)), f"The length of `dataset`({len(dataset)}) and " \ | |||
f"`length`({getattr(self.dataset, 'total_len', len(self.dataset))}) should be equal." | |||
assert len(self.sorted_indices) == getattr(self.dataset, 'total_len', len(self.dataset)), "The indices and dataset should have equal length." | |||
self.length = np.array(length, dtype=int) # 按照长到短排列的序号。 | |||
self.sorted_indices = np.argsort(self.length)[::-1].tolist() # 按长度从高到低排序的 | |||
@@ -42,8 +42,8 @@ class UnrepeatedRandomSampler(UnrepeatedSampler): | |||
返回 sampler 一次完整的迭代过程会产生多少个index。多卡的情况下,只考虑当前rank; | |||
:return: | |||
""" | |||
num_common = len(self.dataset)//self.num_replicas | |||
num_samples = num_common + int(self.rank < (len(self.dataset)-num_common*self.num_replicas)) | |||
num_common = getattr(self.dataset, 'total_len', len(self.dataset))//self.num_replicas | |||
num_samples = num_common + int(self.rank < (getattr(self.dataset, 'total_len', len(self.dataset))-num_common*self.num_replicas)) | |||
return num_samples | |||
def __iter__(self): | |||
@@ -63,14 +63,14 @@ class UnrepeatedRandomSampler(UnrepeatedSampler): | |||
:return: | |||
""" | |||
if self.shuffle: | |||
indices = list(range(len(self.dataset))) | |||
indices = list(range(getattr(self.dataset, 'total_len', len(self.dataset)))) | |||
seed = self.seed + self.epoch | |||
rng = np.random.default_rng(abs(seed)) | |||
rng.shuffle(indices) | |||
if self.epoch < 0: # 防止用户忘记调用 set_epoch,至少这样可以保证每次epoch出来的index顺序不同。 | |||
self.epoch -= 1 | |||
else: | |||
indices = list(range(len(self.dataset))) | |||
indices = list(range(getattr(self.dataset, 'total_len', len(self.dataset)))) | |||
return indices | |||
def set_epoch(self, epoch: int) -> None: | |||
@@ -84,8 +84,8 @@ class UnrepeatedRandomSampler(UnrepeatedSampler): | |||
:param rank: | |||
:return: | |||
""" | |||
assert num_replicas<=len(self.dataset), f"The number of replicas({num_replicas}) should be lesser than the " \ | |||
f"number of samples({len(self.dataset)})." | |||
assert num_replicas<=getattr(self.dataset, 'total_len', len(self.dataset)), f"The number of replicas({num_replicas}) should be lesser than the " \ | |||
f"number of samples({getattr(self.dataset, 'total_len', len(self.dataset))})." | |||
assert num_replicas>0 and isinstance(num_replicas, int) | |||
assert isinstance(rank, int) and 0<=rank<num_replicas | |||
# 注意初始化该函数时,所有的状态都应当默认是一个 epoch 刚开始训练的状态; | |||
@@ -147,5 +147,5 @@ class UnrepeatedSequentialSampler(UnrepeatedRandomSampler): | |||
yield index | |||
def generate_indices(self) -> List[int]: | |||
return list(range(len(self.dataset))) | |||
return list(range(getattr(self.dataset, 'total_len', len(self.dataset)))) | |||
@@ -149,9 +149,12 @@ class FRichProgress(Progress, metaclass=Singleton): | |||
super().stop_task(task_id) | |||
super().remove_task(task_id) | |||
self.refresh() # 使得bar不残留 | |||
# 这里需要注释掉的原因是由于,在dataset多次apply的过程中会出现自动换行的问题。以前保留这个的原因应该是由于evaluate结束bar不消失。 | |||
# if len(self._tasks) == 0: | |||
# self.live.stop() | |||
if len(self._tasks) == 0: | |||
# 这里将这个line函数给hack一下防止stop的时候打印出空行 | |||
old_line = getattr(self.live.console, 'line') | |||
setattr(self.live.console, 'line', lambda *args,**kwargs:...) | |||
self.live.stop() | |||
setattr(self.live.console, 'line', old_line) | |||
def start(self) -> None: | |||
super().start() | |||