You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

torch_utils.py 1.2 kB

2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445
  1. """
  2. # -*- coding: utf-8 -*-
  3. -----------------------------------------------------------------------------------
  4. # Author: Nguyen Mau Dung
  5. # DoC: 2020.08.09
  6. # email: nguyenmaudung93.kstn@gmail.com
  7. -----------------------------------------------------------------------------------
  8. # Description: some utilities of torch (conversion)
  9. -----------------------------------------------------------------------------------
  10. """
  11. import torch
  12. import torch.distributed as dist
  13. __all__ = ['convert2cpu', 'convert2cpu_long', 'to_cpu', 'reduce_tensor', 'to_python_float', '_sigmoid']
  14. def convert2cpu(gpu_matrix):
  15. return torch.FloatTensor(gpu_matrix.size()).copy_(gpu_matrix)
  16. def convert2cpu_long(gpu_matrix):
  17. return torch.LongTensor(gpu_matrix.size()).copy_(gpu_matrix)
  18. def to_cpu(tensor):
  19. return tensor.detach().cpu()
  20. def reduce_tensor(tensor, world_size):
  21. rt = tensor.clone()
  22. dist.all_reduce(rt, op=dist.reduce_op.SUM)
  23. rt /= world_size
  24. return rt
  25. def to_python_float(t):
  26. if hasattr(t, 'item'):
  27. return t.item()
  28. else:
  29. return t[0]
  30. def _sigmoid(x):
  31. return torch.clamp(x.sigmoid_(), min=1e-4, max=1 - 1e-4)

一站式算法开发平台、高性能分布式深度学习框架、先进算法模型库、视觉模型炼知平台、数据可视化分析平台等一系列平台及工具,在模型高效分布式训练、数据处理和可视分析、模型炼知和轻量化等技术上形成独特优势,目前已在产学研等各领域近千家单位及个人提供AI应用赋能