|
|
@@ -85,12 +85,9 @@ |
|
|
|
}, |
|
|
|
"outputs": [], |
|
|
|
"source": [ |
|
|
|
"import numpy as np\n", |
|
|
|
"import torch\n", |
|
|
|
"from torch import nn\n", |
|
|
|
"from torch.autograd import Variable\n", |
|
|
|
"from torchvision.datasets import CIFAR10\n", |
|
|
|
"from torchvision import transforms as tfs\n", |
|
|
|
"\n", |
|
|
|
"def VGG_Block(num_convs, in_channels, out_channels):\n", |
|
|
|
" # 定义第一层\n", |
|
|
@@ -364,19 +361,12 @@ |
|
|
|
}, |
|
|
|
"outputs": [], |
|
|
|
"source": [ |
|
|
|
"from torchvision.datasets import CIFAR10\n", |
|
|
|
"from torchvision import transforms as tfs\n", |
|
|
|
"from utils import train\n", |
|
|
|
"\n", |
|
|
|
"# 使用数据增强\n", |
|
|
|
"def train_tf(x):\n", |
|
|
|
" im_aug = tfs.Compose([\n", |
|
|
|
" #tfs.Resize(224),\n", |
|
|
|
" tfs.ToTensor(),\n", |
|
|
|
" tfs.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])\n", |
|
|
|
" ])\n", |
|
|
|
" x = im_aug(x)\n", |
|
|
|
" return x\n", |
|
|
|
"\n", |
|
|
|
"def test_tf(x):\n", |
|
|
|
"# 使用数据变换\n", |
|
|
|
"def data_tf(x):\n", |
|
|
|
" im_aug = tfs.Compose([\n", |
|
|
|
" #tfs.Resize(224),\n", |
|
|
|
" tfs.ToTensor(),\n", |
|
|
@@ -385,10 +375,10 @@ |
|
|
|
" x = im_aug(x)\n", |
|
|
|
" return x\n", |
|
|
|
" \n", |
|
|
|
"train_set = CIFAR10('../../data', train=True, transform=train_tf)\n", |
|
|
|
"train_set = CIFAR10('../../data', train=True, transform=data_tf)\n", |
|
|
|
"train_data = torch.utils.data.DataLoader(train_set, batch_size=64, shuffle=True)\n", |
|
|
|
"test_set = CIFAR10('../../data', train=False, transform=test_tf)\n", |
|
|
|
"test_data = torch.utils.data.DataLoader(test_set, batch_size=128, shuffle=False)\n", |
|
|
|
"test_set = CIFAR10('../../data', train=False, transform=data_tf)\n", |
|
|
|
"test_data = torch.utils.data.DataLoader(test_set, batch_size=128, shuffle=False)\n", |
|
|
|
"\n", |
|
|
|
"net = VGG_Net()\n", |
|
|
|
"optimizer = torch.optim.Adam(net.parameters(), lr=1e-3)\n", |
|
|
@@ -409,26 +399,26 @@ |
|
|
|
"name": "stdout", |
|
|
|
"output_type": "stream", |
|
|
|
"text": [ |
|
|
|
"Epoch 0. Train Loss: 1.746658, Train Acc: 0.313979, Valid Loss: 1.441544, Valid Acc: 0.459454, Time 00:00:24\n", |
|
|
|
"Epoch 1. Train Loss: 1.293454, Train Acc: 0.532868, Valid Loss: 1.064248, Valid Acc: 0.630934, Time 00:00:27\n", |
|
|
|
"Epoch 2. Train Loss: 0.976450, Train Acc: 0.663643, Valid Loss: 0.865004, Valid Acc: 0.693730, Time 00:00:27\n", |
|
|
|
"Epoch 3. Train Loss: 0.797400, Train Acc: 0.734835, Valid Loss: 0.752283, Valid Acc: 0.745847, Time 00:00:27\n", |
|
|
|
"Epoch 4. Train Loss: 0.669900, Train Acc: 0.782389, Valid Loss: 0.654101, Valid Acc: 0.787085, Time 00:00:27\n", |
|
|
|
"Epoch 5. Train Loss: 0.573636, Train Acc: 0.815837, Valid Loss: 0.637741, Valid Acc: 0.797765, Time 00:00:27\n", |
|
|
|
"Epoch 6. Train Loss: 0.499975, Train Acc: 0.839174, Valid Loss: 0.571682, Valid Acc: 0.809237, Time 00:00:27\n", |
|
|
|
"Epoch 7. Train Loss: 0.428868, Train Acc: 0.862372, Valid Loss: 0.505937, Valid Acc: 0.831982, Time 00:00:27\n", |
|
|
|
"Epoch 8. Train Loss: 0.373129, Train Acc: 0.882133, Valid Loss: 0.559180, Valid Acc: 0.821203, Time 00:00:27\n", |
|
|
|
"Epoch 9. Train Loss: 0.323578, Train Acc: 0.895920, Valid Loss: 0.582161, Valid Acc: 0.827927, Time 00:00:27\n", |
|
|
|
"Epoch 10. Train Loss: 0.274420, Train Acc: 0.913163, Valid Loss: 0.623256, Valid Acc: 0.813390, Time 00:00:27\n", |
|
|
|
"Epoch 11. Train Loss: 0.239558, Train Acc: 0.924692, Valid Loss: 0.561333, Valid Acc: 0.840882, Time 00:00:27\n", |
|
|
|
"Epoch 12. Train Loss: 0.207196, Train Acc: 0.935042, Valid Loss: 0.537481, Valid Acc: 0.843256, Time 00:00:27\n", |
|
|
|
"Epoch 13. Train Loss: 0.180536, Train Acc: 0.943454, Valid Loss: 0.697694, Valid Acc: 0.818631, Time 00:00:27\n", |
|
|
|
"Epoch 14. Train Loss: 0.157166, Train Acc: 0.950607, Valid Loss: 0.542898, Valid Acc: 0.857298, Time 00:00:27\n", |
|
|
|
"Epoch 15. Train Loss: 0.138539, Train Acc: 0.957201, Valid Loss: 0.602181, Valid Acc: 0.847112, Time 00:00:27\n", |
|
|
|
"Epoch 16. Train Loss: 0.130947, Train Acc: 0.960418, Valid Loss: 0.607590, Valid Acc: 0.852453, Time 00:00:27\n", |
|
|
|
"Epoch 17. Train Loss: 0.109348, Train Acc: 0.966972, Valid Loss: 0.636679, Valid Acc: 0.848497, Time 00:00:27\n", |
|
|
|
"Epoch 18. Train Loss: 0.099000, Train Acc: 0.970009, Valid Loss: 0.640463, Valid Acc: 0.849684, Time 00:00:27\n", |
|
|
|
"Epoch 19. Train Loss: 0.083908, Train Acc: 0.974604, Valid Loss: 0.630587, Valid Acc: 0.859771, Time 00:00:27\n" |
|
|
|
"[ 0] Train:(L=1.689224, Acc=0.340493), Valid:(L=1.514729, Acc=0.449664), T: 00:00:27\n", |
|
|
|
"[ 1] Train:(L=1.211734, Acc=0.572111), Valid:(L=1.043950, Acc=0.638944), T: 00:00:27\n", |
|
|
|
"[ 2] Train:(L=0.939749, Acc=0.680647), Valid:(L=0.795742, Acc=0.731408), T: 00:00:27\n", |
|
|
|
"[ 3] Train:(L=0.776414, Acc=0.742987), Valid:(L=0.773068, Acc=0.741792), T: 00:00:27\n", |
|
|
|
"[ 4] Train:(L=0.655303, Acc=0.784607), Valid:(L=0.697191, Acc=0.759296), T: 00:00:27\n", |
|
|
|
"[ 5] Train:(L=0.565006, Acc=0.816956), Valid:(L=0.634936, Acc=0.791436), T: 00:00:27\n", |
|
|
|
"[ 6] Train:(L=0.487787, Acc=0.842172), Valid:(L=0.643098, Acc=0.788766), T: 00:00:27\n", |
|
|
|
"[ 7] Train:(L=0.422939, Acc=0.866129), Valid:(L=0.539120, Acc=0.828323), T: 00:00:27\n", |
|
|
|
"[ 8] Train:(L=0.365580, Acc=0.882413), Valid:(L=0.598219, Acc=0.808248), T: 00:00:27\n", |
|
|
|
"[ 9] Train:(L=0.316299, Acc=0.899097), Valid:(L=0.601980, Acc=0.821005), T: 00:00:27\n", |
|
|
|
"[10] Train:(L=0.271955, Acc=0.914043), Valid:(L=0.664353, Acc=0.803402), T: 00:00:27\n", |
|
|
|
"[11] Train:(L=0.240455, Acc=0.923633), Valid:(L=0.555360, Acc=0.831191), T: 00:00:27\n", |
|
|
|
"[12] Train:(L=0.199894, Acc=0.937100), Valid:(L=0.514384, Acc=0.857793), T: 00:00:27\n", |
|
|
|
"[13] Train:(L=0.175543, Acc=0.944973), Valid:(L=0.641336, Acc=0.842860), T: 00:00:27\n", |
|
|
|
"[14] Train:(L=0.149222, Acc=0.953285), Valid:(L=0.600546, Acc=0.847706), T: 00:00:27\n", |
|
|
|
"[15] Train:(L=0.137142, Acc=0.957221), Valid:(L=0.597016, Acc=0.851958), T: 00:00:27\n", |
|
|
|
"[16] Train:(L=0.121605, Acc=0.962456), Valid:(L=0.638970, Acc=0.850870), T: 00:00:27\n", |
|
|
|
"[17] Train:(L=0.101306, Acc=0.970149), Valid:(L=0.602158, Acc=0.851760), T: 00:00:27\n", |
|
|
|
"[18] Train:(L=0.103552, Acc=0.970009), Valid:(L=0.619986, Acc=0.844640), T: 00:00:27\n", |
|
|
|
"[19] Train:(L=0.092559, Acc=0.972047), Valid:(L=0.692679, Acc=0.841772), T: 00:00:27\n" |
|
|
|
] |
|
|
|
} |
|
|
|
], |
|
|
|