Browse Source

简化了 tutorial_7 的代码

tags/v0.5.5
ChenXin 5 years ago
parent
commit
9cb7cdb532
2 changed files with 121 additions and 116 deletions
  1. +3
    -4
      docs/source/tutorials/tutorial_7_metrics.rst
  2. +118
    -112
      tutorials/tutorial_7_metrics.ipynb

+ 3
- 4
docs/source/tutorials/tutorial_7_metrics.rst View File

@@ -7,9 +7,8 @@

.. code-block:: python

trainer = Trainer(train_data=train_data, model=model, loss=loss,
optimizer=optimizer, batch_size=32, dev_data=dev_data,
metrics=metric, device=device)
trainer = Trainer(train_data=train_data, dev_data=dev_data, model=model,
loss=loss, device=device, metrics=metric)
trainer.train()

除了 :class:`~fastNLP.AccuracyMetric` 之外,:class:`~fastNLP.SpanFPreRecMetric` 也是一种非常见的评价指标,
@@ -89,7 +88,7 @@
super().__init__()

# 如果没有注册该则效果与 Version 1 就是一样的
self._init_param_map(pred=pred, target=target) # 该方法会注册label和pred. 仅需要注册evaluate()方法会用到的参数名即可
self._init_param_map(pred=pred, target=target) # 该方法会注册 pred 和 target . 仅需要注册evaluate()方法会用到的参数名即可

# 根据你的情况自定义指标
self.total = 0


+ 118
- 112
tutorials/tutorial_7_metrics.ipynb View File

@@ -11,25 +11,14 @@
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/remote-home/ynzheng/anaconda3/envs/now/lib/python3.8/site-packages/FastNLP-0.5.0-py3.8.egg/fastNLP/io/loader/classification.py:340: UserWarning: SST2's test file has no target.\n"
]
}
],
"outputs": [],
"source": [
"from fastNLP.io import SST2Pipe\n",
"from fastNLP import Trainer, CrossEntropyLoss, AccuracyMetric\n",
"from fastNLP.models import CNNText\n",
"from fastNLP import CrossEntropyLoss\n",
"import torch\n",
"from torch.optim import Adam\n",
"from fastNLP import AccuracyMetric\n",
"\n",
"databundle = SST2Pipe().process_from_file()\n",
"vocab = databundle.get_vocab('words')\n",
@@ -40,7 +29,6 @@
"model = CNNText((len(vocab),100), num_classes=2, dropout=0.1)\n",
"loss = CrossEntropyLoss()\n",
"metric = AccuracyMetric()\n",
"optimizer = Adam(model.parameters(), lr=0.001)\n",
"device = 0 if torch.cuda.is_available() else 'cpu'"
]
},
@@ -53,7 +41,7 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": 3,
"metadata": {
"scrolled": true
},
@@ -63,12 +51,12 @@
"output_type": "stream",
"text": [
"input fields after batch(if batch size is 2):\n",
"\twords: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2, 13]) \n",
"\twords: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2, 4]) \n",
"\tseq_len: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2]) \n",
"target fields after batch(if batch size is 2):\n",
"\ttarget: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2]) \n",
"\n",
"training epochs started 2020-02-28-00-11-51\n"
"training epochs started 2020-02-28-00-37-08\n"
]
},
{
@@ -104,11 +92,11 @@
"output_type": "stream",
"text": [
"\r",
"Evaluate data in 0.16 seconds!\n",
"Evaluate data in 0.28 seconds!\n",
"\r",
"Evaluation on dev at Epoch 1/10. Step:154/1540: \n",
"\r",
"AccuracyMetric: acc=0.722477\n",
"AccuracyMetric: acc=0.747706\n",
"\n"
]
},
@@ -131,11 +119,11 @@
"output_type": "stream",
"text": [
"\r",
"Evaluate data in 0.36 seconds!\n",
"Evaluate data in 0.17 seconds!\n",
"\r",
"Evaluation on dev at Epoch 2/10. Step:308/1540: \n",
"\r",
"AccuracyMetric: acc=0.762615\n",
"AccuracyMetric: acc=0.745413\n",
"\n"
]
},
@@ -158,11 +146,11 @@
"output_type": "stream",
"text": [
"\r",
"Evaluate data in 0.16 seconds!\n",
"Evaluate data in 0.19 seconds!\n",
"\r",
"Evaluation on dev at Epoch 3/10. Step:462/1540: \n",
"\r",
"AccuracyMetric: acc=0.771789\n",
"AccuracyMetric: acc=0.74656\n",
"\n"
]
},
@@ -185,11 +173,11 @@
"output_type": "stream",
"text": [
"\r",
"Evaluate data in 0.44 seconds!\n",
"Evaluate data in 0.15 seconds!\n",
"\r",
"Evaluation on dev at Epoch 4/10. Step:616/1540: \n",
"\r",
"AccuracyMetric: acc=0.759174\n",
"AccuracyMetric: acc=0.762615\n",
"\n"
]
},
@@ -212,11 +200,11 @@
"output_type": "stream",
"text": [
"\r",
"Evaluate data in 0.29 seconds!\n",
"Evaluate data in 0.42 seconds!\n",
"\r",
"Evaluation on dev at Epoch 5/10. Step:770/1540: \n",
"\r",
"AccuracyMetric: acc=0.75344\n",
"AccuracyMetric: acc=0.736239\n",
"\n"
]
},
@@ -239,11 +227,11 @@
"output_type": "stream",
"text": [
"\r",
"Evaluate data in 0.33 seconds!\n",
"Evaluate data in 0.16 seconds!\n",
"\r",
"Evaluation on dev at Epoch 6/10. Step:924/1540: \n",
"\r",
"AccuracyMetric: acc=0.75\n",
"AccuracyMetric: acc=0.761468\n",
"\n"
]
},
@@ -266,11 +254,11 @@
"output_type": "stream",
"text": [
"\r",
"Evaluate data in 0.19 seconds!\n",
"Evaluate data in 0.42 seconds!\n",
"\r",
"Evaluation on dev at Epoch 7/10. Step:1078/1540: \n",
"\r",
"AccuracyMetric: acc=0.741972\n",
"AccuracyMetric: acc=0.727064\n",
"\n"
]
},
@@ -293,11 +281,11 @@
"output_type": "stream",
"text": [
"\r",
"Evaluate data in 0.49 seconds!\n",
"Evaluate data in 0.21 seconds!\n",
"\r",
"Evaluation on dev at Epoch 8/10. Step:1232/1540: \n",
"\r",
"AccuracyMetric: acc=0.740826\n",
"AccuracyMetric: acc=0.731651\n",
"\n"
]
},
@@ -320,11 +308,11 @@
"output_type": "stream",
"text": [
"\r",
"Evaluate data in 0.15 seconds!\n",
"Evaluate data in 0.52 seconds!\n",
"\r",
"Evaluation on dev at Epoch 9/10. Step:1386/1540: \n",
"\r",
"AccuracyMetric: acc=0.75\n",
"AccuracyMetric: acc=0.752294\n",
"\n"
]
},
@@ -347,36 +335,35 @@
"output_type": "stream",
"text": [
"\r",
"Evaluate data in 0.16 seconds!\n",
"Evaluate data in 0.44 seconds!\n",
"\r",
"Evaluation on dev at Epoch 10/10. Step:1540/1540: \n",
"\r",
"AccuracyMetric: acc=0.752294\n",
"AccuracyMetric: acc=0.760321\n",
"\n",
"\r\n",
"In Epoch:3/Step:462, got best dev performance:\n",
"AccuracyMetric: acc=0.771789\n",
"In Epoch:4/Step:616, got best dev performance:\n",
"AccuracyMetric: acc=0.762615\n",
"Reloaded the best model.\n"
]
},
{
"data": {
"text/plain": [
"{'best_eval': {'AccuracyMetric': {'acc': 0.771789}},\n",
" 'best_epoch': 3,\n",
" 'best_step': 462,\n",
" 'seconds': 30.04}"
"{'best_eval': {'AccuracyMetric': {'acc': 0.762615}},\n",
" 'best_epoch': 4,\n",
" 'best_step': 616,\n",
" 'seconds': 32.63}"
]
},
"execution_count": 2,
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"trainer = Trainer(train_data=train_data, model=model, loss=loss,\n",
" optimizer=optimizer, batch_size=32, dev_data=dev_data,\n",
" metrics=metric, device=device)\n",
"trainer = Trainer(train_data=train_data, dev_data=dev_data, model=model,\n",
" loss=loss, device=device, metrics=metric)\n",
"trainer.train()"
]
},
@@ -432,7 +419,7 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
@@ -464,7 +451,7 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 5,
"metadata": {
"scrolled": true
},
@@ -474,12 +461,12 @@
"output_type": "stream",
"text": [
"input fields after batch(if batch size is 2):\n",
"\twords: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2, 13]) \n",
"\twords: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2, 4]) \n",
"\tseq_len: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2]) \n",
"target fields after batch(if batch size is 2):\n",
"\ttarget: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2]) \n",
"\n",
"training epochs started 2020-02-28-00-12-21\n"
"training epochs started 2020-02-28-00-37-41\n"
]
},
{
@@ -515,11 +502,11 @@
"output_type": "stream",
"text": [
"\r",
"Evaluate data in 0.33 seconds!\n",
"Evaluate data in 0.27 seconds!\n",
"\r",
"Evaluation on dev at Epoch 1/10. Step:154/1540: \n",
"\r",
"AccMetric: acc=0.7419724770642202\n",
"AccMetric: acc=0.7431192660550459\n",
"\n"
]
},
@@ -542,11 +529,11 @@
"output_type": "stream",
"text": [
"\r",
"Evaluate data in 0.19 seconds!\n",
"Evaluate data in 0.42 seconds!\n",
"\r",
"Evaluation on dev at Epoch 2/10. Step:308/1540: \n",
"\r",
"AccMetric: acc=0.7660550458715596\n",
"AccMetric: acc=0.7522935779816514\n",
"\n"
]
},
@@ -569,11 +556,11 @@
"output_type": "stream",
"text": [
"\r",
"Evaluate data in 0.27 seconds!\n",
"Evaluate data in 0.51 seconds!\n",
"\r",
"Evaluation on dev at Epoch 3/10. Step:462/1540: \n",
"\r",
"AccMetric: acc=0.75\n",
"AccMetric: acc=0.7477064220183486\n",
"\n"
]
},
@@ -596,11 +583,11 @@
"output_type": "stream",
"text": [
"\r",
"Evaluate data in 0.24 seconds!\n",
"Evaluate data in 0.48 seconds!\n",
"\r",
"Evaluation on dev at Epoch 4/10. Step:616/1540: \n",
"\r",
"AccMetric: acc=0.7534403669724771\n",
"AccMetric: acc=0.7442660550458715\n",
"\n"
]
},
@@ -623,11 +610,11 @@
"output_type": "stream",
"text": [
"\r",
"Evaluate data in 0.29 seconds!\n",
"Evaluate data in 0.5 seconds!\n",
"\r",
"Evaluation on dev at Epoch 5/10. Step:770/1540: \n",
"\r",
"AccMetric: acc=0.7488532110091743\n",
"AccMetric: acc=0.7362385321100917\n",
"\n"
]
},
@@ -650,11 +637,11 @@
"output_type": "stream",
"text": [
"\r",
"Evaluate data in 0.14 seconds!\n",
"Evaluate data in 0.45 seconds!\n",
"\r",
"Evaluation on dev at Epoch 6/10. Step:924/1540: \n",
"\r",
"AccMetric: acc=0.7488532110091743\n",
"AccMetric: acc=0.7293577981651376\n",
"\n"
]
},
@@ -677,11 +664,11 @@
"output_type": "stream",
"text": [
"\r",
"Evaluate data in 0.27 seconds!\n",
"Evaluate data in 0.33 seconds!\n",
"\r",
"Evaluation on dev at Epoch 7/10. Step:1078/1540: \n",
"\r",
"AccMetric: acc=0.7568807339449541\n",
"AccMetric: acc=0.7190366972477065\n",
"\n"
]
},
@@ -704,11 +691,11 @@
"output_type": "stream",
"text": [
"\r",
"Evaluate data in 0.42 seconds!\n",
"Evaluate data in 0.29 seconds!\n",
"\r",
"Evaluation on dev at Epoch 8/10. Step:1232/1540: \n",
"\r",
"AccMetric: acc=0.7488532110091743\n",
"AccMetric: acc=0.7419724770642202\n",
"\n"
]
},
@@ -731,11 +718,11 @@
"output_type": "stream",
"text": [
"\r",
"Evaluate data in 0.16 seconds!\n",
"Evaluate data in 0.34 seconds!\n",
"\r",
"Evaluation on dev at Epoch 9/10. Step:1386/1540: \n",
"\r",
"AccMetric: acc=0.7408256880733946\n",
"AccMetric: acc=0.7350917431192661\n",
"\n"
]
},
@@ -758,36 +745,35 @@
"output_type": "stream",
"text": [
"\r",
"Evaluate data in 0.28 seconds!\n",
"Evaluate data in 0.18 seconds!\n",
"\r",
"Evaluation on dev at Epoch 10/10. Step:1540/1540: \n",
"\r",
"AccMetric: acc=0.7408256880733946\n",
"AccMetric: acc=0.6846330275229358\n",
"\n",
"\r\n",
"In Epoch:2/Step:308, got best dev performance:\n",
"AccMetric: acc=0.7660550458715596\n",
"AccMetric: acc=0.7522935779816514\n",
"Reloaded the best model.\n"
]
},
{
"data": {
"text/plain": [
"{'best_eval': {'AccMetric': {'acc': 0.7660550458715596}},\n",
"{'best_eval': {'AccMetric': {'acc': 0.7522935779816514}},\n",
" 'best_epoch': 2,\n",
" 'best_step': 308,\n",
" 'seconds': 29.74}"
" 'seconds': 42.7}"
]
},
"execution_count": 4,
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"trainer = Trainer(train_data=train_data, model=model, loss=loss,\n",
" optimizer=optimizer, batch_size=32, dev_data=dev_data,\n",
" metrics=AccMetric(), device=device)\n",
"trainer = Trainer(train_data=train_data, dev_data=dev_data, model=model,\n",
" loss=loss, device=device, metrics=AccMetric())\n",
"trainer.train()"
]
},
@@ -802,7 +788,7 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
@@ -841,7 +827,7 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 7,
"metadata": {
"scrolled": true
},
@@ -851,12 +837,12 @@
"output_type": "stream",
"text": [
"input fields after batch(if batch size is 2):\n",
"\twords: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2, 13]) \n",
"\twords: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2, 4]) \n",
"\tseq_len: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2]) \n",
"target fields after batch(if batch size is 2):\n",
"\ttarget: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2]) \n",
"\n",
"training epochs started 2020-02-28-00-12-51\n"
"training epochs started 2020-02-28-00-38-24\n"
]
},
{
@@ -892,11 +878,11 @@
"output_type": "stream",
"text": [
"\r",
"Evaluate data in 0.24 seconds!\n",
"Evaluate data in 0.32 seconds!\n",
"\r",
"Evaluation on dev at Epoch 1/10. Step:154/1540: \n",
"\r",
"AccMetric: acc=0.7545871559633027\n",
"AccMetric: acc=0.7511467889908257\n",
"\n"
]
},
@@ -919,11 +905,11 @@
"output_type": "stream",
"text": [
"\r",
"Evaluate data in 0.24 seconds!\n",
"Evaluate data in 0.29 seconds!\n",
"\r",
"Evaluation on dev at Epoch 2/10. Step:308/1540: \n",
"\r",
"AccMetric: acc=0.7534403669724771\n",
"AccMetric: acc=0.7454128440366973\n",
"\n"
]
},
@@ -946,11 +932,11 @@
"output_type": "stream",
"text": [
"\r",
"Evaluate data in 0.18 seconds!\n",
"Evaluate data in 0.42 seconds!\n",
"\r",
"Evaluation on dev at Epoch 3/10. Step:462/1540: \n",
"\r",
"AccMetric: acc=0.7557339449541285\n",
"AccMetric: acc=0.7224770642201835\n",
"\n"
]
},
@@ -973,11 +959,11 @@
"output_type": "stream",
"text": [
"\r",
"Evaluate data in 0.11 seconds!\n",
"Evaluate data in 0.4 seconds!\n",
"\r",
"Evaluation on dev at Epoch 4/10. Step:616/1540: \n",
"\r",
"AccMetric: acc=0.7511467889908257\n",
"AccMetric: acc=0.7534403669724771\n",
"\n"
]
},
@@ -1000,11 +986,11 @@
"output_type": "stream",
"text": [
"\r",
"Evaluate data in 0.19 seconds!\n",
"Evaluate data in 0.41 seconds!\n",
"\r",
"Evaluation on dev at Epoch 5/10. Step:770/1540: \n",
"\r",
"AccMetric: acc=0.7465596330275229\n",
"AccMetric: acc=0.7396788990825688\n",
"\n"
]
},
@@ -1027,11 +1013,11 @@
"output_type": "stream",
"text": [
"\r",
"Evaluate data in 0.14 seconds!\n",
"Evaluate data in 0.22 seconds!\n",
"\r",
"Evaluation on dev at Epoch 6/10. Step:924/1540: \n",
"\r",
"AccMetric: acc=0.7454128440366973\n",
"AccMetric: acc=0.7442660550458715\n",
"\n"
]
},
@@ -1054,11 +1040,11 @@
"output_type": "stream",
"text": [
"\r",
"Evaluate data in 0.43 seconds!\n",
"Evaluate data in 0.45 seconds!\n",
"\r",
"Evaluation on dev at Epoch 7/10. Step:1078/1540: \n",
"\r",
"AccMetric: acc=0.7488532110091743\n",
"AccMetric: acc=0.6903669724770642\n",
"\n"
]
},
@@ -1081,11 +1067,11 @@
"output_type": "stream",
"text": [
"\r",
"Evaluate data in 0.21 seconds!\n",
"Evaluate data in 0.25 seconds!\n",
"\r",
"Evaluation on dev at Epoch 8/10. Step:1232/1540: \n",
"\r",
"AccMetric: acc=0.7431192660550459\n",
"AccMetric: acc=0.7293577981651376\n",
"\n"
]
},
@@ -1108,11 +1094,11 @@
"output_type": "stream",
"text": [
"\r",
"Evaluate data in 0.1 seconds!\n",
"Evaluate data in 0.4 seconds!\n",
"\r",
"Evaluation on dev at Epoch 9/10. Step:1386/1540: \n",
"\r",
"AccMetric: acc=0.7477064220183486\n",
"AccMetric: acc=0.7006880733944955\n",
"\n"
]
},
@@ -1135,40 +1121,60 @@
"output_type": "stream",
"text": [
"\r",
"Evaluate data in 0.29 seconds!\n",
"Evaluate data in 0.48 seconds!\n",
"\r",
"Evaluation on dev at Epoch 10/10. Step:1540/1540: \n",
"\r",
"AccMetric: acc=0.7465596330275229\n",
"AccMetric: acc=0.7339449541284404\n",
"\n",
"\r\n",
"In Epoch:3/Step:462, got best dev performance:\n",
"AccMetric: acc=0.7557339449541285\n",
"In Epoch:4/Step:616, got best dev performance:\n",
"AccMetric: acc=0.7534403669724771\n",
"Reloaded the best model.\n"
]
},
{
"data": {
"text/plain": [
"{'best_eval': {'AccMetric': {'acc': 0.7557339449541285}},\n",
" 'best_epoch': 3,\n",
" 'best_step': 462,\n",
" 'seconds': 28.68}"
"{'best_eval': {'AccMetric': {'acc': 0.7534403669724771}},\n",
" 'best_epoch': 4,\n",
" 'best_step': 616,\n",
" 'seconds': 34.74}"
]
},
"execution_count": 6,
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"trainer = Trainer(train_data=train_data, model=model, loss=loss,\n",
" optimizer=optimizer, batch_size=32, dev_data=dev_data,\n",
" metrics=AccMetric(pred=\"pred\", target=\"target\"), device=device)\n",
"trainer = Trainer(train_data=train_data, dev_data=dev_data, model=model,\n",
" loss=loss, device=device, metrics=AccMetric())\n",
"trainer.train()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"``MetricBase`` 将会在输入的字典 ``pred_dict`` 和 ``target_dict`` 中进行检查.\n",
"``pred_dict`` 是模型当中 ``forward()`` 函数或者 ``predict()`` 函数的返回值.\n",
"``target_dict`` 是DataSet当中的ground truth, 判定ground truth的条件是field的 ``is_target`` 被设置为True.\n",
"\n",
"``MetricBase`` 会进行以下的类型检测:\n",
"\n",
"1. self.evaluate当中是否有 varargs, 这是不支持的.\n",
"2. self.evaluate当中所需要的参数是否既不在 ``pred_dict`` 也不在 ``target_dict`` .\n",
"3. self.evaluate当中所需要的参数是否既在 ``pred_dict`` 也在 ``target_dict`` .\n",
"\n",
"除此以外,在参数被传入self.evaluate以前,这个函数会检测 ``pred_dict`` 和 ``target_dict`` 当中没有被用到的参数\n",
"如果kwargs是self.evaluate的参数,则不会检测\n",
"\n",
"self.evaluate将计算一个批次(batch)的评价指标,并累计。 没有返回值\n",
"self.get_metric将统计当前的评价指标并返回评价结果, 返回值需要是一个dict, key是指标名称,value是指标的值\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},


Loading…
Cancel
Save