# Copyright (c) Microsoft Corporation. # Licensed under the MIT license. from collections import OrderedDict import json import random import numpy as np import torch import torch.nn as nn import os from datetime import datetime from io import TextIOBase import logging import sys import time from pytorch.trainer import TorchTensorEncoder _counter = 0 def global_mutable_counting(): """ A program level counter starting from 1. """ global _counter _counter += 1 return _counter def set_seed(seed): random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed) torch.backends.cudnn.benchmark = False torch.backends.cudnn.deterministic = True def _reset_global_mutable_counting(): """ Reset the global mutable counting to count from 1. Useful when defining multiple models with default keys. """ global _counter _counter = 0 def to_device(obj, device): """ Move a tensor, tuple, list, or dict onto device. """ if torch.is_tensor(obj): return obj.to(device) if isinstance(obj, tuple): return tuple(to_device(t, device) for t in obj) if isinstance(obj, list): return [to_device(t, device) for t in obj] if isinstance(obj, dict): return {k: to_device(v, device) for k, v in obj.items()} if isinstance(obj, (int, float, str)): return obj raise ValueError("'%s' has unsupported type '%s'" % (obj, type(obj))) def to_list(arr): if torch.is_tensor(arr): return arr.cpu().numpy().tolist() if isinstance(arr, np.ndarray): return arr.tolist() if isinstance(arr, (list, tuple)): return list(arr) return arr def count_parameters_in_MB(model): return np.sum( np.prod(v.size()) for name, v in model.named_parameters() if "auxiliary" not in name) / 1e6 def str2bool(str): return True if str.lower() == 'true' else False class AverageMeterGroup: """ Average meter group for multiple average meters. """ def __init__(self): self.meters = OrderedDict() def update(self, data): """ Update the meter group with a dict of metrics. Non-exist average meters will be automatically created. """ for k, v in data.items(): if k not in self.meters: self.meters[k] = AverageMeter(k, ":4f") self.meters[k].update(v) def __getattr__(self, item): return self.meters[item] def __getitem__(self, item): return self.meters[item] def __str__(self): return " ".join(str(v) for v in self.meters.values()) def summary(self): """ Return a summary string of group data. """ return " ".join(v.summary() for v in self.meters.values()) def get_last_acc(self): return float([v.summary() for v in self.meters.values()][0].split(': ')[1]) class AverageMeter: """ Computes and stores the average and current value. Parameters ---------- name : str Name to display. fmt : str Format string to print the values. """ def __init__(self, name, fmt=':f'): self.name = name self.fmt = fmt self.reset() def reset(self): """ Reset the meter. """ self.val = 0 self.avg = 0 self.sum = 0 self.count = 0 def update(self, val, n=1): """ Update with value and weight. Parameters ---------- val : float or int The new value to be accounted in. n : int The weight of the new value. """ self.val = val self.sum += val * n self.count += n self.avg = self.sum / self.count def __str__(self): fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' return fmtstr.format(**self.__dict__) def summary(self): fmtstr = '{name}: {avg' + self.fmt + '}' return fmtstr.format(**self.__dict__) class StructuredMutableTreeNode: """ A structured representation of a search space. A search space comes with a root (with `None` stored in its `mutable`), and a bunch of children in its `children`. This tree can be seen as a "flattened" version of the module tree. Since nested mutable entity is not supported yet, the following must be true: each subtree corresponds to a ``MutableScope`` and each leaf corresponds to a ``Mutable`` (other than ``MutableScope``). Parameters ---------- mutable : nni.nas.pytorch.mutables.Mutable The mutable that current node is linked with. """ def __init__(self, mutable): self.mutable = mutable self.children = [] def add_child(self, mutable): """ Add a tree node to the children list of current node. """ self.children.append(StructuredMutableTreeNode(mutable)) return self.children[-1] def type(self): """ Return the ``type`` of mutable content. """ return type(self.mutable) def __iter__(self): return self.traverse() def traverse(self, order="pre", deduplicate=True, memo=None): """ Return a generator that generates a list of mutables in this tree. Parameters ---------- order : str pre or post. If pre, current mutable is yield before children. Otherwise after. deduplicate : bool If true, mutables with the same key will not appear after the first appearance. memo : dict An auxiliary dict that memorize keys seen before, so that deduplication is possible. Returns ------- generator of Mutable """ if memo is None: memo = set() assert order in ["pre", "post"] if order == "pre": if self.mutable is not None: if not deduplicate or self.mutable.key not in memo: memo.add(self.mutable.key) yield self.mutable for child in self.children: for m in child.traverse(order=order, deduplicate=deduplicate, memo=memo): yield m if order == "post": if self.mutable is not None: if not deduplicate or self.mutable.key not in memo: memo.add(self.mutable.key) yield self.mutable def dump_global_result(res_path, global_result): with open(res_path, "w") as ss_file: json.dump(global_result, ss_file, indent=2, cls=TorchTensorEncoder) def save_best_checkpoint(checkpoint_dir, model, optimizer, epoch): """ Dump to 'best_checkpoint_epoch{}.pth.tar'.format(epoch)' on last epoch end. ``DataParallel`` object will have their inside modules exported. """ if isinstance(model, nn.DataParallel): child_model_state_dict = model.module.state_dict() else: child_model_state_dict = model.state_dict() save_state = {'child_model_state_dict': child_model_state_dict, 'optimizer_state_dict': optimizer.state_dict(), 'epoch': epoch} dest_path = os.path.join(checkpoint_dir, "best_checkpoint_epoch{}.pth".format(epoch)) torch.save(save_state, dest_path) log_level_map = { 'fatal': logging.FATAL, 'error': logging.ERROR, 'warning': logging.WARNING, 'info': logging.INFO, 'debug': logging.DEBUG } _time_format = '%m/%d/%Y, %I:%M:%S %p' class _LoggerFileWrapper(TextIOBase): def __init__(self, logger_file): self.file = logger_file def write(self, s): if s != '\n': cur_time = datetime.now().strftime(_time_format) self.file.write('[{}] PRINT '.format(cur_time) + s + '\n') self.file.flush() return len(s) def init_logger(logger_file_path, log_level_name='info'): """Initialize root logger. This will redirect anything from logging.getLogger() as well as stdout to specified file. logger_file_path: path of logger file (path-like object). """ log_level = log_level_map.get(log_level_name) logger_file = open(logger_file_path, 'w') fmt = '[%(asctime)s] %(levelname)s (%(name)s/%(threadName)s) %(message)s' logging.Formatter.converter = time.localtime formatter = logging.Formatter(fmt, _time_format) stream_handler = logging.StreamHandler() stream_handler.setFormatter(formatter) file_handler = logging.FileHandler(logger_file_path) file_handler.setFormatter(formatter) root_logger = logging.getLogger() root_logger.addHandler(stream_handler) root_logger.addHandler(file_handler) root_logger.setLevel(log_level) # include print function output sys.stdout = _LoggerFileWrapper(logger_file) def mkdirs(*args): for path in args: dirname = os.path.dirname(path) if dirname and not os.path.exists(dirname): print("make {} in dir: {}".format(path, dirname)) os.makedirs(dirname) def list_str2int(ls): return list(map(lambda x: int(x), ls))