|
|
@@ -0,0 +1,167 @@ |
|
|
|
{ |
|
|
|
"cells": [ |
|
|
|
{ |
|
|
|
"cell_type": "markdown", |
|
|
|
"metadata": {}, |
|
|
|
"source": [ |
|
|
|
"# 正则化\n", |
|
|
|
"前面我们讲了数据增强和 dropout,而在实际使用中,现在的网络往往不使用 dropout,而是用另外一个技术,叫正则化。\n", |
|
|
|
"\n", |
|
|
|
"正则化是机器学习中提出来的一种方法,有 L1 和 L2 正则化,目前使用较多的是 L2 正则化,引入正则化相当于在 loss 函数上面加上一项,比如\n", |
|
|
|
"\n", |
|
|
|
"$$\n", |
|
|
|
"f = loss + \\lambda \\sum_{p \\in params} ||p||_2^2\n", |
|
|
|
"$$\n", |
|
|
|
"\n", |
|
|
|
"就是在 loss 的基础上加上了参数的二范数作为一个正则化,我们在训练网络的时候,不仅要最小化 loss 函数,同时还要最小化参数的二范数,也就是说我们会对参数做一些限制,不让它变得太大。" |
|
|
|
] |
|
|
|
}, |
|
|
|
{ |
|
|
|
"cell_type": "markdown", |
|
|
|
"metadata": {}, |
|
|
|
"source": [ |
|
|
|
"如果我们对新的损失函数 f 求导进行梯度下降,就有\n", |
|
|
|
"\n", |
|
|
|
"$$\n", |
|
|
|
"\\frac{\\partial f}{\\partial p_j} = \\frac{\\partial loss}{\\partial p_j} + 2 \\lambda p_j\n", |
|
|
|
"$$\n", |
|
|
|
"\n", |
|
|
|
"那么在更新参数的时候就有\n", |
|
|
|
"\n", |
|
|
|
"$$\n", |
|
|
|
"p_j \\rightarrow p_j - \\eta (\\frac{\\partial loss}{\\partial p_j} + 2 \\lambda p_j) = p_j - \\eta \\frac{\\partial loss}{\\partial p_j} - 2 \\eta \\lambda p_j \n", |
|
|
|
"$$\n" |
|
|
|
] |
|
|
|
}, |
|
|
|
{ |
|
|
|
"cell_type": "markdown", |
|
|
|
"metadata": {}, |
|
|
|
"source": [ |
|
|
|
"可以看到 $p_j - \\eta \\frac{\\partial loss}{\\partial p_j}$ 和没加正则项要更新的部分一样,而后面的 $2\\eta \\lambda p_j$ 就是正则项的影响,可以看到加完正则项之后会对参数做更大程度的更新,这也被称为权重衰减(weight decay),在 pytorch 中正则项就是通过这种方式来加入的,比如想在随机梯度下降法中使用正则项,或者说权重衰减,`torch.optim.SGD(net.parameters(), lr=0.1, weight_decay=1e-4)` 就可以了,这个 `weight_decay` 系数就是上面公式中的 $\\lambda$,非常方便\n", |
|
|
|
"\n", |
|
|
|
"注意正则项的系数的大小非常重要,如果太大,会极大的抑制参数的更新,导致欠拟合,如果太小,那么正则项这个部分基本没有贡献,所以选择一个合适的权重衰减系数非常重要,这个需要根据具体的情况去尝试,初步尝试可以使用 `1e-4` 或者 `1e-3` \n", |
|
|
|
"\n", |
|
|
|
"下面我们在训练 cifar 10 中添加正则项" |
|
|
|
] |
|
|
|
}, |
|
|
|
{ |
|
|
|
"cell_type": "code", |
|
|
|
"execution_count": 1, |
|
|
|
"metadata": { |
|
|
|
"ExecuteTime": { |
|
|
|
"end_time": "2017-12-24T08:02:11.903459Z", |
|
|
|
"start_time": "2017-12-24T08:02:11.383170Z" |
|
|
|
}, |
|
|
|
"collapsed": true |
|
|
|
}, |
|
|
|
"outputs": [], |
|
|
|
"source": [ |
|
|
|
"import sys\n", |
|
|
|
"sys.path.append('..')\n", |
|
|
|
"\n", |
|
|
|
"import numpy as np\n", |
|
|
|
"import torch\n", |
|
|
|
"from torch import nn\n", |
|
|
|
"import torch.nn.functional as F\n", |
|
|
|
"from torch.autograd import Variable\n", |
|
|
|
"from torchvision.datasets import CIFAR10\n", |
|
|
|
"from utils import train, resnet\n", |
|
|
|
"from torchvision import transforms as tfs" |
|
|
|
] |
|
|
|
}, |
|
|
|
{ |
|
|
|
"cell_type": "code", |
|
|
|
"execution_count": 2, |
|
|
|
"metadata": { |
|
|
|
"ExecuteTime": { |
|
|
|
"end_time": "2017-12-24T08:02:13.120502Z", |
|
|
|
"start_time": "2017-12-24T08:02:11.905617Z" |
|
|
|
}, |
|
|
|
"collapsed": true |
|
|
|
}, |
|
|
|
"outputs": [], |
|
|
|
"source": [ |
|
|
|
"def data_tf(x):\n", |
|
|
|
" im_aug = tfs.Compose([\n", |
|
|
|
" tfs.Resize(96),\n", |
|
|
|
" tfs.ToTensor(),\n", |
|
|
|
" tfs.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])\n", |
|
|
|
" ])\n", |
|
|
|
" x = im_aug(x)\n", |
|
|
|
" return x\n", |
|
|
|
"\n", |
|
|
|
"train_set = CIFAR10('./data', train=True, transform=data_tf)\n", |
|
|
|
"train_data = torch.utils.data.DataLoader(train_set, batch_size=64, shuffle=True, num_workers=4)\n", |
|
|
|
"test_set = CIFAR10('./data', train=False, transform=data_tf)\n", |
|
|
|
"test_data = torch.utils.data.DataLoader(test_set, batch_size=128, shuffle=False, num_workers=4)\n", |
|
|
|
"\n", |
|
|
|
"net = resnet(3, 10)\n", |
|
|
|
"optimizer = torch.optim.SGD(net.parameters(), lr=0.01, weight_decay=1e-4) # 增加正则项\n", |
|
|
|
"criterion = nn.CrossEntropyLoss()" |
|
|
|
] |
|
|
|
}, |
|
|
|
{ |
|
|
|
"cell_type": "code", |
|
|
|
"execution_count": 3, |
|
|
|
"metadata": { |
|
|
|
"ExecuteTime": { |
|
|
|
"end_time": "2017-12-24T08:11:36.106177Z", |
|
|
|
"start_time": "2017-12-24T08:02:13.122785Z" |
|
|
|
} |
|
|
|
}, |
|
|
|
"outputs": [ |
|
|
|
{ |
|
|
|
"name": "stdout", |
|
|
|
"output_type": "stream", |
|
|
|
"text": [ |
|
|
|
"Epoch 0. Train Loss: 1.429834, Train Acc: 0.476982, Valid Loss: 1.261334, Valid Acc: 0.546776, Time 00:00:26\n", |
|
|
|
"Epoch 1. Train Loss: 0.994539, Train Acc: 0.645400, Valid Loss: 1.310620, Valid Acc: 0.554688, Time 00:00:27\n", |
|
|
|
"Epoch 2. Train Loss: 0.788570, Train Acc: 0.723585, Valid Loss: 1.256101, Valid Acc: 0.577433, Time 00:00:28\n", |
|
|
|
"Epoch 3. Train Loss: 0.629832, Train Acc: 0.780411, Valid Loss: 1.222015, Valid Acc: 0.609474, Time 00:00:27\n", |
|
|
|
"Epoch 4. Train Loss: 0.500406, Train Acc: 0.825288, Valid Loss: 0.831702, Valid Acc: 0.720332, Time 00:00:27\n", |
|
|
|
"Epoch 5. Train Loss: 0.388376, Train Acc: 0.868646, Valid Loss: 0.829582, Valid Acc: 0.726760, Time 00:00:27\n", |
|
|
|
"Epoch 6. Train Loss: 0.291237, Train Acc: 0.902094, Valid Loss: 1.499777, Valid Acc: 0.623714, Time 00:00:28\n", |
|
|
|
"Epoch 7. Train Loss: 0.222401, Train Acc: 0.925072, Valid Loss: 1.832660, Valid Acc: 0.558643, Time 00:00:28\n", |
|
|
|
"Epoch 8. Train Loss: 0.157753, Train Acc: 0.947990, Valid Loss: 1.255313, Valid Acc: 0.668117, Time 00:00:28\n", |
|
|
|
"Epoch 9. Train Loss: 0.111407, Train Acc: 0.963595, Valid Loss: 1.004693, Valid Acc: 0.724782, Time 00:00:27\n", |
|
|
|
"Epoch 10. Train Loss: 0.084960, Train Acc: 0.972926, Valid Loss: 0.867961, Valid Acc: 0.775119, Time 00:00:27\n", |
|
|
|
"Epoch 11. Train Loss: 0.066854, Train Acc: 0.979280, Valid Loss: 1.011263, Valid Acc: 0.749604, Time 00:00:28\n", |
|
|
|
"Epoch 12. Train Loss: 0.048280, Train Acc: 0.985534, Valid Loss: 2.438345, Valid Acc: 0.576938, Time 00:00:27\n", |
|
|
|
"Epoch 13. Train Loss: 0.046176, Train Acc: 0.985614, Valid Loss: 1.008425, Valid Acc: 0.756527, Time 00:00:27\n", |
|
|
|
"Epoch 14. Train Loss: 0.039515, Train Acc: 0.988411, Valid Loss: 0.945017, Valid Acc: 0.766317, Time 00:00:27\n", |
|
|
|
"Epoch 15. Train Loss: 0.025882, Train Acc: 0.992667, Valid Loss: 0.918691, Valid Acc: 0.784217, Time 00:00:27\n", |
|
|
|
"Epoch 16. Train Loss: 0.018592, Train Acc: 0.994985, Valid Loss: 1.507427, Valid Acc: 0.680281, Time 00:00:27\n", |
|
|
|
"Epoch 17. Train Loss: 0.021062, Train Acc: 0.994246, Valid Loss: 2.976452, Valid Acc: 0.558940, Time 00:00:27\n", |
|
|
|
"Epoch 18. Train Loss: 0.021458, Train Acc: 0.993926, Valid Loss: 0.927871, Valid Acc: 0.785898, Time 00:00:27\n", |
|
|
|
"Epoch 19. Train Loss: 0.015656, Train Acc: 0.995824, Valid Loss: 0.962502, Valid Acc: 0.782832, Time 00:00:27\n" |
|
|
|
] |
|
|
|
} |
|
|
|
], |
|
|
|
"source": [ |
|
|
|
"from utils import train\n", |
|
|
|
"train(net, train_data, test_data, 20, optimizer, criterion)" |
|
|
|
] |
|
|
|
} |
|
|
|
], |
|
|
|
"metadata": { |
|
|
|
"kernelspec": { |
|
|
|
"display_name": "Python 3", |
|
|
|
"language": "python", |
|
|
|
"name": "python3" |
|
|
|
}, |
|
|
|
"language_info": { |
|
|
|
"codemirror_mode": { |
|
|
|
"name": "ipython", |
|
|
|
"version": 3 |
|
|
|
}, |
|
|
|
"file_extension": ".py", |
|
|
|
"mimetype": "text/x-python", |
|
|
|
"name": "python", |
|
|
|
"nbconvert_exporter": "python", |
|
|
|
"pygments_lexer": "ipython3", |
|
|
|
"version": "3.5.2" |
|
|
|
} |
|
|
|
}, |
|
|
|
"nbformat": 4, |
|
|
|
"nbformat_minor": 2 |
|
|
|
} |