Browse Source

1.修改了发

f_rich_progress在没有bar的时候暂时关闭live;2.修改sampler获取dataset长度的方式以适配jittor
tags/v1.0.0alpha
yh 3 years ago
parent
commit
6903253e59
6 changed files with 46 additions and 41 deletions
  1. +2
    -1
      fastNLP/core/callbacks/more_evaluate_callback.py
  2. +3
    -2
      fastNLP/core/callbacks/topk_saver.py
  3. +15
    -15
      fastNLP/core/samplers/reproducible_batch_sampler.py
  4. +13
    -13
      fastNLP/core/samplers/reproducible_sampler.py
  5. +7
    -7
      fastNLP/core/samplers/unrepeated_sampler.py
  6. +6
    -3
      fastNLP/core/utils/rich_progress.py

+ 2
- 1
fastNLP/core/callbacks/more_evaluate_callback.py View File

@@ -85,7 +85,8 @@ class MoreEvaluateCallback(HasMonitorCallback):
if watch_monitor is None and evaluate_every is None:
raise RuntimeError("`evaluate_every` and `watch_monitor` cannot be None at the same time.")
if watch_monitor is not None and evaluate_every is not None:
raise RuntimeError("`evaluate_every` and `watch_monitor` cannot be set at the same time.")
raise RuntimeError(f"`evaluate_every`({evaluate_every}) and `watch_monitor`({watch_monitor}) "
f"cannot be set at the same time.")

if topk_monitor is not None and topk == 0:
raise RuntimeError("`topk_monitor` is set, but `topk` is 0.")


+ 3
- 2
fastNLP/core/callbacks/topk_saver.py View File

@@ -36,7 +36,8 @@ class Saver:
model_save_fn:Callable=None, **kwargs):
if folder is None:
folder = Path.cwd().absolute()
logger.info(f"Parameter `folder` is None, and fastNLP will use {folder} to save and load your model.")
if save_object is not None:
logger.info(f"Parameter `folder` is None, and fastNLP will use {folder} to save and load your model.")
folder = Path(folder)
if not folder.exists():
folder.mkdir(parents=True, exist_ok=True)
@@ -208,7 +209,7 @@ class TopkSaver(ResultsMonitor, Saver):
if topk is None:
topk = 0
ResultsMonitor.__init__(self, monitor, larger_better)
Saver.__init__(self, folder, save_object, only_state_dict, model_save_fn, **kwargs)
Saver.__init__(self, folder, save_object if topk!=0 else None, only_state_dict, model_save_fn, **kwargs)

if monitor is not None and topk == 0:
raise RuntimeError("`monitor` is set, but `topk` is 0.")


+ 15
- 15
fastNLP/core/samplers/reproducible_batch_sampler.py View File

@@ -210,7 +210,7 @@ class RandomBatchSampler(ReproducibleBatchSampler):
self.num_consumed_samples = 0
self.during_iter = True

indices = list(range(len(self.dataset)))
indices = list(range(getattr(self.dataset, 'total_len', len(self.dataset))))

if self.shuffle:
if self.num_consumed_samples > 0: # 需要先按照原来的排序,删掉多余的
@@ -237,7 +237,7 @@ class RandomBatchSampler(ReproducibleBatchSampler):
if len(indices)%self.batch_size!=0:
batches.append(indices[_num_batches*self.batch_size:])

need_pad_num = (len(self.dataset)-self.num_consumed_samples) % self.num_replicas
need_pad_num = (getattr(self.dataset, 'total_len', len(self.dataset))-self.num_consumed_samples) % self.num_replicas
if self.pad and need_pad_num !=0 and need_pad_num<=self.rank:
if len(batches) > 0:
if len(batches[-1])<self.batch_size:
@@ -290,9 +290,9 @@ class RandomBatchSampler(ReproducibleBatchSampler):
@property
def batch_idx_in_epoch(self):
if self.drop_last:
return len(self.dataset) // self.num_replicas // self.batch_size - self.num_left_samples // self.batch_size
return getattr(self.dataset, 'total_len', len(self.dataset)) // self.num_replicas // self.batch_size - self.num_left_samples // self.batch_size
else:
return (len(self.dataset) // self.num_replicas + self.batch_size - 1) // self.batch_size - \
return (getattr(self.dataset, 'total_len', len(self.dataset)) // self.num_replicas + self.batch_size - 1) // self.batch_size - \
(self.num_left_samples + self.batch_size - 1) // self.batch_size

@property
@@ -313,8 +313,8 @@ class RandomBatchSampler(ReproducibleBatchSampler):
:return:
"""
num_consumed_samples = self.num_consumed_samples
return math.ceil((len(self.dataset) - num_consumed_samples) / self.num_replicas) if \
self.pad else math.floor(((len(self.dataset) - num_consumed_samples) / self.num_replicas))
return math.ceil((getattr(self.dataset, 'total_len', len(self.dataset)) - num_consumed_samples) / self.num_replicas) if \
self.pad else math.floor(((getattr(self.dataset, 'total_len', len(self.dataset)) - num_consumed_samples) / self.num_replicas))

def __len__(self)->int:
"""
@@ -332,7 +332,7 @@ class RandomBatchSampler(ReproducibleBatchSampler):
raise RuntimeError("BucketedBatchSampler does not support saving before last checkpoint states have been"
" consumed. ")
states = {'seed': self.seed, 'epoch': self.epoch, 'num_consumed_samples': self.num_consumed_samples,
'sampler_type': self.__class__.__name__, 'length': len(self.dataset), 'shuffle': self.shuffle,
'sampler_type': self.__class__.__name__, 'length': getattr(self.dataset, 'total_len', len(self.dataset)), 'shuffle': self.shuffle,
'batch_size': self.batch_size,
'num_replicas': self.num_replicas}

@@ -347,7 +347,7 @@ class RandomBatchSampler(ReproducibleBatchSampler):
f"we cannot use {self.__class__.__name__} to load it."

length = states['length']
assert length == len(self.dataset), "The number of samples is different between the checkpoint record " \
assert length == getattr(self.dataset, 'total_len', len(self.dataset)), "The number of samples is different between the checkpoint record " \
"and current dataset."
self.seed = states['seed']
self.epoch = states['epoch']
@@ -464,8 +464,8 @@ class BucketedBatchSampler(ReproducibleBatchSampler):
:return:
"""
num_consumed_samples = self.num_consumed_samples
return math.ceil((len(self.dataset) - num_consumed_samples) / self.num_replicas) if \
self.pad else math.floor(((len(self.dataset) - num_consumed_samples) / self.num_replicas))
return math.ceil((getattr(self.dataset, 'total_len', len(self.dataset)) - num_consumed_samples) / self.num_replicas) if \
self.pad else math.floor(((getattr(self.dataset, 'total_len', len(self.dataset)) - num_consumed_samples) / self.num_replicas))

def __len__(self)->int:
"""
@@ -515,7 +515,7 @@ class BucketedBatchSampler(ReproducibleBatchSampler):
if len(sorted_indices)%self.batch_size!=0:
batches.append(sorted_indices[_num_batches*self.batch_size:])

need_pad_num = (len(self.dataset)-self.num_consumed_samples) % self.num_replicas
need_pad_num = (getattr(self.dataset, 'total_len', len(self.dataset))-self.num_consumed_samples) % self.num_replicas
if self.pad and need_pad_num !=0 and need_pad_num<=self.rank:
if len(batches) > 0:
if len(batches[-1])<self.batch_size:
@@ -593,7 +593,7 @@ class BucketedBatchSampler(ReproducibleBatchSampler):
raise RuntimeError("BucketedBatchSampler does not support saving before last checkpoint states have been"
" consumed. ")
states = {'seed': self.seed, 'epoch': self.epoch, 'num_consumed_samples': self.num_consumed_samples,
'sampler_type': self.__class__.__name__, 'length': len(self.dataset), 'shuffle': self.shuffle,
'sampler_type': self.__class__.__name__, 'length': getattr(self.dataset, 'total_len', len(self.dataset)), 'shuffle': self.shuffle,
'batch_size': self.batch_size, 'num_batch_per_bucket': self.num_batch_per_bucket,
'num_replicas': self.num_replicas
}
@@ -609,7 +609,7 @@ class BucketedBatchSampler(ReproducibleBatchSampler):
f"we cannot use {self.__class__.__name__} to load it."

length = states['length']
assert length == len(self.dataset), "The number of samples is different between the checkpoint record " \
assert length == getattr(self.dataset, 'total_len', len(self.dataset)), "The number of samples is different between the checkpoint record " \
"and current dataset."
self.seed = states['seed']
self.epoch = states['epoch']
@@ -630,7 +630,7 @@ class BucketedBatchSampler(ReproducibleBatchSampler):
@property
def batch_idx_in_epoch(self):
if self.drop_last:
return len(self.dataset) // self.num_replicas // self.batch_size - self.num_left_samples // self.batch_size
return getattr(self.dataset, 'total_len', len(self.dataset)) // self.num_replicas // self.batch_size - self.num_left_samples // self.batch_size
else:
return (len(self.dataset) // self.num_replicas + self.batch_size - 1) // self.batch_size - \
return (getattr(self.dataset, 'total_len', len(self.dataset)) // self.num_replicas + self.batch_size - 1) // self.batch_size - \
(self.num_left_samples + self.batch_size - 1) // self.batch_size

+ 13
- 13
fastNLP/core/samplers/reproducible_sampler.py View File

@@ -131,14 +131,14 @@ class RandomSampler(ReproducibleSampler):
:return:
"""
if self.shuffle:
indices = list(range(len(self.dataset)))
indices = list(range(getattr(self.dataset, 'total_len', len(self.dataset))))
seed = self.seed + self.epoch
rng = np.random.default_rng(abs(seed))
rng.shuffle(indices)
if self.epoch < 0: # 防止用户忘记调用 set_epoch,至少这样可以保证每次epoch出来的index顺序不同。
self.epoch -= 1
else:
indices = list(range(len(self.dataset)))
indices = list(range(getattr(self.dataset, 'total_len', len(self.dataset))))
return indices

def state_dict(self) -> Dict:
@@ -155,8 +155,8 @@ class RandomSampler(ReproducibleSampler):
f"we cannot use {self.__class__.__name__} to load it."

length = states['length']
assert length == len(self.dataset), f"The number of samples is different between the checkpoint record({length}) " \
f"and current dataset({len(self.dataset)})."
assert length == getattr(self.dataset, 'total_len', len(self.dataset)), f"The number of samples is different between the checkpoint record({length}) " \
f"and current dataset({getattr(self.dataset, 'total_len', len(self.dataset))})."
self.seed = states['seed']
self.epoch = states['epoch']
self.num_consumed_samples = states['num_consumed_samples']
@@ -208,8 +208,8 @@ class RandomSampler(ReproducibleSampler):
:return:
"""
num_consumed_samples = self.num_consumed_samples
return math.ceil((len(self.dataset) - num_consumed_samples) / self.num_replicas) if \
self.pad else math.floor(((len(self.dataset) - num_consumed_samples) / self.num_replicas))
return math.ceil((getattr(self.dataset, 'total_len', len(self.dataset)) - num_consumed_samples) / self.num_replicas) if \
self.pad else math.floor(((getattr(self.dataset, 'total_len', len(self.dataset)) - num_consumed_samples) / self.num_replicas))


class SequentialSampler(RandomSampler):
@@ -258,11 +258,11 @@ class SequentialSampler(RandomSampler):

:return:
"""
return list(range(len(self.dataset)))
return list(range(getattr(self.dataset, 'total_len', len(self.dataset))))

def state_dict(self) -> Dict:
states = {'num_consumed_samples': self.num_consumed_samples, 'sampler_type': self.__class__.__name__,
'length': len(self.dataset)
'length': getattr(self.dataset, 'total_len', len(self.dataset))
}
return states

@@ -275,8 +275,8 @@ class SequentialSampler(RandomSampler):
f"we cannot use {self.__class__.__name__} to load it."

length = states['length']
assert length == len(self.dataset), f"The number of samples is different between the checkpoint record({length}) " \
f"and current dataset({len(self.dataset)})."
assert length == getattr(self.dataset, 'total_len', len(self.dataset)), f"The number of samples is different between the checkpoint record({length}) " \
f"and current dataset({getattr(self.dataset, 'total_len', len(self.dataset))})."
self.num_consumed_samples = states['num_consumed_samples']
if self.num_consumed_samples >= length: # 如果保存的时候已经到达了最后一个sample了,则直接将结果重置为0
self.num_consumed_samples = 0
@@ -314,9 +314,9 @@ class SortedSampler(SequentialSampler):
except BaseException as e:
logger.error(f"Cannot use {self.__class__.__name__} as length, since it is not sortable.")

assert len(length) == len(dataset), f"The length of `dataset`({len(dataset)}) and " \
f"`length`({len(length)}) should be equal."
assert len(self.sorted_indices) == len(dataset), "The indices and dataset should have equal length."
assert len(length) == getattr(self.dataset, 'total_len', len(self.dataset)), f"The length of `dataset`({len(dataset)}) and " \
f"`length`({getattr(self.dataset, 'total_len', len(self.dataset))}) should be equal."
assert len(self.sorted_indices) == getattr(self.dataset, 'total_len', len(self.dataset)), "The indices and dataset should have equal length."

self.length = np.array(length, dtype=int) # 按照长到短排列的序号。
self.sorted_indices = np.argsort(self.length)[::-1].tolist() # 按长度从高到低排序的


+ 7
- 7
fastNLP/core/samplers/unrepeated_sampler.py View File

@@ -42,8 +42,8 @@ class UnrepeatedRandomSampler(UnrepeatedSampler):
返回 sampler 一次完整的迭代过程会产生多少个index。多卡的情况下,只考虑当前rank;
:return:
"""
num_common = len(self.dataset)//self.num_replicas
num_samples = num_common + int(self.rank < (len(self.dataset)-num_common*self.num_replicas))
num_common = getattr(self.dataset, 'total_len', len(self.dataset))//self.num_replicas
num_samples = num_common + int(self.rank < (getattr(self.dataset, 'total_len', len(self.dataset))-num_common*self.num_replicas))
return num_samples

def __iter__(self):
@@ -63,14 +63,14 @@ class UnrepeatedRandomSampler(UnrepeatedSampler):
:return:
"""
if self.shuffle:
indices = list(range(len(self.dataset)))
indices = list(range(getattr(self.dataset, 'total_len', len(self.dataset))))
seed = self.seed + self.epoch
rng = np.random.default_rng(abs(seed))
rng.shuffle(indices)
if self.epoch < 0: # 防止用户忘记调用 set_epoch,至少这样可以保证每次epoch出来的index顺序不同。
self.epoch -= 1
else:
indices = list(range(len(self.dataset)))
indices = list(range(getattr(self.dataset, 'total_len', len(self.dataset))))
return indices

def set_epoch(self, epoch: int) -> None:
@@ -84,8 +84,8 @@ class UnrepeatedRandomSampler(UnrepeatedSampler):
:param rank:
:return:
"""
assert num_replicas<=len(self.dataset), f"The number of replicas({num_replicas}) should be lesser than the " \
f"number of samples({len(self.dataset)})."
assert num_replicas<=getattr(self.dataset, 'total_len', len(self.dataset)), f"The number of replicas({num_replicas}) should be lesser than the " \
f"number of samples({getattr(self.dataset, 'total_len', len(self.dataset))})."
assert num_replicas>0 and isinstance(num_replicas, int)
assert isinstance(rank, int) and 0<=rank<num_replicas
# 注意初始化该函数时,所有的状态都应当默认是一个 epoch 刚开始训练的状态;
@@ -147,5 +147,5 @@ class UnrepeatedSequentialSampler(UnrepeatedRandomSampler):
yield index

def generate_indices(self) -> List[int]:
return list(range(len(self.dataset)))
return list(range(getattr(self.dataset, 'total_len', len(self.dataset))))


+ 6
- 3
fastNLP/core/utils/rich_progress.py View File

@@ -149,9 +149,12 @@ class FRichProgress(Progress, metaclass=Singleton):
super().stop_task(task_id)
super().remove_task(task_id)
self.refresh() # 使得bar不残留
# 这里需要注释掉的原因是由于,在dataset多次apply的过程中会出现自动换行的问题。以前保留这个的原因应该是由于evaluate结束bar不消失。
# if len(self._tasks) == 0:
# self.live.stop()
if len(self._tasks) == 0:
# 这里将这个line函数给hack一下防止stop的时候打印出空行
old_line = getattr(self.live.console, 'line')
setattr(self.live.console, 'line', lambda *args,**kwargs:...)
self.live.stop()
setattr(self.live.console, 'line', old_line)

def start(self) -> None:
super().start()


Loading…
Cancel
Save