|
|
@@ -20,23 +20,13 @@ |
|
|
|
}, |
|
|
|
{ |
|
|
|
"cell_type": "code", |
|
|
|
"execution_count": 1, |
|
|
|
"metadata": { |
|
|
|
"collapsed": true |
|
|
|
}, |
|
|
|
"execution_count": null, |
|
|
|
"metadata": {}, |
|
|
|
"outputs": [], |
|
|
|
"source": [ |
|
|
|
"import sys\n", |
|
|
|
"sys.path.append('..')\n", |
|
|
|
"\n", |
|
|
|
"import numpy as np\n", |
|
|
|
"import torch\n", |
|
|
|
"from torch import nn\n", |
|
|
|
"from torch.autograd import Variable\n", |
|
|
|
"import torch.nn.functional as F\n", |
|
|
|
"from torchvision.datasets import CIFAR10\n", |
|
|
|
"from torchvision import transforms as tfs\n", |
|
|
|
"\n", |
|
|
|
"\n", |
|
|
|
"class LeNet5(nn.Module):\n", |
|
|
|
" def __init__(self):\n", |
|
|
@@ -64,23 +54,9 @@ |
|
|
|
}, |
|
|
|
{ |
|
|
|
"cell_type": "code", |
|
|
|
"execution_count": 4, |
|
|
|
"execution_count": null, |
|
|
|
"metadata": {}, |
|
|
|
"outputs": [ |
|
|
|
{ |
|
|
|
"name": "stdout", |
|
|
|
"output_type": "stream", |
|
|
|
"text": [ |
|
|
|
"LeNet5(\n", |
|
|
|
" (conv1): Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1))\n", |
|
|
|
" (conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))\n", |
|
|
|
" (fc1): Linear(in_features=400, out_features=120, bias=True)\n", |
|
|
|
" (fc2): Linear(in_features=120, out_features=84, bias=True)\n", |
|
|
|
" (fc3): Linear(in_features=84, out_features=10, bias=True)\n", |
|
|
|
")\n" |
|
|
|
] |
|
|
|
} |
|
|
|
], |
|
|
|
"outputs": [], |
|
|
|
"source": [ |
|
|
|
"net = LeNet5()\n", |
|
|
|
"print(net)" |
|
|
@@ -89,54 +65,103 @@ |
|
|
|
{ |
|
|
|
"cell_type": "code", |
|
|
|
"execution_count": null, |
|
|
|
"metadata": { |
|
|
|
"collapsed": true |
|
|
|
}, |
|
|
|
"metadata": {}, |
|
|
|
"outputs": [], |
|
|
|
"source": [ |
|
|
|
"input = torch.randn(1, 1, 32, 32)\n", |
|
|
|
"out = net(input)\n", |
|
|
|
"print(out)" |
|
|
|
] |
|
|
|
}, |
|
|
|
{ |
|
|
|
"cell_type": "code", |
|
|
|
"execution_count": null, |
|
|
|
"metadata": {}, |
|
|
|
"outputs": [], |
|
|
|
"source": [ |
|
|
|
"import numpy as np\n", |
|
|
|
"from torchvision.datasets import mnist\n", |
|
|
|
"from torch.utils.data import DataLoader\n", |
|
|
|
"from torchvision.datasets import mnist \n", |
|
|
|
"from torchvision import transforms as tfs\n", |
|
|
|
"from utils import train\n", |
|
|
|
"\n", |
|
|
|
"# 使用数据增强\n", |
|
|
|
"def data_tf(x):\n", |
|
|
|
" im_aug = tfs.Compose([\n", |
|
|
|
" tfs.Resize(32),\n", |
|
|
|
" tfs.ToTensor(),\n", |
|
|
|
" tfs.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])\n", |
|
|
|
" tfs.ToTensor() #,\n", |
|
|
|
" #tfs.Normalize([0.5], [0.5])\n", |
|
|
|
" ])\n", |
|
|
|
" x = im_aug(x)\n", |
|
|
|
" return x\n", |
|
|
|
" \n", |
|
|
|
"train_set = CIFAR10('../../data', train=True, transform=data_tf)\n", |
|
|
|
"train_set = mnist.MNIST('../../data/mnist', train=True, transform=data_tf, download=True) \n", |
|
|
|
"train_data = torch.utils.data.DataLoader(train_set, batch_size=64, shuffle=True)\n", |
|
|
|
"test_set = CIFAR10('../../data', train=False, transform=data_tf)\n", |
|
|
|
"test_data = torch.utils.data.DataLoader(test_set, batch_size=128, shuffle=False)\n", |
|
|
|
"test_set = mnist.MNIST('../../data/mnist', train=False, transform=data_tf, download=True) \n", |
|
|
|
"test_data = torch.utils.data.DataLoader(test_set, batch_size=128, shuffle=False)\n", |
|
|
|
"\n" |
|
|
|
] |
|
|
|
}, |
|
|
|
{ |
|
|
|
"cell_type": "code", |
|
|
|
"execution_count": null, |
|
|
|
"metadata": {}, |
|
|
|
"outputs": [], |
|
|
|
"source": [ |
|
|
|
"# 显示其中一个数据\n", |
|
|
|
"import matplotlib.pyplot as plt\n", |
|
|
|
"plt.imshow(train_set.data[0], cmap='gray')\n", |
|
|
|
"plt.title('%i' % train_set.targets[0])\n", |
|
|
|
"plt.colorbar()\n", |
|
|
|
"plt.show()" |
|
|
|
] |
|
|
|
}, |
|
|
|
{ |
|
|
|
"cell_type": "code", |
|
|
|
"execution_count": null, |
|
|
|
"metadata": {}, |
|
|
|
"outputs": [], |
|
|
|
"source": [ |
|
|
|
"import matplotlib.pyplot as plt\n", |
|
|
|
"\n", |
|
|
|
"net = LeNet5()\n", |
|
|
|
"optimizer = torch.optim.Adam(net.parameters(), lr=1e-1)\n", |
|
|
|
"criterion = nn.CrossEntropyLoss()" |
|
|
|
"# 显示转化后的图像\n", |
|
|
|
"for im, label in train_data:\n", |
|
|
|
" print(im.shape)\n", |
|
|
|
" print(label.shape)\n", |
|
|
|
" \n", |
|
|
|
" img = im[0,0,:,:]\n", |
|
|
|
" lab = label[0]\n", |
|
|
|
" plt.imshow(img, cmap='gray')\n", |
|
|
|
" plt.title('%i' % lab)\n", |
|
|
|
" plt.colorbar()\n", |
|
|
|
" plt.show()\n", |
|
|
|
"\n", |
|
|
|
" print(im[0,0,:,:])\n", |
|
|
|
" break" |
|
|
|
] |
|
|
|
}, |
|
|
|
{ |
|
|
|
"cell_type": "code", |
|
|
|
"execution_count": null, |
|
|
|
"metadata": { |
|
|
|
"collapsed": true |
|
|
|
"scrolled": false |
|
|
|
}, |
|
|
|
"outputs": [], |
|
|
|
"source": [ |
|
|
|
"net = LeNet5()\n", |
|
|
|
"optimizer = torch.optim.Adam(net.parameters(), lr=1e-3)\n", |
|
|
|
"criterion = nn.CrossEntropyLoss()\n", |
|
|
|
"\n", |
|
|
|
"res = train(net, train_data, test_data, 20, \n", |
|
|
|
" optimizer, criterion,\n", |
|
|
|
" use_cuda=False)" |
|
|
|
" use_cuda=True)" |
|
|
|
] |
|
|
|
}, |
|
|
|
{ |
|
|
|
"cell_type": "code", |
|
|
|
"execution_count": null, |
|
|
|
"metadata": { |
|
|
|
"collapsed": true |
|
|
|
}, |
|
|
|
"metadata": {}, |
|
|
|
"outputs": [], |
|
|
|
"source": [ |
|
|
|
"import matplotlib.pyplot as plt\n", |
|
|
@@ -145,6 +170,7 @@ |
|
|
|
"plt.plot(res[0], label='train')\n", |
|
|
|
"plt.plot(res[2], label='valid')\n", |
|
|
|
"plt.xlabel('epoch')\n", |
|
|
|
"plt.ylabel('Loss')\n", |
|
|
|
"plt.legend(loc='best')\n", |
|
|
|
"plt.savefig('fig-res-lenet5-train-validate-loss.pdf')\n", |
|
|
|
"plt.show()\n", |
|
|
@@ -152,6 +178,7 @@ |
|
|
|
"plt.plot(res[1], label='train')\n", |
|
|
|
"plt.plot(res[3], label='valid')\n", |
|
|
|
"plt.xlabel('epoch')\n", |
|
|
|
"plt.ylabel('Acc')\n", |
|
|
|
"plt.legend(loc='best')\n", |
|
|
|
"plt.savefig('fig-res-lenet5-train-validate-acc.pdf')\n", |
|
|
|
"plt.show()" |
|
|
@@ -174,7 +201,7 @@ |
|
|
|
"name": "python", |
|
|
|
"nbconvert_exporter": "python", |
|
|
|
"pygments_lexer": "ipython3", |
|
|
|
"version": "3.5.4" |
|
|
|
"version": "3.7.9" |
|
|
|
} |
|
|
|
}, |
|
|
|
"nbformat": 4, |
|
|
|