You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

utils.py 956 B

2 years ago
123456789101112131415161718192021222324252627282930313233343536373839
  1. # Copyright (c) Microsoft Corporation.
  2. # Licensed under the MIT license.
  3. import os
  4. import torch
  5. import torch.distributed as dist
  6. def accuracy(output, target, topk=(1,)):
  7. """ Computes the precision@k for the specified values of k """
  8. maxk = max(topk)
  9. batch_size = target.size(0)
  10. _, pred = output.topk(maxk, 1, True, True)
  11. pred = pred.t()
  12. # one-hot case
  13. if target.ndimension() > 1:
  14. target = target.max(1)[1]
  15. correct = pred.eq(target.reshape(1, -1).expand_as(pred))
  16. res = []
  17. for k in topk:
  18. correct_k = correct[:k].reshape(-1).float().sum(0)
  19. res.append(correct_k.mul_(1.0 / batch_size))
  20. return res
  21. def reduce_metrics(metrics):
  22. return {k: reduce_tensor(v).item() for k, v in metrics.items()}
  23. def reduce_tensor(tensor):
  24. rt = torch.sum(tensor)
  25. # rt = tensor.clone()
  26. # dist.all_reduce(rt, op=dist.ReduceOp.SUM)
  27. # rt /= float(os.environ["WORLD_SIZE"])
  28. return rt

一站式算法开发平台、高性能分布式深度学习框架、先进算法模型库、视觉模型炼知平台、数据可视化分析平台等一系列平台及工具,在模型高效分布式训练、数据处理和可视分析、模型炼知和轻量化等技术上形成独特优势,目前已在产学研等各领域近千家单位及个人提供AI应用赋能