You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

utils.py 630 B

2 years ago
123456789101112131415161718192021
  1. # Copyright (c) Microsoft Corporation.
  2. # Licensed under the MIT license.
  3. def accuracy(output, target, topk=(1,)):
  4. """ Computes the precision@k for the specified values of k """
  5. maxk = max(topk)
  6. batch_size = target.size(0)
  7. _, pred = output.topk(maxk, 1, True, True)
  8. pred = pred.t()
  9. # one-hot case
  10. if target.ndimension() > 1:
  11. target = target.max(1)[1]
  12. correct = pred.eq(target.view(1, -1).expand_as(pred))
  13. res = dict()
  14. for k in topk:
  15. correct_k = correct[:k].reshape(-1).float().sum(0)
  16. res["acc{}".format(k)] = correct_k.mul_(1.0 / batch_size).item()
  17. return res

一站式算法开发平台、高性能分布式深度学习框架、先进算法模型库、视觉模型炼知平台、数据可视化分析平台等一系列平台及工具,在模型高效分布式训练、数据处理和可视分析、模型炼知和轻量化等技术上形成独特优势,目前已在产学研等各领域近千家单位及个人提供AI应用赋能