You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

misc.py 2.1 kB

2 years ago
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071
  1. """
  2. # -*- coding: utf-8 -*-
  3. -----------------------------------------------------------------------------------
  4. # Author: Nguyen Mau Dung
  5. # DoC: 2020.07.31
  6. # email: nguyenmaudung93.kstn@gmail.com
  7. -----------------------------------------------------------------------------------
  8. # Description: This script for logging
  9. """
  10. import os
  11. import torch
  12. import time
  13. def make_folder(folder_name):
  14. if not os.path.exists(folder_name):
  15. os.makedirs(folder_name)
  16. # or os.makedirs(folder_name, exist_ok=True)
  17. class AverageMeter(object):
  18. """Computes and stores the average and current value"""
  19. def __init__(self, name, fmt=':f'):
  20. self.name = name
  21. self.fmt = fmt
  22. self.reset()
  23. def reset(self):
  24. self.val = 0
  25. self.avg = 0
  26. self.sum = 0
  27. self.count = 0
  28. def update(self, val, n=1):
  29. self.val = val
  30. self.sum += val * n
  31. self.count += n
  32. self.avg = self.sum / self.count
  33. def __str__(self):
  34. fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
  35. return fmtstr.format(**self.__dict__)
  36. class ProgressMeter(object):
  37. def __init__(self, num_batches, meters, prefix=""):
  38. self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
  39. self.meters = meters
  40. self.prefix = prefix
  41. def display(self, batch):
  42. entries = [self.prefix + self.batch_fmtstr.format(batch)]
  43. entries += [str(meter) for meter in self.meters]
  44. print('\t'.join(entries))
  45. def get_message(self, batch):
  46. entries = [self.prefix + self.batch_fmtstr.format(batch)]
  47. entries += [str(meter) for meter in self.meters]
  48. return '\t'.join(entries)
  49. def _get_batch_fmtstr(self, num_batches):
  50. num_digits = len(str(num_batches // 1))
  51. fmt = '{:' + str(num_digits) + 'd}'
  52. return '[' + fmt + '/' + fmt.format(num_batches) + ']'
  53. def time_synchronized():
  54. torch.cuda.synchronize() if torch.cuda.is_available() else None
  55. return time.time()

一站式算法开发平台、高性能分布式深度学习框架、先进算法模型库、视觉模型炼知平台、数据可视化分析平台等一系列平台及工具,在模型高效分布式训练、数据处理和可视分析、模型炼知和轻量化等技术上形成独特优势,目前已在产学研等各领域近千家单位及个人提供AI应用赋能