Browse Source

Add result of DenseNet

pull/10/MERGE
bushuhui 3 years ago
parent
commit
9bc3e63049
3 changed files with 136 additions and 102 deletions
  1. +2
    -7
      7_deep_learning/1_CNN/02-LeNet5.ipynb
  2. +27
    -37
      7_deep_learning/1_CNN/04-vgg.ipynb
  3. +107
    -58
      7_deep_learning/1_CNN/07-densenet.ipynb

+ 2
- 7
7_deep_learning/1_CNN/02-LeNet5.ipynb View File

@@ -102,10 +102,8 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"import numpy as np\n",
"from torchvision.datasets import mnist\n", "from torchvision.datasets import mnist\n",
"from torch.utils.data import DataLoader\n", "from torch.utils.data import DataLoader\n",
"from torchvision.datasets import mnist \n",
"from torchvision import transforms as tfs\n", "from torchvision import transforms as tfs\n",
"from utils import train\n", "from utils import train\n",
"\n", "\n",
@@ -122,8 +120,7 @@
"train_set = mnist.MNIST('../../data/mnist', train=True, transform=data_tf, download=True) \n", "train_set = mnist.MNIST('../../data/mnist', train=True, transform=data_tf, download=True) \n",
"train_data = torch.utils.data.DataLoader(train_set, batch_size=64, shuffle=True)\n", "train_data = torch.utils.data.DataLoader(train_set, batch_size=64, shuffle=True)\n",
"test_set = mnist.MNIST('../../data/mnist', train=False, transform=data_tf, download=True) \n", "test_set = mnist.MNIST('../../data/mnist', train=False, transform=data_tf, download=True) \n",
"test_data = torch.utils.data.DataLoader(test_set, batch_size=128, shuffle=False)\n",
"\n"
"test_data = torch.utils.data.DataLoader(test_set, batch_size=128, shuffle=False)"
] ]
}, },
{ {
@@ -250,9 +247,7 @@
"optimizer = torch.optim.Adam(net.parameters(), lr=1e-3)\n", "optimizer = torch.optim.Adam(net.parameters(), lr=1e-3)\n",
"criterion = nn.CrossEntropyLoss()\n", "criterion = nn.CrossEntropyLoss()\n",
"\n", "\n",
"res = train(net, train_data, test_data, 20, \n",
" optimizer, criterion,\n",
" use_cuda=True)"
"res = train(net, train_data, test_data, 20, optimizer, criterion)"
] ]
}, },
{ {


+ 27
- 37
7_deep_learning/1_CNN/04-vgg.ipynb View File

@@ -85,12 +85,9 @@
}, },
"outputs": [], "outputs": [],
"source": [ "source": [
"import numpy as np\n",
"import torch\n", "import torch\n",
"from torch import nn\n", "from torch import nn\n",
"from torch.autograd import Variable\n", "from torch.autograd import Variable\n",
"from torchvision.datasets import CIFAR10\n",
"from torchvision import transforms as tfs\n",
"\n", "\n",
"def VGG_Block(num_convs, in_channels, out_channels):\n", "def VGG_Block(num_convs, in_channels, out_channels):\n",
" # 定义第一层\n", " # 定义第一层\n",
@@ -364,19 +361,12 @@
}, },
"outputs": [], "outputs": [],
"source": [ "source": [
"from torchvision.datasets import CIFAR10\n",
"from torchvision import transforms as tfs\n",
"from utils import train\n", "from utils import train\n",
"\n", "\n",
"# 使用数据增强\n",
"def train_tf(x):\n",
" im_aug = tfs.Compose([\n",
" #tfs.Resize(224),\n",
" tfs.ToTensor(),\n",
" tfs.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])\n",
" ])\n",
" x = im_aug(x)\n",
" return x\n",
"\n",
"def test_tf(x):\n",
"# 使用数据变换\n",
"def data_tf(x):\n",
" im_aug = tfs.Compose([\n", " im_aug = tfs.Compose([\n",
" #tfs.Resize(224),\n", " #tfs.Resize(224),\n",
" tfs.ToTensor(),\n", " tfs.ToTensor(),\n",
@@ -385,10 +375,10 @@
" x = im_aug(x)\n", " x = im_aug(x)\n",
" return x\n", " return x\n",
" \n", " \n",
"train_set = CIFAR10('../../data', train=True, transform=train_tf)\n",
"train_set = CIFAR10('../../data', train=True, transform=data_tf)\n",
"train_data = torch.utils.data.DataLoader(train_set, batch_size=64, shuffle=True)\n", "train_data = torch.utils.data.DataLoader(train_set, batch_size=64, shuffle=True)\n",
"test_set = CIFAR10('../../data', train=False, transform=test_tf)\n",
"test_data = torch.utils.data.DataLoader(test_set, batch_size=128, shuffle=False)\n",
"test_set = CIFAR10('../../data', train=False, transform=data_tf)\n",
"test_data = torch.utils.data.DataLoader(test_set, batch_size=128, shuffle=False)\n",
"\n", "\n",
"net = VGG_Net()\n", "net = VGG_Net()\n",
"optimizer = torch.optim.Adam(net.parameters(), lr=1e-3)\n", "optimizer = torch.optim.Adam(net.parameters(), lr=1e-3)\n",
@@ -409,26 +399,26 @@
"name": "stdout", "name": "stdout",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"Epoch 0. Train Loss: 1.746658, Train Acc: 0.313979, Valid Loss: 1.441544, Valid Acc: 0.459454, Time 00:00:24\n",
"Epoch 1. Train Loss: 1.293454, Train Acc: 0.532868, Valid Loss: 1.064248, Valid Acc: 0.630934, Time 00:00:27\n",
"Epoch 2. Train Loss: 0.976450, Train Acc: 0.663643, Valid Loss: 0.865004, Valid Acc: 0.693730, Time 00:00:27\n",
"Epoch 3. Train Loss: 0.797400, Train Acc: 0.734835, Valid Loss: 0.752283, Valid Acc: 0.745847, Time 00:00:27\n",
"Epoch 4. Train Loss: 0.669900, Train Acc: 0.782389, Valid Loss: 0.654101, Valid Acc: 0.787085, Time 00:00:27\n",
"Epoch 5. Train Loss: 0.573636, Train Acc: 0.815837, Valid Loss: 0.637741, Valid Acc: 0.797765, Time 00:00:27\n",
"Epoch 6. Train Loss: 0.499975, Train Acc: 0.839174, Valid Loss: 0.571682, Valid Acc: 0.809237, Time 00:00:27\n",
"Epoch 7. Train Loss: 0.428868, Train Acc: 0.862372, Valid Loss: 0.505937, Valid Acc: 0.831982, Time 00:00:27\n",
"Epoch 8. Train Loss: 0.373129, Train Acc: 0.882133, Valid Loss: 0.559180, Valid Acc: 0.821203, Time 00:00:27\n",
"Epoch 9. Train Loss: 0.323578, Train Acc: 0.895920, Valid Loss: 0.582161, Valid Acc: 0.827927, Time 00:00:27\n",
"Epoch 10. Train Loss: 0.274420, Train Acc: 0.913163, Valid Loss: 0.623256, Valid Acc: 0.813390, Time 00:00:27\n",
"Epoch 11. Train Loss: 0.239558, Train Acc: 0.924692, Valid Loss: 0.561333, Valid Acc: 0.840882, Time 00:00:27\n",
"Epoch 12. Train Loss: 0.207196, Train Acc: 0.935042, Valid Loss: 0.537481, Valid Acc: 0.843256, Time 00:00:27\n",
"Epoch 13. Train Loss: 0.180536, Train Acc: 0.943454, Valid Loss: 0.697694, Valid Acc: 0.818631, Time 00:00:27\n",
"Epoch 14. Train Loss: 0.157166, Train Acc: 0.950607, Valid Loss: 0.542898, Valid Acc: 0.857298, Time 00:00:27\n",
"Epoch 15. Train Loss: 0.138539, Train Acc: 0.957201, Valid Loss: 0.602181, Valid Acc: 0.847112, Time 00:00:27\n",
"Epoch 16. Train Loss: 0.130947, Train Acc: 0.960418, Valid Loss: 0.607590, Valid Acc: 0.852453, Time 00:00:27\n",
"Epoch 17. Train Loss: 0.109348, Train Acc: 0.966972, Valid Loss: 0.636679, Valid Acc: 0.848497, Time 00:00:27\n",
"Epoch 18. Train Loss: 0.099000, Train Acc: 0.970009, Valid Loss: 0.640463, Valid Acc: 0.849684, Time 00:00:27\n",
"Epoch 19. Train Loss: 0.083908, Train Acc: 0.974604, Valid Loss: 0.630587, Valid Acc: 0.859771, Time 00:00:27\n"
"[ 0] Train:(L=1.689224, Acc=0.340493), Valid:(L=1.514729, Acc=0.449664), T: 00:00:27\n",
"[ 1] Train:(L=1.211734, Acc=0.572111), Valid:(L=1.043950, Acc=0.638944), T: 00:00:27\n",
"[ 2] Train:(L=0.939749, Acc=0.680647), Valid:(L=0.795742, Acc=0.731408), T: 00:00:27\n",
"[ 3] Train:(L=0.776414, Acc=0.742987), Valid:(L=0.773068, Acc=0.741792), T: 00:00:27\n",
"[ 4] Train:(L=0.655303, Acc=0.784607), Valid:(L=0.697191, Acc=0.759296), T: 00:00:27\n",
"[ 5] Train:(L=0.565006, Acc=0.816956), Valid:(L=0.634936, Acc=0.791436), T: 00:00:27\n",
"[ 6] Train:(L=0.487787, Acc=0.842172), Valid:(L=0.643098, Acc=0.788766), T: 00:00:27\n",
"[ 7] Train:(L=0.422939, Acc=0.866129), Valid:(L=0.539120, Acc=0.828323), T: 00:00:27\n",
"[ 8] Train:(L=0.365580, Acc=0.882413), Valid:(L=0.598219, Acc=0.808248), T: 00:00:27\n",
"[ 9] Train:(L=0.316299, Acc=0.899097), Valid:(L=0.601980, Acc=0.821005), T: 00:00:27\n",
"[10] Train:(L=0.271955, Acc=0.914043), Valid:(L=0.664353, Acc=0.803402), T: 00:00:27\n",
"[11] Train:(L=0.240455, Acc=0.923633), Valid:(L=0.555360, Acc=0.831191), T: 00:00:27\n",
"[12] Train:(L=0.199894, Acc=0.937100), Valid:(L=0.514384, Acc=0.857793), T: 00:00:27\n",
"[13] Train:(L=0.175543, Acc=0.944973), Valid:(L=0.641336, Acc=0.842860), T: 00:00:27\n",
"[14] Train:(L=0.149222, Acc=0.953285), Valid:(L=0.600546, Acc=0.847706), T: 00:00:27\n",
"[15] Train:(L=0.137142, Acc=0.957221), Valid:(L=0.597016, Acc=0.851958), T: 00:00:27\n",
"[16] Train:(L=0.121605, Acc=0.962456), Valid:(L=0.638970, Acc=0.850870), T: 00:00:27\n",
"[17] Train:(L=0.101306, Acc=0.970149), Valid:(L=0.602158, Acc=0.851760), T: 00:00:27\n",
"[18] Train:(L=0.103552, Acc=0.970009), Valid:(L=0.619986, Acc=0.844640), T: 00:00:27\n",
"[19] Train:(L=0.092559, Acc=0.972047), Valid:(L=0.692679, Acc=0.841772), T: 00:00:27\n"
] ]
} }
], ],


+ 107
- 58
7_deep_learning/1_CNN/07-densenet.ipynb
File diff suppressed because it is too large
View File


Loading…
Cancel
Save