You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

utils.py 1.4 kB

2 years ago
12345678910111213141516171819202122232425262728293031323334353637383940414243
  1. # Copyright (c) Microsoft Corporation.
  2. # Licensed under the MIT license.
  3. import torch
  4. import torch.nn as nn
  5. class CrossEntropyLabelSmooth(nn.Module):
  6. def __init__(self, num_classes, epsilon):
  7. super(CrossEntropyLabelSmooth, self).__init__()
  8. self.num_classes = num_classes
  9. self.epsilon = epsilon
  10. self.logsoftmax = nn.LogSoftmax(dim=1)
  11. def forward(self, inputs, targets):
  12. log_probs = self.logsoftmax(inputs)
  13. # todo , device="cuda:6"
  14. targets = torch.zeros_like(log_probs).scatter_(1, targets.unsqueeze(1), 1)
  15. targets = (1 - self.epsilon) * targets + self.epsilon / self.num_classes
  16. loss = (-targets * log_probs).mean(0).sum()
  17. return loss
  18. def accuracy(output, target, topk=(1, 5)):
  19. """ Computes the precision@k for the specified values of k """
  20. maxk = max(topk)
  21. batch_size = target.size(0)
  22. _, pred = output.topk(maxk, 1, True, True)
  23. pred = pred.t()
  24. # one-hot case
  25. if target.ndimension() > 1:
  26. target = target.max(1)[1]
  27. correct = pred.eq(target.view(1, -1).expand_as(pred))
  28. res = dict()
  29. for k in topk:
  30. # correct_k = correct[:k].view(-1).float().sum(0) # 原始结果
  31. correct_k = correct[:k].reshape(-1).float().sum(0) # .view(-1)不支持
  32. res["acc{}".format(k)] = correct_k.mul_(1.0 / batch_size).item()
  33. return res

一站式算法开发平台、高性能分布式深度学习框架、先进算法模型库、视觉模型炼知平台、数据可视化分析平台等一系列平台及工具,在模型高效分布式训练、数据处理和可视分析、模型炼知和轻量化等技术上形成独特优势,目前已在产学研等各领域近千家单位及个人提供AI应用赋能