Browse Source

Improve train of LeNet5/AlexNet/VGG/GoogLeNet/ResNet

pull/10/MERGE
bushuhui 3 years ago
parent
commit
4dd0a5df4c
6 changed files with 554 additions and 241 deletions
  1. +110
    -13
      7_deep_learning/1_CNN/02-LeNet5.ipynb
  2. +95
    -10
      7_deep_learning/1_CNN/03-AlexNet.ipynb
  3. +153
    -97
      7_deep_learning/1_CNN/04-vgg.ipynb
  4. +109
    -59
      7_deep_learning/1_CNN/05-googlenet.ipynb
  5. +78
    -55
      7_deep_learning/1_CNN/06-resnet.ipynb
  6. +9
    -7
      7_deep_learning/1_CNN/utils.py

+ 110
- 13
7_deep_learning/1_CNN/02-LeNet5.ipynb
File diff suppressed because it is too large
View File


+ 95
- 10
7_deep_learning/1_CNN/03-AlexNet.ipynb
File diff suppressed because it is too large
View File


+ 153
- 97
7_deep_learning/1_CNN/04-vgg.ipynb
File diff suppressed because it is too large
View File


+ 109
- 59
7_deep_learning/1_CNN/05-googlenet.ipynb
File diff suppressed because it is too large
View File


+ 78
- 55
7_deep_learning/1_CNN/06-resnet.ipynb View File

@@ -42,36 +42,32 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": 11,
"metadata": {
"ExecuteTime": {
"end_time": "2017-12-22T12:56:06.772059Z",
"start_time": "2017-12-22T12:56:06.766027Z"
},
"collapsed": true
}
},
"outputs": [],
"source": [
"import sys\n",
"sys.path.append('..')\n",
"\n",
"import numpy as np\n",
"import torch\n",
"from torch import nn\n",
"import torch.nn.functional as F\n",
"from torch.autograd import Variable\n",
"from torchvision.datasets import CIFAR10"
"from torchvision.datasets import CIFAR10\n",
"from torchvision import transforms as tfs"
]
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 2,
"metadata": {
"ExecuteTime": {
"end_time": "2017-12-22T12:47:49.222432Z",
"start_time": "2017-12-22T12:47:49.217940Z"
},
"collapsed": true
}
},
"outputs": [],
"source": [
@@ -82,19 +78,18 @@
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": 3,
"metadata": {
"ExecuteTime": {
"end_time": "2017-12-22T13:14:02.429145Z",
"start_time": "2017-12-22T13:14:02.383322Z"
},
"collapsed": true
}
},
"outputs": [],
"source": [
"class Residual_Block(nn.Module):\n",
" def __init__(self, in_channel, out_channel, same_shape=True):\n",
" super(residual_block, self).__init__()\n",
" super(Residual_Block, self).__init__()\n",
" self.same_shape = same_shape\n",
" stride=1 if self.same_shape else 2\n",
" \n",
@@ -127,7 +122,7 @@
},
{
"cell_type": "code",
"execution_count": 8,
"execution_count": 4,
"metadata": {
"ExecuteTime": {
"end_time": "2017-12-22T13:14:05.793185Z",
@@ -155,7 +150,7 @@
},
{
"cell_type": "code",
"execution_count": 9,
"execution_count": 5,
"metadata": {
"ExecuteTime": {
"end_time": "2017-12-22T13:14:11.929120Z",
@@ -201,13 +196,12 @@
},
{
"cell_type": "code",
"execution_count": 10,
"execution_count": 6,
"metadata": {
"ExecuteTime": {
"end_time": "2017-12-22T13:27:46.099404Z",
"start_time": "2017-12-22T13:27:45.986235Z"
},
"collapsed": true
}
},
"outputs": [],
"source": [
@@ -272,7 +266,7 @@
},
{
"cell_type": "code",
"execution_count": 11,
"execution_count": 8,
"metadata": {
"ExecuteTime": {
"end_time": "2017-12-22T13:28:00.597030Z",
@@ -302,39 +296,39 @@
},
{
"cell_type": "code",
"execution_count": 32,
"execution_count": 12,
"metadata": {
"ExecuteTime": {
"end_time": "2017-12-22T13:29:01.484172Z",
"start_time": "2017-12-22T13:29:00.095952Z"
},
"collapsed": true
}
},
"outputs": [],
"source": [
"from utils import train\n",
"\n",
"def data_tf(x):\n",
" x = x.resize((96, 96), 2) # 将图片放大到 96 x 96\n",
" x = np.array(x, dtype='float32') / 255\n",
" x = (x - 0.5) / 0.5 # 标准化,这个技巧之后会讲到\n",
" x = x.transpose((2, 0, 1)) # 将 channel 放到第一维,只是 pytorch 要求的输入方式\n",
" x = torch.from_numpy(x)\n",
" im_aug = tfs.Compose([\n",
" tfs.Resize(96),\n",
" tfs.ToTensor(),\n",
" tfs.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])\n",
" ])\n",
" x = im_aug(x)\n",
" return x\n",
" \n",
"train_set = CIFAR10('../../data', train=True, transform=data_tf)\n",
"train_set = CIFAR10('../../data', train=True, transform=data_tf)\n",
"train_data = torch.utils.data.DataLoader(train_set, batch_size=64, shuffle=True)\n",
"test_set = CIFAR10('../../data', train=False, transform=data_tf)\n",
"test_data = torch.utils.data.DataLoader(test_set, batch_size=128, shuffle=False)\n",
"test_set = CIFAR10('../../data', train=False, transform=data_tf)\n",
"test_data = torch.utils.data.DataLoader(test_set, batch_size=128, shuffle=False)\n",
"\n",
"net = ResNet(3, 10)\n",
"optimizer = torch.optim.SGD(net.parameters(), lr=0.01)\n",
"optimizer = torch.optim.Adam(net.parameters(), lr=1e-3)\n",
"criterion = nn.CrossEntropyLoss()"
]
},
{
"cell_type": "code",
"execution_count": 33,
"execution_count": null,
"metadata": {
"ExecuteTime": {
"end_time": "2017-12-22T13:45:00.783186Z",
@@ -346,31 +340,60 @@
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 0. Train Loss: 1.437317, Train Acc: 0.476662, Valid Loss: 1.928288, Valid Acc: 0.384691, Time 00:00:44\n",
"Epoch 1. Train Loss: 0.992832, Train Acc: 0.648198, Valid Loss: 1.009847, Valid Acc: 0.642405, Time 00:00:48\n",
"Epoch 2. Train Loss: 0.767309, Train Acc: 0.732617, Valid Loss: 1.827319, Valid Acc: 0.430380, Time 00:00:47\n",
"Epoch 3. Train Loss: 0.606737, Train Acc: 0.788043, Valid Loss: 1.304808, Valid Acc: 0.585245, Time 00:00:46\n",
"Epoch 4. Train Loss: 0.484436, Train Acc: 0.834499, Valid Loss: 1.335749, Valid Acc: 0.617089, Time 00:00:47\n",
"Epoch 5. Train Loss: 0.374320, Train Acc: 0.872922, Valid Loss: 0.878519, Valid Acc: 0.724288, Time 00:00:47\n",
"Epoch 6. Train Loss: 0.280981, Train Acc: 0.904212, Valid Loss: 0.931616, Valid Acc: 0.716871, Time 00:00:48\n",
"Epoch 7. Train Loss: 0.210800, Train Acc: 0.929747, Valid Loss: 1.448870, Valid Acc: 0.638548, Time 00:00:48\n",
"Epoch 8. Train Loss: 0.147873, Train Acc: 0.951427, Valid Loss: 1.356992, Valid Acc: 0.657536, Time 00:00:47\n",
"Epoch 9. Train Loss: 0.112824, Train Acc: 0.963895, Valid Loss: 1.630560, Valid Acc: 0.627769, Time 00:00:47\n",
"Epoch 10. Train Loss: 0.082685, Train Acc: 0.973905, Valid Loss: 0.982882, Valid Acc: 0.744264, Time 00:00:44\n",
"Epoch 11. Train Loss: 0.065325, Train Acc: 0.979680, Valid Loss: 0.911631, Valid Acc: 0.767009, Time 00:00:47\n",
"Epoch 12. Train Loss: 0.041401, Train Acc: 0.987952, Valid Loss: 1.167992, Valid Acc: 0.729826, Time 00:00:48\n",
"Epoch 13. Train Loss: 0.037516, Train Acc: 0.989011, Valid Loss: 1.081807, Valid Acc: 0.746737, Time 00:00:47\n",
"Epoch 14. Train Loss: 0.030674, Train Acc: 0.991468, Valid Loss: 0.935292, Valid Acc: 0.774031, Time 00:00:45\n",
"Epoch 15. Train Loss: 0.021743, Train Acc: 0.994565, Valid Loss: 0.879348, Valid Acc: 0.790150, Time 00:00:47\n",
"Epoch 16. Train Loss: 0.014642, Train Acc: 0.996463, Valid Loss: 1.328587, Valid Acc: 0.724387, Time 00:00:47\n",
"Epoch 17. Train Loss: 0.011072, Train Acc: 0.997363, Valid Loss: 0.909065, Valid Acc: 0.792919, Time 00:00:47\n",
"Epoch 18. Train Loss: 0.006870, Train Acc: 0.998561, Valid Loss: 0.923746, Valid Acc: 0.794403, Time 00:00:46\n",
"Epoch 19. Train Loss: 0.004240, Train Acc: 0.999500, Valid Loss: 0.877908, Valid Acc: 0.802314, Time 00:00:46\n"
"[ 0] Train:(L=1.506980, Acc=0.449868), Valid:(L=1.119623, Acc=0.598596), T: 00:00:48\n",
"[ 1] Train:(L=1.022635, Acc=0.641504), Valid:(L=0.942414, Acc=0.669600), T: 00:00:47\n",
"[ 2] Train:(L=0.806174, Acc=0.717551), Valid:(L=0.921687, Acc=0.682061), T: 00:00:47\n",
"[ 3] Train:(L=0.638939, Acc=0.775555), Valid:(L=0.802450, Acc=0.729727), T: 00:00:47\n",
"[ 4] Train:(L=0.497571, Acc=0.826606), Valid:(L=0.658700, Acc=0.775316), T: 00:00:47\n",
"[ 5] Train:(L=0.364864, Acc=0.872442), Valid:(L=0.717290, Acc=0.768888), T: 00:00:47\n",
"[ 6] Train:(L=0.263076, Acc=0.907888), Valid:(L=0.832575, Acc=0.750000), T: 00:00:47\n",
"[ 7] Train:(L=0.181254, Acc=0.935782), Valid:(L=0.818366, Acc=0.764933), T: 00:00:47\n",
"[ 8] Train:(L=0.124111, Acc=0.957820), Valid:(L=0.883527, Acc=0.778184), T: 00:00:47\n",
"[ 9] Train:(L=0.108587, Acc=0.961657), Valid:(L=0.899127, Acc=0.780756), T: 00:00:47\n",
"[10] Train:(L=0.091386, Acc=0.968670), Valid:(L=0.975022, Acc=0.781448), T: 00:00:47\n",
"[11] Train:(L=0.079259, Acc=0.972287), Valid:(L=1.061239, Acc=0.770075), T: 00:00:47\n",
"[12] Train:(L=0.067858, Acc=0.976123), Valid:(L=1.025909, Acc=0.782140), T: 00:00:47\n",
"[13] Train:(L=0.064745, Acc=0.977701), Valid:(L=0.987410, Acc=0.789062), T: 00:00:47\n",
"[14] Train:(L=0.056921, Acc=0.979779), Valid:(L=1.165746, Acc=0.773438), T: 00:00:47\n",
"[15] Train:(L=0.058128, Acc=0.980039), Valid:(L=1.057119, Acc=0.782437), T: 00:00:47\n",
"[16] Train:(L=0.050794, Acc=0.982257), Valid:(L=1.098127, Acc=0.779074), T: 00:00:47\n",
"[17] Train:(L=0.046720, Acc=0.984415), Valid:(L=1.066124, Acc=0.787184), T: 00:00:47\n",
"[18] Train:(L=0.044737, Acc=0.984375), Valid:(L=1.053032, Acc=0.792029), T: 00:00:47\n"
]
}
],
"source": [
"train(net, train_data, test_data, 20, optimizer, criterion)"
"res = train(net, train_data, test_data, 20, optimizer, criterion)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import matplotlib.pyplot as plt\n",
"%matplotlib inline\n",
"\n",
"plt.plot(res[0], label='train')\n",
"plt.plot(res[2], label='valid')\n",
"plt.xlabel('epoch')\n",
"plt.ylabel('Loss')\n",
"plt.legend(loc='best')\n",
"plt.savefig('fig-res-resnet-train-validate-loss.pdf')\n",
"plt.show()\n",
"\n",
"plt.plot(res[1], label='train')\n",
"plt.plot(res[3], label='valid')\n",
"plt.xlabel('epoch')\n",
"plt.ylabel('Acc')\n",
"plt.legend(loc='best')\n",
"plt.savefig('fig-res-resnet-train-validate-acc.pdf')\n",
"plt.show()\n",
"\n",
"# save raw data\n",
"import numpy\n",
"numpy.save('fig-res-resnet_data.npy', res)"
]
},
{
@@ -418,7 +441,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.5.4"
"version": "3.8.12"
}
},
"nbformat": 4,


+ 9
- 7
7_deep_learning/1_CNN/utils.py View File

@@ -47,10 +47,7 @@ def train(net, train_data, valid_data, num_epochs, optimizer, criterion, use_cud
train_loss += loss.item()
train_acc += get_acc(output, label)

cur_time = datetime.now()
h, remainder = divmod((cur_time - prev_time).seconds, 3600)
m, s = divmod(remainder, 60)
time_str = "Time %02d:%02d:%02d" % (h, m, s)

if valid_data is not None:
valid_loss = 0
valid_acc = 0
@@ -67,7 +64,7 @@ def train(net, train_data, valid_data, num_epochs, optimizer, criterion, use_cud
valid_loss += loss.item()
valid_acc += get_acc(output, label)
epoch_str = (
"Epoch %d. Train Loss: %f, Train Acc: %f, Valid Loss: %f, Valid Acc: %f, "
"[%2d] Train:(L=%f, Acc=%f), Valid:(L=%f, Acc=%f), "
% (epoch, train_loss / len(train_data),
train_acc / len(train_data), valid_loss / len(valid_data),
valid_acc / len(valid_data)))
@@ -75,13 +72,18 @@ def train(net, train_data, valid_data, num_epochs, optimizer, criterion, use_cud
l_valid_acc.append(valid_acc / len(valid_data))
l_valid_loss.append(valid_loss / len(valid_data))
else:
epoch_str = ("Epoch %d. Train Loss: %f, Train Acc: %f, " %
epoch_str = ("[%2d] Train:(L=%f, Acc=%f), " %
(epoch, train_loss / len(train_data),
train_acc / len(train_data)))

l_train_acc.append(train_acc / len(train_data))
l_train_loss.append(train_loss / len(train_data))
cur_time = datetime.now()
h, remainder = divmod((cur_time - prev_time).seconds, 3600)
m, s = divmod(remainder, 60)
time_str = "T: %02d:%02d:%02d" % (h, m, s)
prev_time = cur_time
print(epoch_str + time_str)


Loading…
Cancel
Save