@@ -11,25 +11,14 @@
},
{
"cell_type": "code",
"execution_count": 1 ,
"execution_count": 2 ,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/remote-home/ynzheng/anaconda3/envs/now/lib/python3.8/site-packages/FastNLP-0.5.0-py3.8.egg/fastNLP/io/loader/classification.py:340: UserWarning: SST2's test file has no target.\n"
]
}
],
"outputs": [],
"source": [
"from fastNLP.io import SST2Pipe\n",
"from fastNLP import Trainer, CrossEntropyLoss, AccuracyMetric\n",
"from fastNLP.models import CNNText\n",
"from fastNLP import CrossEntropyLoss\n",
"import torch\n",
"from torch.optim import Adam\n",
"from fastNLP import AccuracyMetric\n",
"\n",
"databundle = SST2Pipe().process_from_file()\n",
"vocab = databundle.get_vocab('words')\n",
@@ -40,7 +29,6 @@
"model = CNNText((len(vocab),100), num_classes=2, dropout=0.1)\n",
"loss = CrossEntropyLoss()\n",
"metric = AccuracyMetric()\n",
"optimizer = Adam(model.parameters(), lr=0.001)\n",
"device = 0 if torch.cuda.is_available() else 'cpu'"
]
},
@@ -53,7 +41,7 @@
},
{
"cell_type": "code",
"execution_count": 2 ,
"execution_count": 3 ,
"metadata": {
"scrolled": true
},
@@ -63,12 +51,12 @@
"output_type": "stream",
"text": [
"input fields after batch(if batch size is 2):\n",
"\twords: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2, 13 ]) \n",
"\twords: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2, 4 ]) \n",
"\tseq_len: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2]) \n",
"target fields after batch(if batch size is 2):\n",
"\ttarget: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2]) \n",
"\n",
"training epochs started 2020-02-28-00-11-51 \n"
"training epochs started 2020-02-28-00-37-08 \n"
]
},
{
@@ -104,11 +92,11 @@
"output_type": "stream",
"text": [
"\r",
"Evaluate data in 0.16 seconds!\n",
"Evaluate data in 0.28 seconds!\n",
"\r",
"Evaluation on dev at Epoch 1/10. Step:154/1540: \n",
"\r",
"AccuracyMetric: acc=0.722 477\n",
"AccuracyMetric: acc=0.747706 \n",
"\n"
]
},
@@ -131,11 +119,11 @@
"output_type": "stream",
"text": [
"\r",
"Evaluate data in 0.36 seconds!\n",
"Evaluate data in 0.17 seconds!\n",
"\r",
"Evaluation on dev at Epoch 2/10. Step:308/1540: \n",
"\r",
"AccuracyMetric: acc=0.762615 \n",
"AccuracyMetric: acc=0.745413 \n",
"\n"
]
},
@@ -158,11 +146,11 @@
"output_type": "stream",
"text": [
"\r",
"Evaluate data in 0.16 seconds!\n",
"Evaluate data in 0.19 seconds!\n",
"\r",
"Evaluation on dev at Epoch 3/10. Step:462/1540: \n",
"\r",
"AccuracyMetric: acc=0.771789 \n",
"AccuracyMetric: acc=0.74656 \n",
"\n"
]
},
@@ -185,11 +173,11 @@
"output_type": "stream",
"text": [
"\r",
"Evaluate data in 0.44 seconds!\n",
"Evaluate data in 0.15 seconds!\n",
"\r",
"Evaluation on dev at Epoch 4/10. Step:616/1540: \n",
"\r",
"AccuracyMetric: acc=0.759174 \n",
"AccuracyMetric: acc=0.762615 \n",
"\n"
]
},
@@ -212,11 +200,11 @@
"output_type": "stream",
"text": [
"\r",
"Evaluate data in 0.29 seconds!\n",
"Evaluate data in 0.4 2 seconds!\n",
"\r",
"Evaluation on dev at Epoch 5/10. Step:770/1540: \n",
"\r",
"AccuracyMetric: acc=0.75344 \n",
"AccuracyMetric: acc=0.736239 \n",
"\n"
]
},
@@ -239,11 +227,11 @@
"output_type": "stream",
"text": [
"\r",
"Evaluate data in 0.33 seconds!\n",
"Evaluate data in 0.16 seconds!\n",
"\r",
"Evaluation on dev at Epoch 6/10. Step:924/1540: \n",
"\r",
"AccuracyMetric: acc=0.75 \n",
"AccuracyMetric: acc=0.761468 \n",
"\n"
]
},
@@ -266,11 +254,11 @@
"output_type": "stream",
"text": [
"\r",
"Evaluate data in 0.19 seconds!\n",
"Evaluate data in 0.42 seconds!\n",
"\r",
"Evaluation on dev at Epoch 7/10. Step:1078/1540: \n",
"\r",
"AccuracyMetric: acc=0.74197 2\n",
"AccuracyMetric: acc=0.727064 \n",
"\n"
]
},
@@ -293,11 +281,11 @@
"output_type": "stream",
"text": [
"\r",
"Evaluate data in 0.49 seconds!\n",
"Evaluate data in 0.21 seconds!\n",
"\r",
"Evaluation on dev at Epoch 8/10. Step:1232/1540: \n",
"\r",
"AccuracyMetric: acc=0.740826 \n",
"AccuracyMetric: acc=0.731651 \n",
"\n"
]
},
@@ -320,11 +308,11 @@
"output_type": "stream",
"text": [
"\r",
"Evaluate data in 0.1 5 seconds!\n",
"Evaluate data in 0.52 seconds!\n",
"\r",
"Evaluation on dev at Epoch 9/10. Step:1386/1540: \n",
"\r",
"AccuracyMetric: acc=0.75\n",
"AccuracyMetric: acc=0.752294 \n",
"\n"
]
},
@@ -347,36 +335,35 @@
"output_type": "stream",
"text": [
"\r",
"Evaluate data in 0.16 seconds!\n",
"Evaluate data in 0.44 seconds!\n",
"\r",
"Evaluation on dev at Epoch 10/10. Step:1540/1540: \n",
"\r",
"AccuracyMetric: acc=0.752294 \n",
"AccuracyMetric: acc=0.760321 \n",
"\n",
"\r\n",
"In Epoch:3/Step:462 , got best dev performance:\n",
"AccuracyMetric: acc=0.771789 \n",
"In Epoch:4/Step:616 , got best dev performance:\n",
"AccuracyMetric: acc=0.762615 \n",
"Reloaded the best model.\n"
]
},
{
"data": {
"text/plain": [
"{'best_eval': {'AccuracyMetric': {'acc': 0.771789 }},\n",
" 'best_epoch': 3 ,\n",
" 'best_step': 462 ,\n",
" 'seconds': 30.04 }"
"{'best_eval': {'AccuracyMetric': {'acc': 0.762615 }},\n",
" 'best_epoch': 4 ,\n",
" 'best_step': 616 ,\n",
" 'seconds': 32.63 }"
]
},
"execution_count": 2 ,
"execution_count": 3 ,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"trainer = Trainer(train_data=train_data, model=model, loss=loss,\n",
" optimizer=optimizer, batch_size=32, dev_data=dev_data,\n",
" metrics=metric, device=device)\n",
"trainer = Trainer(train_data=train_data, dev_data=dev_data, model=model,\n",
" loss=loss, device=device, metrics=metric)\n",
"trainer.train()"
]
},
@@ -432,7 +419,7 @@
},
{
"cell_type": "code",
"execution_count": 3 ,
"execution_count": 4 ,
"metadata": {},
"outputs": [],
"source": [
@@ -464,7 +451,7 @@
},
{
"cell_type": "code",
"execution_count": 4 ,
"execution_count": 5 ,
"metadata": {
"scrolled": true
},
@@ -474,12 +461,12 @@
"output_type": "stream",
"text": [
"input fields after batch(if batch size is 2):\n",
"\twords: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2, 13 ]) \n",
"\twords: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2, 4 ]) \n",
"\tseq_len: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2]) \n",
"target fields after batch(if batch size is 2):\n",
"\ttarget: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2]) \n",
"\n",
"training epochs started 2020-02-28-00-12-2 1\n"
"training epochs started 2020-02-28-00-37-4 1\n"
]
},
{
@@ -515,11 +502,11 @@
"output_type": "stream",
"text": [
"\r",
"Evaluate data in 0.33 seconds!\n",
"Evaluate data in 0.27 seconds!\n",
"\r",
"Evaluation on dev at Epoch 1/10. Step:154/1540: \n",
"\r",
"AccMetric: acc=0.7419724770642202 \n",
"AccMetric: acc=0.7431192660550459 \n",
"\n"
]
},
@@ -542,11 +529,11 @@
"output_type": "stream",
"text": [
"\r",
"Evaluate data in 0.19 seconds!\n",
"Evaluate data in 0.42 seconds!\n",
"\r",
"Evaluation on dev at Epoch 2/10. Step:308/1540: \n",
"\r",
"AccMetric: acc=0.7660550458715596 \n",
"AccMetric: acc=0.7522935779816514 \n",
"\n"
]
},
@@ -569,11 +556,11 @@
"output_type": "stream",
"text": [
"\r",
"Evaluate data in 0.27 seconds!\n",
"Evaluate data in 0.51 seconds!\n",
"\r",
"Evaluation on dev at Epoch 3/10. Step:462/1540: \n",
"\r",
"AccMetric: acc=0.75 \n",
"AccMetric: acc=0.7477064220183486 \n",
"\n"
]
},
@@ -596,11 +583,11 @@
"output_type": "stream",
"text": [
"\r",
"Evaluate data in 0.2 4 seconds!\n",
"Evaluate data in 0.48 seconds!\n",
"\r",
"Evaluation on dev at Epoch 4/10. Step:616/1540: \n",
"\r",
"AccMetric: acc=0.7534403669724771 \n",
"AccMetric: acc=0.7442660550458715 \n",
"\n"
]
},
@@ -623,11 +610,11 @@
"output_type": "stream",
"text": [
"\r",
"Evaluate data in 0.29 seconds!\n",
"Evaluate data in 0.5 seconds!\n",
"\r",
"Evaluation on dev at Epoch 5/10. Step:770/1540: \n",
"\r",
"AccMetric: acc=0.7488532110091743 \n",
"AccMetric: acc=0.7362385321100917 \n",
"\n"
]
},
@@ -650,11 +637,11 @@
"output_type": "stream",
"text": [
"\r",
"Evaluate data in 0.1 4 seconds!\n",
"Evaluate data in 0.45 seconds!\n",
"\r",
"Evaluation on dev at Epoch 6/10. Step:924/1540: \n",
"\r",
"AccMetric: acc=0.7488532110091743 \n",
"AccMetric: acc=0.7293577981651376 \n",
"\n"
]
},
@@ -677,11 +664,11 @@
"output_type": "stream",
"text": [
"\r",
"Evaluate data in 0.27 seconds!\n",
"Evaluate data in 0.33 seconds!\n",
"\r",
"Evaluation on dev at Epoch 7/10. Step:1078/1540: \n",
"\r",
"AccMetric: acc=0.7568807339449541 \n",
"AccMetric: acc=0.7190366972477065 \n",
"\n"
]
},
@@ -704,11 +691,11 @@
"output_type": "stream",
"text": [
"\r",
"Evaluate data in 0.4 2 seconds!\n",
"Evaluate data in 0.29 seconds!\n",
"\r",
"Evaluation on dev at Epoch 8/10. Step:1232/1540: \n",
"\r",
"AccMetric: acc=0.7488532110091743 \n",
"AccMetric: acc=0.7419724770642202 \n",
"\n"
]
},
@@ -731,11 +718,11 @@
"output_type": "stream",
"text": [
"\r",
"Evaluate data in 0.16 seconds!\n",
"Evaluate data in 0.34 seconds!\n",
"\r",
"Evaluation on dev at Epoch 9/10. Step:1386/1540: \n",
"\r",
"AccMetric: acc=0.7408256880733946 \n",
"AccMetric: acc=0.7350917431192661 \n",
"\n"
]
},
@@ -758,36 +745,35 @@
"output_type": "stream",
"text": [
"\r",
"Evaluate data in 0.2 8 seconds!\n",
"Evaluate data in 0.1 8 seconds!\n",
"\r",
"Evaluation on dev at Epoch 10/10. Step:1540/1540: \n",
"\r",
"AccMetric: acc=0.7408256880733946 \n",
"AccMetric: acc=0.6846330275229358 \n",
"\n",
"\r\n",
"In Epoch:2/Step:308, got best dev performance:\n",
"AccMetric: acc=0.7660550458715596 \n",
"AccMetric: acc=0.7522935779816514 \n",
"Reloaded the best model.\n"
]
},
{
"data": {
"text/plain": [
"{'best_eval': {'AccMetric': {'acc': 0.7660550458715596 }},\n",
"{'best_eval': {'AccMetric': {'acc': 0.7522935779816514 }},\n",
" 'best_epoch': 2,\n",
" 'best_step': 308,\n",
" 'seconds': 29 .74 }"
" 'seconds': 4 2.7}"
]
},
"execution_count": 4 ,
"execution_count": 5 ,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"trainer = Trainer(train_data=train_data, model=model, loss=loss,\n",
" optimizer=optimizer, batch_size=32, dev_data=dev_data,\n",
" metrics=AccMetric(), device=device)\n",
"trainer = Trainer(train_data=train_data, dev_data=dev_data, model=model,\n",
" loss=loss, device=device, metrics=AccMetric())\n",
"trainer.train()"
]
},
@@ -802,7 +788,7 @@
},
{
"cell_type": "code",
"execution_count": 5 ,
"execution_count": 6 ,
"metadata": {},
"outputs": [],
"source": [
@@ -841,7 +827,7 @@
},
{
"cell_type": "code",
"execution_count": 6 ,
"execution_count": 7 ,
"metadata": {
"scrolled": true
},
@@ -851,12 +837,12 @@
"output_type": "stream",
"text": [
"input fields after batch(if batch size is 2):\n",
"\twords: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2, 13 ]) \n",
"\twords: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2, 4 ]) \n",
"\tseq_len: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2]) \n",
"target fields after batch(if batch size is 2):\n",
"\ttarget: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2]) \n",
"\n",
"training epochs started 2020-02-28-00-12-51 \n"
"training epochs started 2020-02-28-00-38-24 \n"
]
},
{
@@ -892,11 +878,11 @@
"output_type": "stream",
"text": [
"\r",
"Evaluate data in 0.24 seconds!\n",
"Evaluate data in 0.3 2 seconds!\n",
"\r",
"Evaluation on dev at Epoch 1/10. Step:154/1540: \n",
"\r",
"AccMetric: acc=0.754587155963302 7\n",
"AccMetric: acc=0.751146788990825 7\n",
"\n"
]
},
@@ -919,11 +905,11 @@
"output_type": "stream",
"text": [
"\r",
"Evaluate data in 0.24 seconds!\n",
"Evaluate data in 0.29 seconds!\n",
"\r",
"Evaluation on dev at Epoch 2/10. Step:308/1540: \n",
"\r",
"AccMetric: acc=0.7534403669724771 \n",
"AccMetric: acc=0.7454128440366973 \n",
"\n"
]
},
@@ -946,11 +932,11 @@
"output_type": "stream",
"text": [
"\r",
"Evaluate data in 0.18 seconds!\n",
"Evaluate data in 0.42 seconds!\n",
"\r",
"Evaluation on dev at Epoch 3/10. Step:462/1540: \n",
"\r",
"AccMetric: acc=0.755733944954128 5\n",
"AccMetric: acc=0.722477064220183 5\n",
"\n"
]
},
@@ -973,11 +959,11 @@
"output_type": "stream",
"text": [
"\r",
"Evaluate data in 0.11 seconds!\n",
"Evaluate data in 0.4 seconds!\n",
"\r",
"Evaluation on dev at Epoch 4/10. Step:616/1540: \n",
"\r",
"AccMetric: acc=0.7511467889908257 \n",
"AccMetric: acc=0.7534403669724771 \n",
"\n"
]
},
@@ -1000,11 +986,11 @@
"output_type": "stream",
"text": [
"\r",
"Evaluate data in 0.19 seconds!\n",
"Evaluate data in 0.4 1 seconds!\n",
"\r",
"Evaluation on dev at Epoch 5/10. Step:770/1540: \n",
"\r",
"AccMetric: acc=0.7465596330275229 \n",
"AccMetric: acc=0.7396788990825688 \n",
"\n"
]
},
@@ -1027,11 +1013,11 @@
"output_type": "stream",
"text": [
"\r",
"Evaluate data in 0.14 seconds!\n",
"Evaluate data in 0.22 seconds!\n",
"\r",
"Evaluation on dev at Epoch 6/10. Step:924/1540: \n",
"\r",
"AccMetric: acc=0.7454128440366973 \n",
"AccMetric: acc=0.7442660550458715 \n",
"\n"
]
},
@@ -1054,11 +1040,11 @@
"output_type": "stream",
"text": [
"\r",
"Evaluate data in 0.43 seconds!\n",
"Evaluate data in 0.45 seconds!\n",
"\r",
"Evaluation on dev at Epoch 7/10. Step:1078/1540: \n",
"\r",
"AccMetric: acc=0.7488532110091743 \n",
"AccMetric: acc=0.6903669724770642 \n",
"\n"
]
},
@@ -1081,11 +1067,11 @@
"output_type": "stream",
"text": [
"\r",
"Evaluate data in 0.21 seconds!\n",
"Evaluate data in 0.25 seconds!\n",
"\r",
"Evaluation on dev at Epoch 8/10. Step:1232/1540: \n",
"\r",
"AccMetric: acc=0.7431192660550459 \n",
"AccMetric: acc=0.7293577981651376 \n",
"\n"
]
},
@@ -1108,11 +1094,11 @@
"output_type": "stream",
"text": [
"\r",
"Evaluate data in 0.1 seconds!\n",
"Evaluate data in 0.4 seconds!\n",
"\r",
"Evaluation on dev at Epoch 9/10. Step:1386/1540: \n",
"\r",
"AccMetric: acc=0.7477064220183486 \n",
"AccMetric: acc=0.7006880733944955 \n",
"\n"
]
},
@@ -1135,40 +1121,60 @@
"output_type": "stream",
"text": [
"\r",
"Evaluate data in 0.29 seconds!\n",
"Evaluate data in 0.48 seconds!\n",
"\r",
"Evaluation on dev at Epoch 10/10. Step:1540/1540: \n",
"\r",
"AccMetric: acc=0.7465596330275229 \n",
"AccMetric: acc=0.7339449541284404 \n",
"\n",
"\r\n",
"In Epoch:3/Step:462 , got best dev performance:\n",
"AccMetric: acc=0.7557339449541285 \n",
"In Epoch:4/Step:616 , got best dev performance:\n",
"AccMetric: acc=0.7534403669724771 \n",
"Reloaded the best model.\n"
]
},
{
"data": {
"text/plain": [
"{'best_eval': {'AccMetric': {'acc': 0.7557339449541285 }},\n",
" 'best_epoch': 3 ,\n",
" 'best_step': 462 ,\n",
" 'seconds': 28.68 }"
"{'best_eval': {'AccMetric': {'acc': 0.7534403669724771 }},\n",
" 'best_epoch': 4 ,\n",
" 'best_step': 616 ,\n",
" 'seconds': 34.74 }"
]
},
"execution_count": 6 ,
"execution_count": 7 ,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"trainer = Trainer(train_data=train_data, model=model, loss=loss,\n",
" optimizer=optimizer, batch_size=32, dev_data=dev_data,\n",
" metrics=AccMetric(pred=\"pred\", target=\"target\"), device=device)\n",
"trainer = Trainer(train_data=train_data, dev_data=dev_data, model=model,\n",
" loss=loss, device=device, metrics=AccMetric())\n",
"trainer.train()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"``MetricBase`` 将会在输入的字典 ``pred_dict`` 和 ``target_dict`` 中进行检查.\n",
"``pred_dict`` 是模型当中 ``forward()`` 函数或者 ``predict()`` 函数的返回值.\n",
"``target_dict`` 是DataSet当中的ground truth, 判定ground truth的条件是field的 ``is_target`` 被设置为True.\n",
"\n",
"``MetricBase`` 会进行以下的类型检测:\n",
"\n",
"1. self.evaluate当中是否有 varargs, 这是不支持的.\n",
"2. self.evaluate当中所需要的参数是否既不在 ``pred_dict`` 也不在 ``target_dict`` .\n",
"3. self.evaluate当中所需要的参数是否既在 ``pred_dict`` 也在 ``target_dict`` .\n",
"\n",
"除此以外,在参数被传入self.evaluate以前,这个函数会检测 ``pred_dict`` 和 ``target_dict`` 当中没有被用到的参数\n",
"如果kwargs是self.evaluate的参数,则不会检测\n",
"\n",
"self.evaluate将计算一个批次(batch)的评价指标,并累计。 没有返回值\n",
"self.get_metric将统计当前的评价指标并返回评价结果, 返回值需要是一个dict, key是指标名称,value是指标的值\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},