import os import argparse import logging import random import numpy as np import torch import torch.nn as nn from dataloader import get_imagenet_iter_dali from pytorch.fixed import apply_fixed_architecture from pytorch.utils import AverageMeterGroup from torch.utils.tensorboard import SummaryWriter # import torch.distributed as dist from network import ShuffleNetV2OneShot from utils import CrossEntropyLabelSmooth, accuracy # os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3" logger = logging.getLogger("nni.spos.scratch") def train(epoch, model, criterion, optimizer, loader, writer, args): model.train() meters = AverageMeterGroup() cur_lr = optimizer.param_groups[0]["lr"] for step, (x, y) in enumerate(loader): cur_step = len(loader) * epoch + step optimizer.zero_grad() logits = model(x) loss = criterion(logits, y) loss.backward() optimizer.step() metrics = accuracy(logits, y) metrics["loss"] = loss.item() meters.update(metrics) writer.add_scalar("lr", cur_lr, global_step=cur_step) writer.add_scalar("loss/train", loss.item(), global_step=cur_step) writer.add_scalar("acc1/train", metrics["acc1"], global_step=cur_step) writer.add_scalar("acc5/train", metrics["acc5"], global_step=cur_step) if step % args.log_frequency == 0 or step + 1 == len(loader): logger.info("Epoch [%d/%d] Step [%d/%d] %s", epoch + 1, args.epochs, step + 1, len(loader), meters) if step > len(loader): break logger.info("Epoch %d training summary: %s", epoch + 1, meters) def validate(epoch, model, criterion, loader, writer, args): model.eval() meters = AverageMeterGroup() with torch.no_grad(): for step, (x, y) in enumerate(loader): logits = model(x) loss = criterion(logits, y) metrics = accuracy(logits, y) metrics["loss"] = loss.item() meters.update(metrics) if step % args.log_frequency == 0 or step + 1 == len(loader): logger.info("Epoch [%d/%d] Validation Step [%d/%d] %s", epoch + 1, args.epochs, step + 1, len(loader), meters) if step > len(loader): break writer.add_scalar("loss/test", meters.loss.avg, global_step=epoch) writer.add_scalar("acc1/test", meters.acc1.avg, global_step=epoch) writer.add_scalar("acc5/test", meters.acc5.avg, global_step=epoch) logger.info("Epoch %d validation: top1 = %f, top5 = %f", epoch + 1, meters.acc1.avg, meters.acc5.avg) def dump_checkpoint(model, epoch, checkpoint_dir): if isinstance(model, nn.DataParallel): state_dict = model.module.state_dict() else: state_dict = model.state_dict() if not os.path.exists(checkpoint_dir): os.makedirs(checkpoint_dir) dest_path = os.path.join(checkpoint_dir, "epoch_{}.pth.tar".format(epoch)) logger.info("Saving model to %s", dest_path) torch.save(state_dict, dest_path) if __name__ == "__main__": parser = argparse.ArgumentParser("SPOS Training From Scratch") parser.add_argument("--imagenet-dir", type=str, default="/mnt/local/hanjiayi/imagenet") # ./data/imagenet parser.add_argument("--tb-dir", type=str, default="runs") parser.add_argument("--architecture", type=str, default="./checkpoints/037_034.json") # "architecture_final.json" parser.add_argument("--workers", type=int, default=4) parser.add_argument("--batch-size", type=int, default=1024) parser.add_argument("--epochs", type=int, default=240) parser.add_argument("--learning-rate", type=float, default=0.5) parser.add_argument("--momentum", type=float, default=0.9) parser.add_argument("--weight-decay", type=float, default=4E-5) parser.add_argument("--label-smooth", type=float, default=0.1) parser.add_argument("--log-frequency", type=int, default=10) parser.add_argument("--lr-decay", type=str, default="linear") parser.add_argument("--seed", type=int, default=42) parser.add_argument("--spos-preprocessing", default=False, action="store_true") parser.add_argument("--label-smoothing", type=float, default=0.1) parser.add_argument("--local_rank", default=[0,1,2,3]) args = parser.parse_args() torch.manual_seed(args.seed) torch.cuda.manual_seed_all(args.seed) np.random.seed(args.seed) random.seed(args.seed) torch.backends.cudnn.deterministic = True model = ShuffleNetV2OneShot(affine=True) model.cuda("cuda:0") apply_fixed_architecture(model, args.architecture) # state_dict是否发生变化 state_dict = model.state_dict() # todo DDP并行的一些设置 # dist.init_process_group(backend = "nccl") if torch.cuda.device_count() > 1: # exclude last gpu, saving for data preprocessing on gpu model = nn.DataParallel(model, device_ids=list(range(0, torch.cuda.device_count() - 1))) # todo # device_ids=list(range(0, torch.cuda.device_count() - 1)) # model = torch.nn.parallel.DistributedDataParallel(model, device_ids=args.local_rank) criterion = CrossEntropyLabelSmooth(1000, args.label_smoothing) optimizer = torch.optim.SGD(model.parameters(), lr=args.learning_rate, momentum=args.momentum, weight_decay=args.weight_decay) if args.lr_decay == "linear": scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lambda step: (1.0 - step / args.epochs) if step <= args.epochs else 0, last_epoch=-1) elif args.lr_decay == "cosine": scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, args.epochs, 1E-3) else: raise ValueError("'%s' not supported." % args.lr_decay) writer = SummaryWriter(log_dir=args.tb_dir) train_loader = get_imagenet_iter_dali("train", args.imagenet_dir, args.batch_size, args.workers, spos_preprocessing=args.spos_preprocessing) val_loader = get_imagenet_iter_dali("val", args.imagenet_dir, args.batch_size, args.workers, spos_preprocessing=args.spos_preprocessing) for epoch in range(args.epochs): train(epoch, model, criterion, optimizer, train_loader, writer, args) validate(epoch, model, criterion, val_loader, writer, args) scheduler.step() dump_checkpoint(model, epoch, "scratch_checkpoints") writer.close()