You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

utils.py 9.2 kB

2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323
  1. # Copyright (c) Microsoft Corporation.
  2. # Licensed under the MIT license.
  3. from collections import OrderedDict
  4. import json
  5. import random
  6. import numpy as np
  7. import torch
  8. import torch.nn as nn
  9. import os
  10. from datetime import datetime
  11. from io import TextIOBase
  12. import logging
  13. import sys
  14. import time
  15. from pytorch.trainer import TorchTensorEncoder
  16. _counter = 0
  17. def global_mutable_counting():
  18. """
  19. A program level counter starting from 1.
  20. """
  21. global _counter
  22. _counter += 1
  23. return _counter
  24. def set_seed(seed):
  25. random.seed(seed)
  26. np.random.seed(seed)
  27. torch.manual_seed(seed)
  28. if torch.cuda.is_available():
  29. torch.cuda.manual_seed_all(seed)
  30. torch.backends.cudnn.benchmark = False
  31. torch.backends.cudnn.deterministic = True
  32. def _reset_global_mutable_counting():
  33. """
  34. Reset the global mutable counting to count from 1. Useful when defining multiple models with default keys.
  35. """
  36. global _counter
  37. _counter = 0
  38. def to_device(obj, device):
  39. """
  40. Move a tensor, tuple, list, or dict onto device.
  41. """
  42. if torch.is_tensor(obj):
  43. return obj.to(device)
  44. if isinstance(obj, tuple):
  45. return tuple(to_device(t, device) for t in obj)
  46. if isinstance(obj, list):
  47. return [to_device(t, device) for t in obj]
  48. if isinstance(obj, dict):
  49. return {k: to_device(v, device) for k, v in obj.items()}
  50. if isinstance(obj, (int, float, str)):
  51. return obj
  52. raise ValueError("'%s' has unsupported type '%s'" % (obj, type(obj)))
  53. def to_list(arr):
  54. if torch.is_tensor(arr):
  55. return arr.cpu().numpy().tolist()
  56. if isinstance(arr, np.ndarray):
  57. return arr.tolist()
  58. if isinstance(arr, (list, tuple)):
  59. return list(arr)
  60. return arr
  61. def count_parameters_in_MB(model):
  62. return np.sum(
  63. np.prod(v.size()) for name, v in model.named_parameters() if "auxiliary" not in name) / 1e6
  64. def str2bool(str):
  65. return True if str.lower() == 'true' else False
  66. class AverageMeterGroup:
  67. """
  68. Average meter group for multiple average meters.
  69. """
  70. def __init__(self):
  71. self.meters = OrderedDict()
  72. def update(self, data):
  73. """
  74. Update the meter group with a dict of metrics.
  75. Non-exist average meters will be automatically created.
  76. """
  77. for k, v in data.items():
  78. if k not in self.meters:
  79. self.meters[k] = AverageMeter(k, ":4f")
  80. self.meters[k].update(v)
  81. def __getattr__(self, item):
  82. return self.meters[item]
  83. def __getitem__(self, item):
  84. return self.meters[item]
  85. def __str__(self):
  86. return " ".join(str(v) for v in self.meters.values())
  87. def summary(self):
  88. """
  89. Return a summary string of group data.
  90. """
  91. return " ".join(v.summary() for v in self.meters.values())
  92. def get_last_acc(self):
  93. return float([v.summary() for v in self.meters.values()][0].split(': ')[1])
  94. class AverageMeter:
  95. """
  96. Computes and stores the average and current value.
  97. Parameters
  98. ----------
  99. name : str
  100. Name to display.
  101. fmt : str
  102. Format string to print the values.
  103. """
  104. def __init__(self, name, fmt=':f'):
  105. self.name = name
  106. self.fmt = fmt
  107. self.reset()
  108. def reset(self):
  109. """
  110. Reset the meter.
  111. """
  112. self.val = 0
  113. self.avg = 0
  114. self.sum = 0
  115. self.count = 0
  116. def update(self, val, n=1):
  117. """
  118. Update with value and weight.
  119. Parameters
  120. ----------
  121. val : float or int
  122. The new value to be accounted in.
  123. n : int
  124. The weight of the new value.
  125. """
  126. self.val = val
  127. self.sum += val * n
  128. self.count += n
  129. self.avg = self.sum / self.count
  130. def __str__(self):
  131. fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
  132. return fmtstr.format(**self.__dict__)
  133. def summary(self):
  134. fmtstr = '{name}: {avg' + self.fmt + '}'
  135. return fmtstr.format(**self.__dict__)
  136. class StructuredMutableTreeNode:
  137. """
  138. A structured representation of a search space.
  139. A search space comes with a root (with `None` stored in its `mutable`), and a bunch of children in its `children`.
  140. This tree can be seen as a "flattened" version of the module tree. Since nested mutable entity is not supported yet,
  141. the following must be true: each subtree corresponds to a ``MutableScope`` and each leaf corresponds to a
  142. ``Mutable`` (other than ``MutableScope``).
  143. Parameters
  144. ----------
  145. mutable : nni.nas.pytorch.mutables.Mutable
  146. The mutable that current node is linked with.
  147. """
  148. def __init__(self, mutable):
  149. self.mutable = mutable
  150. self.children = []
  151. def add_child(self, mutable):
  152. """
  153. Add a tree node to the children list of current node.
  154. """
  155. self.children.append(StructuredMutableTreeNode(mutable))
  156. return self.children[-1]
  157. def type(self):
  158. """
  159. Return the ``type`` of mutable content.
  160. """
  161. return type(self.mutable)
  162. def __iter__(self):
  163. return self.traverse()
  164. def traverse(self, order="pre", deduplicate=True, memo=None):
  165. """
  166. Return a generator that generates a list of mutables in this tree.
  167. Parameters
  168. ----------
  169. order : str
  170. pre or post. If pre, current mutable is yield before children. Otherwise after.
  171. deduplicate : bool
  172. If true, mutables with the same key will not appear after the first appearance.
  173. memo : dict
  174. An auxiliary dict that memorize keys seen before, so that deduplication is possible.
  175. Returns
  176. -------
  177. generator of Mutable
  178. """
  179. if memo is None:
  180. memo = set()
  181. assert order in ["pre", "post"]
  182. if order == "pre":
  183. if self.mutable is not None:
  184. if not deduplicate or self.mutable.key not in memo:
  185. memo.add(self.mutable.key)
  186. yield self.mutable
  187. for child in self.children:
  188. for m in child.traverse(order=order, deduplicate=deduplicate, memo=memo):
  189. yield m
  190. if order == "post":
  191. if self.mutable is not None:
  192. if not deduplicate or self.mutable.key not in memo:
  193. memo.add(self.mutable.key)
  194. yield self.mutable
  195. def dump_global_result(res_path, global_result):
  196. with open(res_path, "w") as ss_file:
  197. json.dump(global_result, ss_file, indent=2, cls=TorchTensorEncoder)
  198. def save_best_checkpoint(checkpoint_dir, model, optimizer, epoch):
  199. """
  200. Dump to 'best_checkpoint_epoch{}.pth.tar'.format(epoch)' on last epoch end.
  201. ``DataParallel`` object will have their inside modules exported.
  202. """
  203. if isinstance(model, nn.DataParallel):
  204. child_model_state_dict = model.module.state_dict()
  205. else:
  206. child_model_state_dict = model.state_dict()
  207. save_state = {'child_model_state_dict': child_model_state_dict,
  208. 'optimizer_state_dict': optimizer.state_dict(),
  209. 'epoch': epoch}
  210. dest_path = os.path.join(checkpoint_dir, "best_checkpoint_epoch{}.pth".format(epoch))
  211. torch.save(save_state, dest_path)
  212. log_level_map = {
  213. 'fatal': logging.FATAL,
  214. 'error': logging.ERROR,
  215. 'warning': logging.WARNING,
  216. 'info': logging.INFO,
  217. 'debug': logging.DEBUG
  218. }
  219. _time_format = '%m/%d/%Y, %I:%M:%S %p'
  220. class _LoggerFileWrapper(TextIOBase):
  221. def __init__(self, logger_file):
  222. self.file = logger_file
  223. def write(self, s):
  224. if s != '\n':
  225. cur_time = datetime.now().strftime(_time_format)
  226. self.file.write('[{}] PRINT '.format(cur_time) + s + '\n')
  227. self.file.flush()
  228. return len(s)
  229. def init_logger(logger_file_path, log_level_name='info'):
  230. """Initialize root logger.
  231. This will redirect anything from logging.getLogger() as well as stdout to specified file.
  232. logger_file_path: path of logger file (path-like object).
  233. """
  234. log_level = log_level_map.get(log_level_name)
  235. logger_file = open(logger_file_path, 'w')
  236. fmt = '[%(asctime)s] %(levelname)s (%(name)s/%(threadName)s) %(message)s'
  237. logging.Formatter.converter = time.localtime
  238. formatter = logging.Formatter(fmt, _time_format)
  239. stream_handler = logging.StreamHandler()
  240. stream_handler.setFormatter(formatter)
  241. file_handler = logging.FileHandler(logger_file_path)
  242. file_handler.setFormatter(formatter)
  243. root_logger = logging.getLogger()
  244. root_logger.addHandler(stream_handler)
  245. root_logger.addHandler(file_handler)
  246. root_logger.setLevel(log_level)
  247. # include print function output
  248. sys.stdout = _LoggerFileWrapper(logger_file)
  249. def mkdirs(*args):
  250. for path in args:
  251. dirname = os.path.dirname(path)
  252. if dirname and not os.path.exists(dirname):
  253. print("make {} in dir: {}".format(path, dirname))
  254. os.makedirs(dirname)
  255. def list_str2int(ls):
  256. return list(map(lambda x: int(x), ls))

一站式算法开发平台、高性能分布式深度学习框架、先进算法模型库、视觉模型炼知平台、数据可视化分析平台等一系列平台及工具,在模型高效分布式训练、数据处理和可视分析、模型炼知和轻量化等技术上形成独特优势,目前已在产学研等各领域近千家单位及个人提供AI应用赋能