From c0b67a2bc9fdd932ba8e812ba579681c1fbac9c4 Mon Sep 17 00:00:00 2001 From: FengZiYjun Date: Fri, 11 Jan 2019 20:24:35 +0800 Subject: [PATCH] fix tests --- fastNLP/core/callback.py | 3 --- fastNLP/core/trainer.py | 2 +- test/core/test_dataset.py | 2 +- test/core/test_metrics.py | 3 ++- 4 files changed, 4 insertions(+), 6 deletions(-) diff --git a/fastNLP/core/callback.py b/fastNLP/core/callback.py index de6303ad..ce9627ea 100644 --- a/fastNLP/core/callback.py +++ b/fastNLP/core/callback.py @@ -190,9 +190,6 @@ class EchoCallback(Callback): def before_batch(self, batch_x, batch_y, indices): print("before_batch") - print("batch_x:", batch_x) - print("batch_y:", batch_y) - print("indices: ", indices) def before_loss(self, batch_y, predict_y): print("before_loss") diff --git a/fastNLP/core/trainer.py b/fastNLP/core/trainer.py index 4d27540f..109315a3 100644 --- a/fastNLP/core/trainer.py +++ b/fastNLP/core/trainer.py @@ -257,7 +257,7 @@ class Trainer(object): self._update() # lr scheduler; lr_finder; one_cycle - self.callback_manager.after_step() + self.callback_manager.after_step(self.optimizer) self._summary_writer.add_scalar("loss", loss.item(), global_step=self.step) for name, param in self.model.named_parameters(): diff --git a/test/core/test_dataset.py b/test/core/test_dataset.py index 01963af6..261d42b3 100644 --- a/test/core/test_dataset.py +++ b/test/core/test_dataset.py @@ -197,4 +197,4 @@ class TestDataSetIter(unittest.TestCase): def test__repr__(self): ds = DataSet({"x": [[1, 2, 3, 4]] * 10, "y": [[5, 6]] * 10}) for iter in ds: - self.assertEqual(iter.__repr__(), "{'x': [1, 2, 3, 4],\n'y': [5, 6]}") + self.assertEqual(iter.__repr__(), "{'x': [1, 2, 3, 4] type=list,\n'y': [5, 6] type=list}") diff --git a/test/core/test_metrics.py b/test/core/test_metrics.py index 1dbab314..80ed54e2 100644 --- a/test/core/test_metrics.py +++ b/test/core/test_metrics.py @@ -360,7 +360,8 @@ class TestBMESF1PreRecMetric(unittest.TestCase): metric = BMESF1PreRecMetric() metric(pred_dict, target_dict) - self.assertDictEqual(metric.get_metric(), {'f1': 0.999999, 'precision': 1.0, 'recall': 1.0}) + self.assertDictEqual(metric.get_metric(), {'f': 1.0, 'pre': 1.0, 'rec': 1.0}) + class TestUsefulFunctions(unittest.TestCase): # 测试metrics.py中一些看上去挺有用的函数