You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

utils.py 829 B

2 years ago
123456789101112131415161718192021222324252627282930
  1. # Copyright (c) Microsoft Corporation.
  2. # Licensed under the MIT license.
  3. import torch
  4. def accuracy(output, target, topk=(1,)):
  5. """ Computes the precision@k for the specified values of k """
  6. maxk = max(topk)
  7. batch_size = target.size(0)
  8. _, pred = output.topk(maxk, 1, True, True)
  9. pred = pred.t()
  10. # one-hot case
  11. if target.ndimension() > 1:
  12. target = target.max(1)[1]
  13. correct = pred.eq(target.view(1, -1).expand_as(pred))
  14. res = dict()
  15. for k in topk:
  16. correct_k = correct[:k].view(-1).float().sum(0)
  17. res["acc{}".format(k)] = correct_k.mul_(1.0 / batch_size).item()
  18. return res
  19. def reward_accuracy(output, target, topk=(1,)):
  20. batch_size = target.size(0)
  21. _, predicted = torch.max(output.data, 1)
  22. return (predicted == target).sum().item() / batch_size

一站式算法开发平台、高性能分布式深度学习框架、先进算法模型库、视觉模型炼知平台、数据可视化分析平台等一系列平台及工具,在模型高效分布式训练、数据处理和可视分析、模型炼知和轻量化等技术上形成独特优势,目前已在产学研等各领域近千家单位及个人提供AI应用赋能