|
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144 |
- from datetime import datetime
-
- import torch
- import torch.nn.functional as F
- from torch import nn
- from torch.autograd import Variable
-
-
- def get_acc(output, label):
- total = output.shape[0]
- _, pred_label = output.max(1)
- num_correct = (pred_label == label).sum().item()
- return num_correct / total
-
-
- def train(net, train_data, valid_data, num_epochs, optimizer, criterion):
- if torch.cuda.is_available():
- net = net.cuda()
- prev_time = datetime.now()
- for epoch in range(num_epochs):
- train_loss = 0
- train_acc = 0
- net = net.train()
- for im, label in train_data:
- if torch.cuda.is_available():
- im = Variable(im.cuda()) # (bs, 3, h, w)
- label = Variable(label.cuda()) # (bs, h, w)
- else:
- im = Variable(im)
- label = Variable(label)
- # forward
- output = net(im)
- loss = criterion(output, label)
- # backward
- optimizer.zero_grad()
- loss.backward()
- optimizer.step()
-
- train_loss += loss.item()
- train_acc += get_acc(output, label)
-
- cur_time = datetime.now()
- h, remainder = divmod((cur_time - prev_time).seconds, 3600)
- m, s = divmod(remainder, 60)
- time_str = "Time %02d:%02d:%02d" % (h, m, s)
- if valid_data is not None:
- valid_loss = 0
- valid_acc = 0
- net = net.eval()
- for im, label in valid_data:
- if torch.cuda.is_available():
- im = Variable(im.cuda(), volatile=True)
- label = Variable(label.cuda(), volatile=True)
- else:
- im = Variable(im, volatile=True)
- label = Variable(label, volatile=True)
- output = net(im)
- loss = criterion(output, label)
- valid_loss += loss.item()
- valid_acc += get_acc(output, label)
- epoch_str = (
- "Epoch %d. Train Loss: %f, Train Acc: %f, Valid Loss: %f, Valid Acc: %f, "
- % (epoch, train_loss / len(train_data),
- train_acc / len(train_data), valid_loss / len(valid_data),
- valid_acc / len(valid_data)))
- else:
- epoch_str = ("Epoch %d. Train Loss: %f, Train Acc: %f, " %
- (epoch, train_loss / len(train_data),
- train_acc / len(train_data)))
- prev_time = cur_time
- print(epoch_str + time_str)
-
-
- def conv3x3(in_channel, out_channel, stride=1):
- return nn.Conv2d(
- in_channel, out_channel, 3, stride=stride, padding=1, bias=False)
-
-
- class residual_block(nn.Module):
- def __init__(self, in_channel, out_channel, same_shape=True):
- super(residual_block, self).__init__()
- self.same_shape = same_shape
- stride = 1 if self.same_shape else 2
-
- self.conv1 = conv3x3(in_channel, out_channel, stride=stride)
- self.bn1 = nn.BatchNorm2d(out_channel)
-
- self.conv2 = conv3x3(out_channel, out_channel)
- self.bn2 = nn.BatchNorm2d(out_channel)
- if not self.same_shape:
- self.conv3 = nn.Conv2d(in_channel, out_channel, 1, stride=stride)
-
- def forward(self, x):
- out = self.conv1(x)
- out = F.relu(self.bn1(out), True)
- out = self.conv2(out)
- out = F.relu(self.bn2(out), True)
-
- if not self.same_shape:
- x = self.conv3(x)
- return F.relu(x + out, True)
-
-
- class resnet(nn.Module):
- def __init__(self, in_channel, num_classes, verbose=False):
- super(resnet, self).__init__()
- self.verbose = verbose
-
- self.block1 = nn.Conv2d(in_channel, 64, 7, 2)
-
- self.block2 = nn.Sequential(
- nn.MaxPool2d(3, 2), residual_block(64, 64), residual_block(64, 64))
-
- self.block3 = nn.Sequential(
- residual_block(64, 128, False), residual_block(128, 128))
-
- self.block4 = nn.Sequential(
- residual_block(128, 256, False), residual_block(256, 256))
-
- self.block5 = nn.Sequential(
- residual_block(256, 512, False),
- residual_block(512, 512), nn.AvgPool2d(3))
-
- self.classifier = nn.Linear(512, num_classes)
-
- def forward(self, x):
- x = self.block1(x)
- if self.verbose:
- print('block 1 output: {}'.format(x.shape))
- x = self.block2(x)
- if self.verbose:
- print('block 2 output: {}'.format(x.shape))
- x = self.block3(x)
- if self.verbose:
- print('block 3 output: {}'.format(x.shape))
- x = self.block4(x)
- if self.verbose:
- print('block 4 output: {}'.format(x.shape))
- x = self.block5(x)
- if self.verbose:
- print('block 5 output: {}'.format(x.shape))
- x = x.view(x.shape[0], -1)
- x = self.classifier(x)
- return x
|