From 9cb7cdb532a02e50948391c2322a9d0f0aa4d7ff Mon Sep 17 00:00:00 2001 From: ChenXin Date: Fri, 28 Feb 2020 00:44:15 +0800 Subject: [PATCH] =?UTF-8?q?=E7=AE=80=E5=8C=96=E4=BA=86=20tutorial=5F7=20?= =?UTF-8?q?=E7=9A=84=E4=BB=A3=E7=A0=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- docs/source/tutorials/tutorial_7_metrics.rst | 7 +- tutorials/tutorial_7_metrics.ipynb | 230 ++++++++++++++------------- 2 files changed, 121 insertions(+), 116 deletions(-) diff --git a/docs/source/tutorials/tutorial_7_metrics.rst b/docs/source/tutorials/tutorial_7_metrics.rst index bb17292c..2731c023 100644 --- a/docs/source/tutorials/tutorial_7_metrics.rst +++ b/docs/source/tutorials/tutorial_7_metrics.rst @@ -7,9 +7,8 @@ .. code-block:: python - trainer = Trainer(train_data=train_data, model=model, loss=loss, - optimizer=optimizer, batch_size=32, dev_data=dev_data, - metrics=metric, device=device) + trainer = Trainer(train_data=train_data, dev_data=dev_data, model=model, + loss=loss, device=device, metrics=metric) trainer.train() 除了 :class:`~fastNLP.AccuracyMetric` 之外,:class:`~fastNLP.SpanFPreRecMetric` 也是一种非常见的评价指标, @@ -89,7 +88,7 @@ super().__init__() # 如果没有注册该则效果与 Version 1 就是一样的 - self._init_param_map(pred=pred, target=target) # 该方法会注册label和pred. 仅需要注册evaluate()方法会用到的参数名即可 + self._init_param_map(pred=pred, target=target) # 该方法会注册 pred 和 target . 仅需要注册evaluate()方法会用到的参数名即可 # 根据你的情况自定义指标 self.total = 0 diff --git a/tutorials/tutorial_7_metrics.ipynb b/tutorials/tutorial_7_metrics.ipynb index e6780587..ef791683 100644 --- a/tutorials/tutorial_7_metrics.ipynb +++ b/tutorials/tutorial_7_metrics.ipynb @@ -11,25 +11,14 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": 2, "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/remote-home/ynzheng/anaconda3/envs/now/lib/python3.8/site-packages/FastNLP-0.5.0-py3.8.egg/fastNLP/io/loader/classification.py:340: UserWarning: SST2's test file has no target.\n" - ] - } - ], + "outputs": [], "source": [ "from fastNLP.io import SST2Pipe\n", "from fastNLP import Trainer, CrossEntropyLoss, AccuracyMetric\n", "from fastNLP.models import CNNText\n", - "from fastNLP import CrossEntropyLoss\n", "import torch\n", - "from torch.optim import Adam\n", - "from fastNLP import AccuracyMetric\n", "\n", "databundle = SST2Pipe().process_from_file()\n", "vocab = databundle.get_vocab('words')\n", @@ -40,7 +29,6 @@ "model = CNNText((len(vocab),100), num_classes=2, dropout=0.1)\n", "loss = CrossEntropyLoss()\n", "metric = AccuracyMetric()\n", - "optimizer = Adam(model.parameters(), lr=0.001)\n", "device = 0 if torch.cuda.is_available() else 'cpu'" ] }, @@ -53,7 +41,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 3, "metadata": { "scrolled": true }, @@ -63,12 +51,12 @@ "output_type": "stream", "text": [ "input fields after batch(if batch size is 2):\n", - "\twords: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2, 13]) \n", + "\twords: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2, 4]) \n", "\tseq_len: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2]) \n", "target fields after batch(if batch size is 2):\n", "\ttarget: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2]) \n", "\n", - "training epochs started 2020-02-28-00-11-51\n" + "training epochs started 2020-02-28-00-37-08\n" ] }, { @@ -104,11 +92,11 @@ "output_type": "stream", "text": [ "\r", - "Evaluate data in 0.16 seconds!\n", + "Evaluate data in 0.28 seconds!\n", "\r", "Evaluation on dev at Epoch 1/10. Step:154/1540: \n", "\r", - "AccuracyMetric: acc=0.722477\n", + "AccuracyMetric: acc=0.747706\n", "\n" ] }, @@ -131,11 +119,11 @@ "output_type": "stream", "text": [ "\r", - "Evaluate data in 0.36 seconds!\n", + "Evaluate data in 0.17 seconds!\n", "\r", "Evaluation on dev at Epoch 2/10. Step:308/1540: \n", "\r", - "AccuracyMetric: acc=0.762615\n", + "AccuracyMetric: acc=0.745413\n", "\n" ] }, @@ -158,11 +146,11 @@ "output_type": "stream", "text": [ "\r", - "Evaluate data in 0.16 seconds!\n", + "Evaluate data in 0.19 seconds!\n", "\r", "Evaluation on dev at Epoch 3/10. Step:462/1540: \n", "\r", - "AccuracyMetric: acc=0.771789\n", + "AccuracyMetric: acc=0.74656\n", "\n" ] }, @@ -185,11 +173,11 @@ "output_type": "stream", "text": [ "\r", - "Evaluate data in 0.44 seconds!\n", + "Evaluate data in 0.15 seconds!\n", "\r", "Evaluation on dev at Epoch 4/10. Step:616/1540: \n", "\r", - "AccuracyMetric: acc=0.759174\n", + "AccuracyMetric: acc=0.762615\n", "\n" ] }, @@ -212,11 +200,11 @@ "output_type": "stream", "text": [ "\r", - "Evaluate data in 0.29 seconds!\n", + "Evaluate data in 0.42 seconds!\n", "\r", "Evaluation on dev at Epoch 5/10. Step:770/1540: \n", "\r", - "AccuracyMetric: acc=0.75344\n", + "AccuracyMetric: acc=0.736239\n", "\n" ] }, @@ -239,11 +227,11 @@ "output_type": "stream", "text": [ "\r", - "Evaluate data in 0.33 seconds!\n", + "Evaluate data in 0.16 seconds!\n", "\r", "Evaluation on dev at Epoch 6/10. Step:924/1540: \n", "\r", - "AccuracyMetric: acc=0.75\n", + "AccuracyMetric: acc=0.761468\n", "\n" ] }, @@ -266,11 +254,11 @@ "output_type": "stream", "text": [ "\r", - "Evaluate data in 0.19 seconds!\n", + "Evaluate data in 0.42 seconds!\n", "\r", "Evaluation on dev at Epoch 7/10. Step:1078/1540: \n", "\r", - "AccuracyMetric: acc=0.741972\n", + "AccuracyMetric: acc=0.727064\n", "\n" ] }, @@ -293,11 +281,11 @@ "output_type": "stream", "text": [ "\r", - "Evaluate data in 0.49 seconds!\n", + "Evaluate data in 0.21 seconds!\n", "\r", "Evaluation on dev at Epoch 8/10. Step:1232/1540: \n", "\r", - "AccuracyMetric: acc=0.740826\n", + "AccuracyMetric: acc=0.731651\n", "\n" ] }, @@ -320,11 +308,11 @@ "output_type": "stream", "text": [ "\r", - "Evaluate data in 0.15 seconds!\n", + "Evaluate data in 0.52 seconds!\n", "\r", "Evaluation on dev at Epoch 9/10. Step:1386/1540: \n", "\r", - "AccuracyMetric: acc=0.75\n", + "AccuracyMetric: acc=0.752294\n", "\n" ] }, @@ -347,36 +335,35 @@ "output_type": "stream", "text": [ "\r", - "Evaluate data in 0.16 seconds!\n", + "Evaluate data in 0.44 seconds!\n", "\r", "Evaluation on dev at Epoch 10/10. Step:1540/1540: \n", "\r", - "AccuracyMetric: acc=0.752294\n", + "AccuracyMetric: acc=0.760321\n", "\n", "\r\n", - "In Epoch:3/Step:462, got best dev performance:\n", - "AccuracyMetric: acc=0.771789\n", + "In Epoch:4/Step:616, got best dev performance:\n", + "AccuracyMetric: acc=0.762615\n", "Reloaded the best model.\n" ] }, { "data": { "text/plain": [ - "{'best_eval': {'AccuracyMetric': {'acc': 0.771789}},\n", - " 'best_epoch': 3,\n", - " 'best_step': 462,\n", - " 'seconds': 30.04}" + "{'best_eval': {'AccuracyMetric': {'acc': 0.762615}},\n", + " 'best_epoch': 4,\n", + " 'best_step': 616,\n", + " 'seconds': 32.63}" ] }, - "execution_count": 2, + "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "trainer = Trainer(train_data=train_data, model=model, loss=loss,\n", - " optimizer=optimizer, batch_size=32, dev_data=dev_data,\n", - " metrics=metric, device=device)\n", + "trainer = Trainer(train_data=train_data, dev_data=dev_data, model=model,\n", + " loss=loss, device=device, metrics=metric)\n", "trainer.train()" ] }, @@ -432,7 +419,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 4, "metadata": {}, "outputs": [], "source": [ @@ -464,7 +451,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 5, "metadata": { "scrolled": true }, @@ -474,12 +461,12 @@ "output_type": "stream", "text": [ "input fields after batch(if batch size is 2):\n", - "\twords: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2, 13]) \n", + "\twords: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2, 4]) \n", "\tseq_len: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2]) \n", "target fields after batch(if batch size is 2):\n", "\ttarget: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2]) \n", "\n", - "training epochs started 2020-02-28-00-12-21\n" + "training epochs started 2020-02-28-00-37-41\n" ] }, { @@ -515,11 +502,11 @@ "output_type": "stream", "text": [ "\r", - "Evaluate data in 0.33 seconds!\n", + "Evaluate data in 0.27 seconds!\n", "\r", "Evaluation on dev at Epoch 1/10. Step:154/1540: \n", "\r", - "AccMetric: acc=0.7419724770642202\n", + "AccMetric: acc=0.7431192660550459\n", "\n" ] }, @@ -542,11 +529,11 @@ "output_type": "stream", "text": [ "\r", - "Evaluate data in 0.19 seconds!\n", + "Evaluate data in 0.42 seconds!\n", "\r", "Evaluation on dev at Epoch 2/10. Step:308/1540: \n", "\r", - "AccMetric: acc=0.7660550458715596\n", + "AccMetric: acc=0.7522935779816514\n", "\n" ] }, @@ -569,11 +556,11 @@ "output_type": "stream", "text": [ "\r", - "Evaluate data in 0.27 seconds!\n", + "Evaluate data in 0.51 seconds!\n", "\r", "Evaluation on dev at Epoch 3/10. Step:462/1540: \n", "\r", - "AccMetric: acc=0.75\n", + "AccMetric: acc=0.7477064220183486\n", "\n" ] }, @@ -596,11 +583,11 @@ "output_type": "stream", "text": [ "\r", - "Evaluate data in 0.24 seconds!\n", + "Evaluate data in 0.48 seconds!\n", "\r", "Evaluation on dev at Epoch 4/10. Step:616/1540: \n", "\r", - "AccMetric: acc=0.7534403669724771\n", + "AccMetric: acc=0.7442660550458715\n", "\n" ] }, @@ -623,11 +610,11 @@ "output_type": "stream", "text": [ "\r", - "Evaluate data in 0.29 seconds!\n", + "Evaluate data in 0.5 seconds!\n", "\r", "Evaluation on dev at Epoch 5/10. Step:770/1540: \n", "\r", - "AccMetric: acc=0.7488532110091743\n", + "AccMetric: acc=0.7362385321100917\n", "\n" ] }, @@ -650,11 +637,11 @@ "output_type": "stream", "text": [ "\r", - "Evaluate data in 0.14 seconds!\n", + "Evaluate data in 0.45 seconds!\n", "\r", "Evaluation on dev at Epoch 6/10. Step:924/1540: \n", "\r", - "AccMetric: acc=0.7488532110091743\n", + "AccMetric: acc=0.7293577981651376\n", "\n" ] }, @@ -677,11 +664,11 @@ "output_type": "stream", "text": [ "\r", - "Evaluate data in 0.27 seconds!\n", + "Evaluate data in 0.33 seconds!\n", "\r", "Evaluation on dev at Epoch 7/10. Step:1078/1540: \n", "\r", - "AccMetric: acc=0.7568807339449541\n", + "AccMetric: acc=0.7190366972477065\n", "\n" ] }, @@ -704,11 +691,11 @@ "output_type": "stream", "text": [ "\r", - "Evaluate data in 0.42 seconds!\n", + "Evaluate data in 0.29 seconds!\n", "\r", "Evaluation on dev at Epoch 8/10. Step:1232/1540: \n", "\r", - "AccMetric: acc=0.7488532110091743\n", + "AccMetric: acc=0.7419724770642202\n", "\n" ] }, @@ -731,11 +718,11 @@ "output_type": "stream", "text": [ "\r", - "Evaluate data in 0.16 seconds!\n", + "Evaluate data in 0.34 seconds!\n", "\r", "Evaluation on dev at Epoch 9/10. Step:1386/1540: \n", "\r", - "AccMetric: acc=0.7408256880733946\n", + "AccMetric: acc=0.7350917431192661\n", "\n" ] }, @@ -758,36 +745,35 @@ "output_type": "stream", "text": [ "\r", - "Evaluate data in 0.28 seconds!\n", + "Evaluate data in 0.18 seconds!\n", "\r", "Evaluation on dev at Epoch 10/10. Step:1540/1540: \n", "\r", - "AccMetric: acc=0.7408256880733946\n", + "AccMetric: acc=0.6846330275229358\n", "\n", "\r\n", "In Epoch:2/Step:308, got best dev performance:\n", - "AccMetric: acc=0.7660550458715596\n", + "AccMetric: acc=0.7522935779816514\n", "Reloaded the best model.\n" ] }, { "data": { "text/plain": [ - "{'best_eval': {'AccMetric': {'acc': 0.7660550458715596}},\n", + "{'best_eval': {'AccMetric': {'acc': 0.7522935779816514}},\n", " 'best_epoch': 2,\n", " 'best_step': 308,\n", - " 'seconds': 29.74}" + " 'seconds': 42.7}" ] }, - "execution_count": 4, + "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "trainer = Trainer(train_data=train_data, model=model, loss=loss,\n", - " optimizer=optimizer, batch_size=32, dev_data=dev_data,\n", - " metrics=AccMetric(), device=device)\n", + "trainer = Trainer(train_data=train_data, dev_data=dev_data, model=model,\n", + " loss=loss, device=device, metrics=AccMetric())\n", "trainer.train()" ] }, @@ -802,7 +788,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 6, "metadata": {}, "outputs": [], "source": [ @@ -841,7 +827,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 7, "metadata": { "scrolled": true }, @@ -851,12 +837,12 @@ "output_type": "stream", "text": [ "input fields after batch(if batch size is 2):\n", - "\twords: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2, 13]) \n", + "\twords: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2, 4]) \n", "\tseq_len: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2]) \n", "target fields after batch(if batch size is 2):\n", "\ttarget: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2]) \n", "\n", - "training epochs started 2020-02-28-00-12-51\n" + "training epochs started 2020-02-28-00-38-24\n" ] }, { @@ -892,11 +878,11 @@ "output_type": "stream", "text": [ "\r", - "Evaluate data in 0.24 seconds!\n", + "Evaluate data in 0.32 seconds!\n", "\r", "Evaluation on dev at Epoch 1/10. Step:154/1540: \n", "\r", - "AccMetric: acc=0.7545871559633027\n", + "AccMetric: acc=0.7511467889908257\n", "\n" ] }, @@ -919,11 +905,11 @@ "output_type": "stream", "text": [ "\r", - "Evaluate data in 0.24 seconds!\n", + "Evaluate data in 0.29 seconds!\n", "\r", "Evaluation on dev at Epoch 2/10. Step:308/1540: \n", "\r", - "AccMetric: acc=0.7534403669724771\n", + "AccMetric: acc=0.7454128440366973\n", "\n" ] }, @@ -946,11 +932,11 @@ "output_type": "stream", "text": [ "\r", - "Evaluate data in 0.18 seconds!\n", + "Evaluate data in 0.42 seconds!\n", "\r", "Evaluation on dev at Epoch 3/10. Step:462/1540: \n", "\r", - "AccMetric: acc=0.7557339449541285\n", + "AccMetric: acc=0.7224770642201835\n", "\n" ] }, @@ -973,11 +959,11 @@ "output_type": "stream", "text": [ "\r", - "Evaluate data in 0.11 seconds!\n", + "Evaluate data in 0.4 seconds!\n", "\r", "Evaluation on dev at Epoch 4/10. Step:616/1540: \n", "\r", - "AccMetric: acc=0.7511467889908257\n", + "AccMetric: acc=0.7534403669724771\n", "\n" ] }, @@ -1000,11 +986,11 @@ "output_type": "stream", "text": [ "\r", - "Evaluate data in 0.19 seconds!\n", + "Evaluate data in 0.41 seconds!\n", "\r", "Evaluation on dev at Epoch 5/10. Step:770/1540: \n", "\r", - "AccMetric: acc=0.7465596330275229\n", + "AccMetric: acc=0.7396788990825688\n", "\n" ] }, @@ -1027,11 +1013,11 @@ "output_type": "stream", "text": [ "\r", - "Evaluate data in 0.14 seconds!\n", + "Evaluate data in 0.22 seconds!\n", "\r", "Evaluation on dev at Epoch 6/10. Step:924/1540: \n", "\r", - "AccMetric: acc=0.7454128440366973\n", + "AccMetric: acc=0.7442660550458715\n", "\n" ] }, @@ -1054,11 +1040,11 @@ "output_type": "stream", "text": [ "\r", - "Evaluate data in 0.43 seconds!\n", + "Evaluate data in 0.45 seconds!\n", "\r", "Evaluation on dev at Epoch 7/10. Step:1078/1540: \n", "\r", - "AccMetric: acc=0.7488532110091743\n", + "AccMetric: acc=0.6903669724770642\n", "\n" ] }, @@ -1081,11 +1067,11 @@ "output_type": "stream", "text": [ "\r", - "Evaluate data in 0.21 seconds!\n", + "Evaluate data in 0.25 seconds!\n", "\r", "Evaluation on dev at Epoch 8/10. Step:1232/1540: \n", "\r", - "AccMetric: acc=0.7431192660550459\n", + "AccMetric: acc=0.7293577981651376\n", "\n" ] }, @@ -1108,11 +1094,11 @@ "output_type": "stream", "text": [ "\r", - "Evaluate data in 0.1 seconds!\n", + "Evaluate data in 0.4 seconds!\n", "\r", "Evaluation on dev at Epoch 9/10. Step:1386/1540: \n", "\r", - "AccMetric: acc=0.7477064220183486\n", + "AccMetric: acc=0.7006880733944955\n", "\n" ] }, @@ -1135,40 +1121,60 @@ "output_type": "stream", "text": [ "\r", - "Evaluate data in 0.29 seconds!\n", + "Evaluate data in 0.48 seconds!\n", "\r", "Evaluation on dev at Epoch 10/10. Step:1540/1540: \n", "\r", - "AccMetric: acc=0.7465596330275229\n", + "AccMetric: acc=0.7339449541284404\n", "\n", "\r\n", - "In Epoch:3/Step:462, got best dev performance:\n", - "AccMetric: acc=0.7557339449541285\n", + "In Epoch:4/Step:616, got best dev performance:\n", + "AccMetric: acc=0.7534403669724771\n", "Reloaded the best model.\n" ] }, { "data": { "text/plain": [ - "{'best_eval': {'AccMetric': {'acc': 0.7557339449541285}},\n", - " 'best_epoch': 3,\n", - " 'best_step': 462,\n", - " 'seconds': 28.68}" + "{'best_eval': {'AccMetric': {'acc': 0.7534403669724771}},\n", + " 'best_epoch': 4,\n", + " 'best_step': 616,\n", + " 'seconds': 34.74}" ] }, - "execution_count": 6, + "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "trainer = Trainer(train_data=train_data, model=model, loss=loss,\n", - " optimizer=optimizer, batch_size=32, dev_data=dev_data,\n", - " metrics=AccMetric(pred=\"pred\", target=\"target\"), device=device)\n", + "trainer = Trainer(train_data=train_data, dev_data=dev_data, model=model,\n", + " loss=loss, device=device, metrics=AccMetric())\n", "trainer.train()" ] }, { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "``MetricBase`` 将会在输入的字典 ``pred_dict`` 和 ``target_dict`` 中进行检查.\n", + "``pred_dict`` 是模型当中 ``forward()`` 函数或者 ``predict()`` 函数的返回值.\n", + "``target_dict`` 是DataSet当中的ground truth, 判定ground truth的条件是field的 ``is_target`` 被设置为True.\n", + "\n", + "``MetricBase`` 会进行以下的类型检测:\n", + "\n", + "1. self.evaluate当中是否有 varargs, 这是不支持的.\n", + "2. self.evaluate当中所需要的参数是否既不在 ``pred_dict`` 也不在 ``target_dict`` .\n", + "3. self.evaluate当中所需要的参数是否既在 ``pred_dict`` 也在 ``target_dict`` .\n", + "\n", + "除此以外,在参数被传入self.evaluate以前,这个函数会检测 ``pred_dict`` 和 ``target_dict`` 当中没有被用到的参数\n", + "如果kwargs是self.evaluate的参数,则不会检测\n", + "\n", + "self.evaluate将计算一个批次(batch)的评价指标,并累计。 没有返回值\n", + "self.get_metric将统计当前的评价指标并返回评价结果, 返回值需要是一个dict, key是指标名称,value是指标的值\n" + ] + }, + { "cell_type": "code", "execution_count": null, "metadata": {},