You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

utils.py 4.9 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144
  1. from datetime import datetime
  2. import torch
  3. import torch.nn.functional as F
  4. from torch import nn
  5. from torch.autograd import Variable
  6. def get_acc(output, label):
  7. total = output.shape[0]
  8. _, pred_label = output.max(1)
  9. num_correct = (pred_label == label).sum().item()
  10. return num_correct / total
  11. def train(net, train_data, valid_data, num_epochs, optimizer, criterion):
  12. if torch.cuda.is_available():
  13. net = net.cuda()
  14. prev_time = datetime.now()
  15. for epoch in range(num_epochs):
  16. train_loss = 0
  17. train_acc = 0
  18. net = net.train()
  19. for im, label in train_data:
  20. if torch.cuda.is_available():
  21. im = Variable(im.cuda()) # (bs, 3, h, w)
  22. label = Variable(label.cuda()) # (bs, h, w)
  23. else:
  24. im = Variable(im)
  25. label = Variable(label)
  26. # forward
  27. output = net(im)
  28. loss = criterion(output, label)
  29. # backward
  30. optimizer.zero_grad()
  31. loss.backward()
  32. optimizer.step()
  33. train_loss += loss.item()
  34. train_acc += get_acc(output, label)
  35. cur_time = datetime.now()
  36. h, remainder = divmod((cur_time - prev_time).seconds, 3600)
  37. m, s = divmod(remainder, 60)
  38. time_str = "Time %02d:%02d:%02d" % (h, m, s)
  39. if valid_data is not None:
  40. valid_loss = 0
  41. valid_acc = 0
  42. net = net.eval()
  43. for im, label in valid_data:
  44. if torch.cuda.is_available():
  45. im = Variable(im.cuda(), volatile=True)
  46. label = Variable(label.cuda(), volatile=True)
  47. else:
  48. im = Variable(im, volatile=True)
  49. label = Variable(label, volatile=True)
  50. output = net(im)
  51. loss = criterion(output, label)
  52. valid_loss += loss.item()
  53. valid_acc += get_acc(output, label)
  54. epoch_str = (
  55. "Epoch %d. Train Loss: %f, Train Acc: %f, Valid Loss: %f, Valid Acc: %f, "
  56. % (epoch, train_loss / len(train_data),
  57. train_acc / len(train_data), valid_loss / len(valid_data),
  58. valid_acc / len(valid_data)))
  59. else:
  60. epoch_str = ("Epoch %d. Train Loss: %f, Train Acc: %f, " %
  61. (epoch, train_loss / len(train_data),
  62. train_acc / len(train_data)))
  63. prev_time = cur_time
  64. print(epoch_str + time_str)
  65. def conv3x3(in_channel, out_channel, stride=1):
  66. return nn.Conv2d(
  67. in_channel, out_channel, 3, stride=stride, padding=1, bias=False)
  68. class residual_block(nn.Module):
  69. def __init__(self, in_channel, out_channel, same_shape=True):
  70. super(residual_block, self).__init__()
  71. self.same_shape = same_shape
  72. stride = 1 if self.same_shape else 2
  73. self.conv1 = conv3x3(in_channel, out_channel, stride=stride)
  74. self.bn1 = nn.BatchNorm2d(out_channel)
  75. self.conv2 = conv3x3(out_channel, out_channel)
  76. self.bn2 = nn.BatchNorm2d(out_channel)
  77. if not self.same_shape:
  78. self.conv3 = nn.Conv2d(in_channel, out_channel, 1, stride=stride)
  79. def forward(self, x):
  80. out = self.conv1(x)
  81. out = F.relu(self.bn1(out), True)
  82. out = self.conv2(out)
  83. out = F.relu(self.bn2(out), True)
  84. if not self.same_shape:
  85. x = self.conv3(x)
  86. return F.relu(x + out, True)
  87. class resnet(nn.Module):
  88. def __init__(self, in_channel, num_classes, verbose=False):
  89. super(resnet, self).__init__()
  90. self.verbose = verbose
  91. self.block1 = nn.Conv2d(in_channel, 64, 7, 2)
  92. self.block2 = nn.Sequential(
  93. nn.MaxPool2d(3, 2), residual_block(64, 64), residual_block(64, 64))
  94. self.block3 = nn.Sequential(
  95. residual_block(64, 128, False), residual_block(128, 128))
  96. self.block4 = nn.Sequential(
  97. residual_block(128, 256, False), residual_block(256, 256))
  98. self.block5 = nn.Sequential(
  99. residual_block(256, 512, False),
  100. residual_block(512, 512), nn.AvgPool2d(3))
  101. self.classifier = nn.Linear(512, num_classes)
  102. def forward(self, x):
  103. x = self.block1(x)
  104. if self.verbose:
  105. print('block 1 output: {}'.format(x.shape))
  106. x = self.block2(x)
  107. if self.verbose:
  108. print('block 2 output: {}'.format(x.shape))
  109. x = self.block3(x)
  110. if self.verbose:
  111. print('block 3 output: {}'.format(x.shape))
  112. x = self.block4(x)
  113. if self.verbose:
  114. print('block 4 output: {}'.format(x.shape))
  115. x = self.block5(x)
  116. if self.verbose:
  117. print('block 5 output: {}'.format(x.shape))
  118. x = x.view(x.shape[0], -1)
  119. x = self.classifier(x)
  120. return x

机器学习越来越多应用到飞行器、机器人等领域,其目的是利用计算机实现类似人类的智能,从而实现装备的智能化与无人化。本课程旨在引导学生掌握机器学习的基本知识、典型方法与技术,通过具体的应用案例激发学生对该学科的兴趣,鼓励学生能够从人工智能的角度来分析、解决飞行器、机器人所面临的问题和挑战。本课程主要内容包括Python编程基础,机器学习模型,无监督学习、监督学习、深度学习基础知识与实现,并学习如何利用机器学习解决实际问题,从而全面提升自我的《综合能力》。