You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

utils.py 5.4 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162
  1. from datetime import datetime
  2. import torch
  3. import torch.nn.functional as F
  4. from torch import nn
  5. from torch.autograd import Variable
  6. def get_acc(output, label):
  7. total = output.shape[0]
  8. _, pred_label = output.max(1)
  9. num_correct = (pred_label == label).sum().item()
  10. return num_correct / total
  11. def train(net, train_data, valid_data, num_epochs, optimizer, criterion, use_cuda=True):
  12. if use_cuda and torch.cuda.is_available():
  13. net = net.cuda()
  14. l_train_loss = []
  15. l_train_acc = []
  16. l_valid_loss = []
  17. l_valid_acc = []
  18. prev_time = datetime.now()
  19. for epoch in range(num_epochs):
  20. train_loss = 0
  21. train_acc = 0
  22. net = net.train()
  23. for im, label in train_data:
  24. if use_cuda and torch.cuda.is_available():
  25. im = Variable(im.cuda()) # (bs, 3, h, w)
  26. label = Variable(label.cuda()) # (bs, h, w)
  27. else:
  28. im = Variable(im)
  29. label = Variable(label)
  30. # forward
  31. output = net(im)
  32. loss = criterion(output, label)
  33. # backward
  34. optimizer.zero_grad()
  35. loss.backward()
  36. optimizer.step()
  37. train_loss += loss.item()
  38. train_acc += get_acc(output, label)
  39. if valid_data is not None:
  40. valid_loss = 0
  41. valid_acc = 0
  42. net = net.eval()
  43. for im, label in valid_data:
  44. if use_cuda and torch.cuda.is_available():
  45. im = Variable(im.cuda())
  46. label = Variable(label.cuda())
  47. else:
  48. im = Variable(im)
  49. label = Variable(label)
  50. output = net(im)
  51. loss = criterion(output, label)
  52. valid_loss += loss.item()
  53. valid_acc += get_acc(output, label)
  54. epoch_str = (
  55. "[%2d] Train:(L=%f, Acc=%f), Valid:(L=%f, Acc=%f), "
  56. % (epoch, train_loss / len(train_data),
  57. train_acc / len(train_data), valid_loss / len(valid_data),
  58. valid_acc / len(valid_data)))
  59. l_valid_acc.append(valid_acc / len(valid_data))
  60. l_valid_loss.append(valid_loss / len(valid_data))
  61. else:
  62. epoch_str = ("[%2d] Train:(L=%f, Acc=%f), " %
  63. (epoch, train_loss / len(train_data),
  64. train_acc / len(train_data)))
  65. l_train_acc.append(train_acc / len(train_data))
  66. l_train_loss.append(train_loss / len(train_data))
  67. cur_time = datetime.now()
  68. h, remainder = divmod((cur_time - prev_time).seconds, 3600)
  69. m, s = divmod(remainder, 60)
  70. time_str = "T: %02d:%02d:%02d" % (h, m, s)
  71. prev_time = cur_time
  72. print(epoch_str + time_str)
  73. return (l_train_loss, l_train_acc, l_valid_loss, l_valid_acc)
  74. def conv3x3(in_channel, out_channel, stride=1):
  75. return nn.Conv2d(
  76. in_channel, out_channel, 3, stride=stride, padding=1, bias=False)
  77. class residual_block(nn.Module):
  78. def __init__(self, in_channel, out_channel, same_shape=True):
  79. super(residual_block, self).__init__()
  80. self.same_shape = same_shape
  81. stride = 1 if self.same_shape else 2
  82. self.conv1 = conv3x3(in_channel, out_channel, stride=stride)
  83. self.bn1 = nn.BatchNorm2d(out_channel)
  84. self.conv2 = conv3x3(out_channel, out_channel)
  85. self.bn2 = nn.BatchNorm2d(out_channel)
  86. if not self.same_shape:
  87. self.conv3 = nn.Conv2d(in_channel, out_channel, 1, stride=stride)
  88. def forward(self, x):
  89. out = self.conv1(x)
  90. out = F.relu(self.bn1(out), True)
  91. out = self.conv2(out)
  92. out = F.relu(self.bn2(out), True)
  93. if not self.same_shape:
  94. x = self.conv3(x)
  95. return F.relu(x + out, True)
  96. class resnet(nn.Module):
  97. def __init__(self, in_channel, num_classes, verbose=False):
  98. super(resnet, self).__init__()
  99. self.verbose = verbose
  100. self.block1 = nn.Conv2d(in_channel, 64, 7, 2)
  101. self.block2 = nn.Sequential(
  102. nn.MaxPool2d(3, 2), residual_block(64, 64), residual_block(64, 64))
  103. self.block3 = nn.Sequential(
  104. residual_block(64, 128, False), residual_block(128, 128))
  105. self.block4 = nn.Sequential(
  106. residual_block(128, 256, False), residual_block(256, 256))
  107. self.block5 = nn.Sequential(
  108. residual_block(256, 512, False),
  109. residual_block(512, 512), nn.AvgPool2d(3))
  110. self.classifier = nn.Linear(512, num_classes)
  111. def forward(self, x):
  112. x = self.block1(x)
  113. if self.verbose:
  114. print('block 1 output: {}'.format(x.shape))
  115. x = self.block2(x)
  116. if self.verbose:
  117. print('block 2 output: {}'.format(x.shape))
  118. x = self.block3(x)
  119. if self.verbose:
  120. print('block 3 output: {}'.format(x.shape))
  121. x = self.block4(x)
  122. if self.verbose:
  123. print('block 4 output: {}'.format(x.shape))
  124. x = self.block5(x)
  125. if self.verbose:
  126. print('block 5 output: {}'.format(x.shape))
  127. x = x.view(x.shape[0], -1)
  128. x = self.classifier(x)
  129. return x

机器学习越来越多应用到飞行器、机器人等领域,其目的是利用计算机实现类似人类的智能,从而实现装备的智能化与无人化。本课程旨在引导学生掌握机器学习的基本知识、典型方法与技术,通过具体的应用案例激发学生对该学科的兴趣,鼓励学生能够从人工智能的角度来分析、解决飞行器、机器人所面临的问题和挑战。本课程主要内容包括Python编程基础,机器学习模型,无监督学习、监督学习、深度学习基础知识与实现,并学习如何利用机器学习解决实际问题,从而全面提升自我的《综合能力》。