|
@@ -20,7 +20,7 @@ |
|
|
"cell_type": "markdown", |
|
|
"cell_type": "markdown", |
|
|
"metadata": {}, |
|
|
"metadata": {}, |
|
|
"source": [ |
|
|
"source": [ |
|
|
"如果我们对新的损失函数 f 求导进行梯度下降,就有\n", |
|
|
|
|
|
|
|
|
"如果对新的损失函数 $f$ 求导进行梯度下降,就有\n", |
|
|
"\n", |
|
|
"\n", |
|
|
"$$\n", |
|
|
"$$\n", |
|
|
"\\frac{\\partial f}{\\partial p_j} = \\frac{\\partial loss}{\\partial p_j} + 2 \\lambda p_j\n", |
|
|
"\\frac{\\partial f}{\\partial p_j} = \\frac{\\partial loss}{\\partial p_j} + 2 \\lambda p_j\n", |
|
@@ -37,9 +37,9 @@ |
|
|
"cell_type": "markdown", |
|
|
"cell_type": "markdown", |
|
|
"metadata": {}, |
|
|
"metadata": {}, |
|
|
"source": [ |
|
|
"source": [ |
|
|
"可以看到 $p_j - \\eta \\frac{\\partial loss}{\\partial p_j}$ 和没加正则项要更新的部分一样,而后面的 $2\\eta \\lambda p_j$ 就是正则项的影响,可以看到加完正则项之后会对参数做更大程度的更新,这也被称为权重衰减(weight decay),在 pytorch 中正则项就是通过这种方式来加入的,比如想在随机梯度下降法中使用正则项,或者说权重衰减,`torch.optim.SGD(net.parameters(), lr=0.1, weight_decay=1e-4)` 就可以了,这个 `weight_decay` 系数就是上面公式中的 $\\lambda$,非常方便\n", |
|
|
|
|
|
|
|
|
"可以看到 $p_j - \\eta \\frac{\\partial loss}{\\partial p_j}$ 和没加正则项要更新的部分一样,而后面的 $2\\eta \\lambda p_j$ 就是正则项的影响,可以看到加完正则项之后会对参数做更大程度的更新,这也被称为权重衰减(weight decay)。在 PyTorch 中正则项就是通过这种方式来加入的,比如想在随机梯度下降法中使用正则项,或者说权重衰减,`torch.optim.SGD(net.parameters(), lr=0.1, weight_decay=1e-4)` 就可以了,这个 `weight_decay` 系数就是上面公式中的 $\\lambda$,非常方便\n", |
|
|
"\n", |
|
|
"\n", |
|
|
"注意正则项的系数的大小非常重要,如果太大,会极大的抑制参数的更新,导致欠拟合,如果太小,那么正则项这个部分基本没有贡献,所以选择一个合适的权重衰减系数非常重要,这个需要根据具体的情况去尝试,初步尝试可以使用 `1e-4` 或者 `1e-3` \n", |
|
|
|
|
|
|
|
|
"注意正则项的系数的大小非常重要,如果太大,会极大的抑制参数的更新,导致欠拟合;如果太小,那么正则项这个部分基本没有贡献。所以选择一个合适的权重衰减系数非常重要,这个需要根据具体的情况去尝试,初步尝试可以使用 `1e-4` 或者 `1e-3` \n", |
|
|
"\n", |
|
|
"\n", |
|
|
"下面我们在训练 cifar 10 中添加正则项" |
|
|
"下面我们在训练 cifar 10 中添加正则项" |
|
|
] |
|
|
] |
|
@@ -159,7 +159,7 @@ |
|
|
"name": "python", |
|
|
"name": "python", |
|
|
"nbconvert_exporter": "python", |
|
|
"nbconvert_exporter": "python", |
|
|
"pygments_lexer": "ipython3", |
|
|
"pygments_lexer": "ipython3", |
|
|
"version": "3.5.4" |
|
|
|
|
|
|
|
|
"version": "3.7.9" |
|
|
} |
|
|
} |
|
|
}, |
|
|
}, |
|
|
"nbformat": 4, |
|
|
"nbformat": 4, |
|
|