|
|
@@ -42,36 +42,32 @@ |
|
|
|
}, |
|
|
|
{ |
|
|
|
"cell_type": "code", |
|
|
|
"execution_count": 5, |
|
|
|
"execution_count": 11, |
|
|
|
"metadata": { |
|
|
|
"ExecuteTime": { |
|
|
|
"end_time": "2017-12-22T12:56:06.772059Z", |
|
|
|
"start_time": "2017-12-22T12:56:06.766027Z" |
|
|
|
}, |
|
|
|
"collapsed": true |
|
|
|
} |
|
|
|
}, |
|
|
|
"outputs": [], |
|
|
|
"source": [ |
|
|
|
"import sys\n", |
|
|
|
"sys.path.append('..')\n", |
|
|
|
"\n", |
|
|
|
"import numpy as np\n", |
|
|
|
"import torch\n", |
|
|
|
"from torch import nn\n", |
|
|
|
"import torch.nn.functional as F\n", |
|
|
|
"from torch.autograd import Variable\n", |
|
|
|
"from torchvision.datasets import CIFAR10" |
|
|
|
"from torchvision.datasets import CIFAR10\n", |
|
|
|
"from torchvision import transforms as tfs" |
|
|
|
] |
|
|
|
}, |
|
|
|
{ |
|
|
|
"cell_type": "code", |
|
|
|
"execution_count": 6, |
|
|
|
"execution_count": 2, |
|
|
|
"metadata": { |
|
|
|
"ExecuteTime": { |
|
|
|
"end_time": "2017-12-22T12:47:49.222432Z", |
|
|
|
"start_time": "2017-12-22T12:47:49.217940Z" |
|
|
|
}, |
|
|
|
"collapsed": true |
|
|
|
} |
|
|
|
}, |
|
|
|
"outputs": [], |
|
|
|
"source": [ |
|
|
@@ -82,19 +78,18 @@ |
|
|
|
}, |
|
|
|
{ |
|
|
|
"cell_type": "code", |
|
|
|
"execution_count": 7, |
|
|
|
"execution_count": 3, |
|
|
|
"metadata": { |
|
|
|
"ExecuteTime": { |
|
|
|
"end_time": "2017-12-22T13:14:02.429145Z", |
|
|
|
"start_time": "2017-12-22T13:14:02.383322Z" |
|
|
|
}, |
|
|
|
"collapsed": true |
|
|
|
} |
|
|
|
}, |
|
|
|
"outputs": [], |
|
|
|
"source": [ |
|
|
|
"class Residual_Block(nn.Module):\n", |
|
|
|
" def __init__(self, in_channel, out_channel, same_shape=True):\n", |
|
|
|
" super(residual_block, self).__init__()\n", |
|
|
|
" super(Residual_Block, self).__init__()\n", |
|
|
|
" self.same_shape = same_shape\n", |
|
|
|
" stride=1 if self.same_shape else 2\n", |
|
|
|
" \n", |
|
|
@@ -127,7 +122,7 @@ |
|
|
|
}, |
|
|
|
{ |
|
|
|
"cell_type": "code", |
|
|
|
"execution_count": 8, |
|
|
|
"execution_count": 4, |
|
|
|
"metadata": { |
|
|
|
"ExecuteTime": { |
|
|
|
"end_time": "2017-12-22T13:14:05.793185Z", |
|
|
@@ -155,7 +150,7 @@ |
|
|
|
}, |
|
|
|
{ |
|
|
|
"cell_type": "code", |
|
|
|
"execution_count": 9, |
|
|
|
"execution_count": 5, |
|
|
|
"metadata": { |
|
|
|
"ExecuteTime": { |
|
|
|
"end_time": "2017-12-22T13:14:11.929120Z", |
|
|
@@ -201,13 +196,12 @@ |
|
|
|
}, |
|
|
|
{ |
|
|
|
"cell_type": "code", |
|
|
|
"execution_count": 10, |
|
|
|
"execution_count": 6, |
|
|
|
"metadata": { |
|
|
|
"ExecuteTime": { |
|
|
|
"end_time": "2017-12-22T13:27:46.099404Z", |
|
|
|
"start_time": "2017-12-22T13:27:45.986235Z" |
|
|
|
}, |
|
|
|
"collapsed": true |
|
|
|
} |
|
|
|
}, |
|
|
|
"outputs": [], |
|
|
|
"source": [ |
|
|
@@ -272,7 +266,7 @@ |
|
|
|
}, |
|
|
|
{ |
|
|
|
"cell_type": "code", |
|
|
|
"execution_count": 11, |
|
|
|
"execution_count": 8, |
|
|
|
"metadata": { |
|
|
|
"ExecuteTime": { |
|
|
|
"end_time": "2017-12-22T13:28:00.597030Z", |
|
|
@@ -302,39 +296,39 @@ |
|
|
|
}, |
|
|
|
{ |
|
|
|
"cell_type": "code", |
|
|
|
"execution_count": 32, |
|
|
|
"execution_count": 12, |
|
|
|
"metadata": { |
|
|
|
"ExecuteTime": { |
|
|
|
"end_time": "2017-12-22T13:29:01.484172Z", |
|
|
|
"start_time": "2017-12-22T13:29:00.095952Z" |
|
|
|
}, |
|
|
|
"collapsed": true |
|
|
|
} |
|
|
|
}, |
|
|
|
"outputs": [], |
|
|
|
"source": [ |
|
|
|
"from utils import train\n", |
|
|
|
"\n", |
|
|
|
"def data_tf(x):\n", |
|
|
|
" x = x.resize((96, 96), 2) # 将图片放大到 96 x 96\n", |
|
|
|
" x = np.array(x, dtype='float32') / 255\n", |
|
|
|
" x = (x - 0.5) / 0.5 # 标准化,这个技巧之后会讲到\n", |
|
|
|
" x = x.transpose((2, 0, 1)) # 将 channel 放到第一维,只是 pytorch 要求的输入方式\n", |
|
|
|
" x = torch.from_numpy(x)\n", |
|
|
|
" im_aug = tfs.Compose([\n", |
|
|
|
" tfs.Resize(96),\n", |
|
|
|
" tfs.ToTensor(),\n", |
|
|
|
" tfs.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])\n", |
|
|
|
" ])\n", |
|
|
|
" x = im_aug(x)\n", |
|
|
|
" return x\n", |
|
|
|
" \n", |
|
|
|
"train_set = CIFAR10('../../data', train=True, transform=data_tf)\n", |
|
|
|
"train_set = CIFAR10('../../data', train=True, transform=data_tf)\n", |
|
|
|
"train_data = torch.utils.data.DataLoader(train_set, batch_size=64, shuffle=True)\n", |
|
|
|
"test_set = CIFAR10('../../data', train=False, transform=data_tf)\n", |
|
|
|
"test_data = torch.utils.data.DataLoader(test_set, batch_size=128, shuffle=False)\n", |
|
|
|
"test_set = CIFAR10('../../data', train=False, transform=data_tf)\n", |
|
|
|
"test_data = torch.utils.data.DataLoader(test_set, batch_size=128, shuffle=False)\n", |
|
|
|
"\n", |
|
|
|
"net = ResNet(3, 10)\n", |
|
|
|
"optimizer = torch.optim.SGD(net.parameters(), lr=0.01)\n", |
|
|
|
"optimizer = torch.optim.Adam(net.parameters(), lr=1e-3)\n", |
|
|
|
"criterion = nn.CrossEntropyLoss()" |
|
|
|
] |
|
|
|
}, |
|
|
|
{ |
|
|
|
"cell_type": "code", |
|
|
|
"execution_count": 33, |
|
|
|
"execution_count": null, |
|
|
|
"metadata": { |
|
|
|
"ExecuteTime": { |
|
|
|
"end_time": "2017-12-22T13:45:00.783186Z", |
|
|
@@ -346,31 +340,60 @@ |
|
|
|
"name": "stdout", |
|
|
|
"output_type": "stream", |
|
|
|
"text": [ |
|
|
|
"Epoch 0. Train Loss: 1.437317, Train Acc: 0.476662, Valid Loss: 1.928288, Valid Acc: 0.384691, Time 00:00:44\n", |
|
|
|
"Epoch 1. Train Loss: 0.992832, Train Acc: 0.648198, Valid Loss: 1.009847, Valid Acc: 0.642405, Time 00:00:48\n", |
|
|
|
"Epoch 2. Train Loss: 0.767309, Train Acc: 0.732617, Valid Loss: 1.827319, Valid Acc: 0.430380, Time 00:00:47\n", |
|
|
|
"Epoch 3. Train Loss: 0.606737, Train Acc: 0.788043, Valid Loss: 1.304808, Valid Acc: 0.585245, Time 00:00:46\n", |
|
|
|
"Epoch 4. Train Loss: 0.484436, Train Acc: 0.834499, Valid Loss: 1.335749, Valid Acc: 0.617089, Time 00:00:47\n", |
|
|
|
"Epoch 5. Train Loss: 0.374320, Train Acc: 0.872922, Valid Loss: 0.878519, Valid Acc: 0.724288, Time 00:00:47\n", |
|
|
|
"Epoch 6. Train Loss: 0.280981, Train Acc: 0.904212, Valid Loss: 0.931616, Valid Acc: 0.716871, Time 00:00:48\n", |
|
|
|
"Epoch 7. Train Loss: 0.210800, Train Acc: 0.929747, Valid Loss: 1.448870, Valid Acc: 0.638548, Time 00:00:48\n", |
|
|
|
"Epoch 8. Train Loss: 0.147873, Train Acc: 0.951427, Valid Loss: 1.356992, Valid Acc: 0.657536, Time 00:00:47\n", |
|
|
|
"Epoch 9. Train Loss: 0.112824, Train Acc: 0.963895, Valid Loss: 1.630560, Valid Acc: 0.627769, Time 00:00:47\n", |
|
|
|
"Epoch 10. Train Loss: 0.082685, Train Acc: 0.973905, Valid Loss: 0.982882, Valid Acc: 0.744264, Time 00:00:44\n", |
|
|
|
"Epoch 11. Train Loss: 0.065325, Train Acc: 0.979680, Valid Loss: 0.911631, Valid Acc: 0.767009, Time 00:00:47\n", |
|
|
|
"Epoch 12. Train Loss: 0.041401, Train Acc: 0.987952, Valid Loss: 1.167992, Valid Acc: 0.729826, Time 00:00:48\n", |
|
|
|
"Epoch 13. Train Loss: 0.037516, Train Acc: 0.989011, Valid Loss: 1.081807, Valid Acc: 0.746737, Time 00:00:47\n", |
|
|
|
"Epoch 14. Train Loss: 0.030674, Train Acc: 0.991468, Valid Loss: 0.935292, Valid Acc: 0.774031, Time 00:00:45\n", |
|
|
|
"Epoch 15. Train Loss: 0.021743, Train Acc: 0.994565, Valid Loss: 0.879348, Valid Acc: 0.790150, Time 00:00:47\n", |
|
|
|
"Epoch 16. Train Loss: 0.014642, Train Acc: 0.996463, Valid Loss: 1.328587, Valid Acc: 0.724387, Time 00:00:47\n", |
|
|
|
"Epoch 17. Train Loss: 0.011072, Train Acc: 0.997363, Valid Loss: 0.909065, Valid Acc: 0.792919, Time 00:00:47\n", |
|
|
|
"Epoch 18. Train Loss: 0.006870, Train Acc: 0.998561, Valid Loss: 0.923746, Valid Acc: 0.794403, Time 00:00:46\n", |
|
|
|
"Epoch 19. Train Loss: 0.004240, Train Acc: 0.999500, Valid Loss: 0.877908, Valid Acc: 0.802314, Time 00:00:46\n" |
|
|
|
"[ 0] Train:(L=1.506980, Acc=0.449868), Valid:(L=1.119623, Acc=0.598596), T: 00:00:48\n", |
|
|
|
"[ 1] Train:(L=1.022635, Acc=0.641504), Valid:(L=0.942414, Acc=0.669600), T: 00:00:47\n", |
|
|
|
"[ 2] Train:(L=0.806174, Acc=0.717551), Valid:(L=0.921687, Acc=0.682061), T: 00:00:47\n", |
|
|
|
"[ 3] Train:(L=0.638939, Acc=0.775555), Valid:(L=0.802450, Acc=0.729727), T: 00:00:47\n", |
|
|
|
"[ 4] Train:(L=0.497571, Acc=0.826606), Valid:(L=0.658700, Acc=0.775316), T: 00:00:47\n", |
|
|
|
"[ 5] Train:(L=0.364864, Acc=0.872442), Valid:(L=0.717290, Acc=0.768888), T: 00:00:47\n", |
|
|
|
"[ 6] Train:(L=0.263076, Acc=0.907888), Valid:(L=0.832575, Acc=0.750000), T: 00:00:47\n", |
|
|
|
"[ 7] Train:(L=0.181254, Acc=0.935782), Valid:(L=0.818366, Acc=0.764933), T: 00:00:47\n", |
|
|
|
"[ 8] Train:(L=0.124111, Acc=0.957820), Valid:(L=0.883527, Acc=0.778184), T: 00:00:47\n", |
|
|
|
"[ 9] Train:(L=0.108587, Acc=0.961657), Valid:(L=0.899127, Acc=0.780756), T: 00:00:47\n", |
|
|
|
"[10] Train:(L=0.091386, Acc=0.968670), Valid:(L=0.975022, Acc=0.781448), T: 00:00:47\n", |
|
|
|
"[11] Train:(L=0.079259, Acc=0.972287), Valid:(L=1.061239, Acc=0.770075), T: 00:00:47\n", |
|
|
|
"[12] Train:(L=0.067858, Acc=0.976123), Valid:(L=1.025909, Acc=0.782140), T: 00:00:47\n", |
|
|
|
"[13] Train:(L=0.064745, Acc=0.977701), Valid:(L=0.987410, Acc=0.789062), T: 00:00:47\n", |
|
|
|
"[14] Train:(L=0.056921, Acc=0.979779), Valid:(L=1.165746, Acc=0.773438), T: 00:00:47\n", |
|
|
|
"[15] Train:(L=0.058128, Acc=0.980039), Valid:(L=1.057119, Acc=0.782437), T: 00:00:47\n", |
|
|
|
"[16] Train:(L=0.050794, Acc=0.982257), Valid:(L=1.098127, Acc=0.779074), T: 00:00:47\n", |
|
|
|
"[17] Train:(L=0.046720, Acc=0.984415), Valid:(L=1.066124, Acc=0.787184), T: 00:00:47\n", |
|
|
|
"[18] Train:(L=0.044737, Acc=0.984375), Valid:(L=1.053032, Acc=0.792029), T: 00:00:47\n" |
|
|
|
] |
|
|
|
} |
|
|
|
], |
|
|
|
"source": [ |
|
|
|
"train(net, train_data, test_data, 20, optimizer, criterion)" |
|
|
|
"res = train(net, train_data, test_data, 20, optimizer, criterion)" |
|
|
|
] |
|
|
|
}, |
|
|
|
{ |
|
|
|
"cell_type": "code", |
|
|
|
"execution_count": null, |
|
|
|
"metadata": {}, |
|
|
|
"outputs": [], |
|
|
|
"source": [ |
|
|
|
"import matplotlib.pyplot as plt\n", |
|
|
|
"%matplotlib inline\n", |
|
|
|
"\n", |
|
|
|
"plt.plot(res[0], label='train')\n", |
|
|
|
"plt.plot(res[2], label='valid')\n", |
|
|
|
"plt.xlabel('epoch')\n", |
|
|
|
"plt.ylabel('Loss')\n", |
|
|
|
"plt.legend(loc='best')\n", |
|
|
|
"plt.savefig('fig-res-resnet-train-validate-loss.pdf')\n", |
|
|
|
"plt.show()\n", |
|
|
|
"\n", |
|
|
|
"plt.plot(res[1], label='train')\n", |
|
|
|
"plt.plot(res[3], label='valid')\n", |
|
|
|
"plt.xlabel('epoch')\n", |
|
|
|
"plt.ylabel('Acc')\n", |
|
|
|
"plt.legend(loc='best')\n", |
|
|
|
"plt.savefig('fig-res-resnet-train-validate-acc.pdf')\n", |
|
|
|
"plt.show()\n", |
|
|
|
"\n", |
|
|
|
"# save raw data\n", |
|
|
|
"import numpy\n", |
|
|
|
"numpy.save('fig-res-resnet_data.npy', res)" |
|
|
|
] |
|
|
|
}, |
|
|
|
{ |
|
|
@@ -418,7 +441,7 @@ |
|
|
|
"name": "python", |
|
|
|
"nbconvert_exporter": "python", |
|
|
|
"pygments_lexer": "ipython3", |
|
|
|
"version": "3.5.4" |
|
|
|
"version": "3.8.12" |
|
|
|
} |
|
|
|
}, |
|
|
|
"nbformat": 4, |
|
|
|