Browse Source

fix tests

tags/v0.4.10
FengZiYjun 6 years ago
parent
commit
c0b67a2bc9
4 changed files with 4 additions and 6 deletions
  1. +0
    -3
      fastNLP/core/callback.py
  2. +1
    -1
      fastNLP/core/trainer.py
  3. +1
    -1
      test/core/test_dataset.py
  4. +2
    -1
      test/core/test_metrics.py

+ 0
- 3
fastNLP/core/callback.py View File

@@ -190,9 +190,6 @@ class EchoCallback(Callback):

def before_batch(self, batch_x, batch_y, indices):
print("before_batch")
print("batch_x:", batch_x)
print("batch_y:", batch_y)
print("indices: ", indices)

def before_loss(self, batch_y, predict_y):
print("before_loss")


+ 1
- 1
fastNLP/core/trainer.py View File

@@ -257,7 +257,7 @@ class Trainer(object):

self._update()
# lr scheduler; lr_finder; one_cycle
self.callback_manager.after_step()
self.callback_manager.after_step(self.optimizer)

self._summary_writer.add_scalar("loss", loss.item(), global_step=self.step)
for name, param in self.model.named_parameters():


+ 1
- 1
test/core/test_dataset.py View File

@@ -197,4 +197,4 @@ class TestDataSetIter(unittest.TestCase):
def test__repr__(self):
ds = DataSet({"x": [[1, 2, 3, 4]] * 10, "y": [[5, 6]] * 10})
for iter in ds:
self.assertEqual(iter.__repr__(), "{'x': [1, 2, 3, 4],\n'y': [5, 6]}")
self.assertEqual(iter.__repr__(), "{'x': [1, 2, 3, 4] type=list,\n'y': [5, 6] type=list}")

+ 2
- 1
test/core/test_metrics.py View File

@@ -360,7 +360,8 @@ class TestBMESF1PreRecMetric(unittest.TestCase):

metric = BMESF1PreRecMetric()
metric(pred_dict, target_dict)
self.assertDictEqual(metric.get_metric(), {'f1': 0.999999, 'precision': 1.0, 'recall': 1.0})
self.assertDictEqual(metric.get_metric(), {'f': 1.0, 'pre': 1.0, 'rec': 1.0})


class TestUsefulFunctions(unittest.TestCase):
# 测试metrics.py中一些看上去挺有用的函数


Loading…
Cancel
Save