|
|
@@ -78,15 +78,17 @@ class RichCallback(ProgressCallback): |
|
|
|
super(RichCallback, self).on_after_trainer_initialized(trainer, driver) |
|
|
|
|
|
|
|
def on_train_begin(self, trainer): |
|
|
|
self.task2id['epoch'] = self.progress_bar.add_task(description='Epoch:0', total=trainer.n_epochs, |
|
|
|
completed=trainer.global_forward_batches/(trainer.n_batches+1e-6)) |
|
|
|
self.task2id['epoch'] = self.progress_bar.add_task(description=f'Epoch:{trainer.cur_epoch_idx}', |
|
|
|
total=trainer.n_epochs, |
|
|
|
completed=trainer.global_forward_batches/(trainer.n_batches+1e-6)* |
|
|
|
trainer.n_epochs) |
|
|
|
|
|
|
|
def on_train_epoch_begin(self, trainer): |
|
|
|
self.epoch_bar_update_advance = self.print_every/(trainer.num_batches_per_epoch + 1e-6) |
|
|
|
if 'batch' in self.task2id: |
|
|
|
self.progress_bar.reset(self.task2id['batch'], completed=trainer.batch_idx_in_epoch) |
|
|
|
else: |
|
|
|
self.task2id['batch'] = self.progress_bar.add_task(description='Batch:0', |
|
|
|
self.task2id['batch'] = self.progress_bar.add_task(description=f'Batch:{trainer.batch_idx_in_epoch}', |
|
|
|
total=trainer.num_batches_per_epoch, |
|
|
|
completed=trainer.batch_idx_in_epoch) |
|
|
|
|
|
|
@@ -249,9 +251,10 @@ class TqdmCallback(ProgressCallback): |
|
|
|
self.num_signs = 10 |
|
|
|
|
|
|
|
def on_train_begin(self, trainer): |
|
|
|
self.task2id['epoch'] = self.progress_bar.add_task(description='Epoch:0', total=trainer.n_epochs, |
|
|
|
self.task2id['epoch'] = self.progress_bar.add_task(description=f'Epoch:{trainer.cur_epoch_idx}', |
|
|
|
total=trainer.n_epochs, |
|
|
|
bar_format='{desc}: {percentage:3.0f}%|{bar}| [{elapsed}<{remaining}, {rate_fmt}, {postfix}]', |
|
|
|
initial=trainer.global_forward_batches/(trainer.n_batches+1e-6)) |
|
|
|
initial=trainer.global_forward_batches/(trainer.n_batches+1e-6)*trainer.n_epochs) |
|
|
|
|
|
|
|
def on_train_epoch_begin(self, trainer): |
|
|
|
self.epoch_bar_update_advance = self.print_every/(trainer.num_batches_per_epoch + 1e-6) |
|
|
@@ -279,7 +282,7 @@ class TqdmCallback(ProgressCallback): |
|
|
|
self.progress_bar.update(self.task2id['epoch'], advance=self.epoch_bar_update_advance, refresh=True) |
|
|
|
|
|
|
|
def on_evaluate_end(self, trainer, results): |
|
|
|
if len(results)==0: |
|
|
|
if len(results) == 0: |
|
|
|
return |
|
|
|
base_text = f'Eval. results on Epoch:{trainer.cur_epoch_idx}, Batch:{trainer.batch_idx_in_epoch}' |
|
|
|
text = '' |
|
|
|