You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

retrain.py 18 kB

2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536
  1. # Copyright (c) Microsoft Corporation.
  2. # Licensed under the MIT license.
  3. import sys
  4. import os
  5. import logging
  6. import pickle
  7. import shutil
  8. import random
  9. import math
  10. import time
  11. import datetime
  12. import argparse
  13. import distutils.util
  14. import numpy as np
  15. import torch
  16. from torch import nn
  17. from torch import optim
  18. from torch.utils.data import DataLoader
  19. import torch.nn.functional as Func
  20. from model import Model
  21. from pytorch.fixed import apply_fixed_architecture
  22. from dataloader import read_data_sst
  23. logger = logging.getLogger("nni.textnas")
  24. def parse_args():
  25. parser = argparse.ArgumentParser()
  26. parser.add_argument(
  27. "--reset_output_dir",
  28. type=distutils.util.strtobool,
  29. default=True,
  30. help="Whether to clean the output dir if existed. (default: %(default)s)")
  31. parser.add_argument(
  32. "--child_fixed_arc",
  33. type=str,
  34. required=True,
  35. help="Architecture json file. (default: %(default)s)")
  36. parser.add_argument(
  37. "--data_path",
  38. type=str,
  39. default="data",
  40. help="Directory containing the dataset and embedding file. (default: %(default)s)")
  41. parser.add_argument(
  42. "--output_dir",
  43. type=str,
  44. default="output",
  45. help="The output directory. (default: %(default)s)")
  46. parser.add_argument(
  47. "--child_lr_decay_scheme",
  48. type=str,
  49. default="cosine",
  50. help="Learning rate annealing strategy, only 'cosine' supported. (default: %(default)s)")
  51. parser.add_argument(
  52. "--batch_size",
  53. type=int,
  54. default=128,
  55. help="Number of samples each batch for training. (default: %(default)s)")
  56. parser.add_argument(
  57. "--eval_batch_size",
  58. type=int,
  59. default=128,
  60. help="Number of samples each batch for evaluation. (default: %(default)s)")
  61. parser.add_argument(
  62. "--class_num",
  63. type=int,
  64. default=5,
  65. help="The number of categories. (default: %(default)s)")
  66. parser.add_argument(
  67. "--global_seed",
  68. type=int,
  69. default=1234,
  70. help="Seed for reproduction. (default: %(default)s)")
  71. parser.add_argument(
  72. "--max_input_length",
  73. type=int,
  74. default=64,
  75. help="The maximum length of the sentence. (default: %(default)s)")
  76. parser.add_argument(
  77. "--num_epochs",
  78. type=int,
  79. default=10,
  80. help="The number of training epochs. (default: %(default)s)")
  81. parser.add_argument(
  82. "--child_num_layers",
  83. type=int,
  84. default=24,
  85. help="The layer number of the architecture. (default: %(default)s)")
  86. parser.add_argument(
  87. "--child_out_filters",
  88. type=int,
  89. default=256,
  90. help="The dimension of hidden states. (default: %(default)s)")
  91. parser.add_argument(
  92. "--child_out_filters_scale",
  93. type=int,
  94. default=1,
  95. help="The scale of hidden state dimension. (default: %(default)s)")
  96. parser.add_argument(
  97. "--child_lr_T_0",
  98. type=int,
  99. default=10,
  100. help="The length of one cycle. (default: %(default)s)")
  101. parser.add_argument(
  102. "--child_lr_T_mul",
  103. type=int,
  104. default=2,
  105. help="The multiplication factor per cycle. (default: %(default)s)")
  106. parser.add_argument(
  107. "--min_count",
  108. type=int,
  109. default=1,
  110. help="The threshold to cut off low frequent words. (default: %(default)s)")
  111. parser.add_argument(
  112. "--train_ratio",
  113. type=float,
  114. default=1.0,
  115. help="The sample ratio for the training set. (default: %(default)s)")
  116. parser.add_argument(
  117. "--valid_ratio",
  118. type=float,
  119. default=1.0,
  120. help="The sample ratio for the dev set. (default: %(default)s)")
  121. parser.add_argument(
  122. "--child_grad_bound",
  123. type=float,
  124. default=5.0,
  125. help="The threshold for gradient clipping. (default: %(default)s)")
  126. parser.add_argument(
  127. "--child_lr",
  128. type=float,
  129. default=0.02,
  130. help="The initial learning rate. (default: %(default)s)")
  131. parser.add_argument(
  132. "--cnn_keep_prob",
  133. type=float,
  134. default=0.8,
  135. help="Keep prob for cnn layer. (default: %(default)s)")
  136. parser.add_argument(
  137. "--final_output_keep_prob",
  138. type=float,
  139. default=1.0,
  140. help="Keep prob for the last output layer. (default: %(default)s)")
  141. parser.add_argument(
  142. "--lstm_out_keep_prob",
  143. type=float,
  144. default=0.8,
  145. help="Keep prob for the RNN layer. (default: %(default)s)")
  146. parser.add_argument(
  147. "--embed_keep_prob",
  148. type=float,
  149. default=0.8,
  150. help="Keep prob for the embedding layer. (default: %(default)s)")
  151. parser.add_argument(
  152. "--attention_keep_prob",
  153. type=float,
  154. default=0.8,
  155. help="Keep prob for the self-attention layer. (default: %(default)s)")
  156. parser.add_argument(
  157. "--child_l2_reg",
  158. type=float,
  159. default=3e-6,
  160. help="Weight decay factor. (default: %(default)s)")
  161. parser.add_argument(
  162. "--child_lr_max",
  163. type=float,
  164. default=0.002,
  165. help="The max learning rate. (default: %(default)s)")
  166. parser.add_argument(
  167. "--child_lr_min",
  168. type=float,
  169. default=0.001,
  170. help="The min learning rate. (default: %(default)s)")
  171. parser.add_argument(
  172. "--child_optim_algo",
  173. type=str,
  174. default="adam",
  175. help="Optimization algorithm. (default: %(default)s)")
  176. parser.add_argument(
  177. "--checkpoint_dir",
  178. type=str,
  179. default="best_checkpoint",
  180. help="Path for saved checkpoints. (default: %(default)s)")
  181. parser.add_argument(
  182. "--output_type",
  183. type=str,
  184. default="avg",
  185. help="Opertor type for the time steps reduction. (default: %(default)s)")
  186. parser.add_argument(
  187. "--multi_path",
  188. type=distutils.util.strtobool,
  189. default=False,
  190. help="Search for multiple path in the architecture. (default: %(default)s)")
  191. parser.add_argument(
  192. "--is_binary",
  193. type=distutils.util.strtobool,
  194. default=False,
  195. help="Binary label for sst dataset. (default: %(default)s)")
  196. parser.add_argument(
  197. "--is_cuda",
  198. type=distutils.util.strtobool,
  199. default=True,
  200. help="Specify the device type. (default: %(default)s)")
  201. parser.add_argument(
  202. "--is_mask",
  203. type=distutils.util.strtobool,
  204. default=True,
  205. help="Apply mask. (default: %(default)s)")
  206. parser.add_argument(
  207. "--fixed_seed",
  208. type=distutils.util.strtobool,
  209. default=True,
  210. help="Fix the seed. (default: %(default)s)")
  211. parser.add_argument(
  212. "--load_checkpoint",
  213. type=distutils.util.strtobool,
  214. default=False,
  215. help="Wether to load checkpoint. (default: %(default)s)")
  216. parser.add_argument(
  217. "--log_every",
  218. type=int,
  219. default=50,
  220. help="How many steps to log. (default: %(default)s)")
  221. parser.add_argument(
  222. "--eval_every_epochs",
  223. type=int,
  224. default=1,
  225. help="How many epochs to eval. (default: %(default)s)")
  226. global FLAGS
  227. FLAGS = parser.parse_args()
  228. def set_random_seed(seed):
  229. logger.info("set random seed for data reading: {}".format(seed))
  230. random.seed(seed)
  231. os.environ['PYTHONHASHSEED'] = str(seed)
  232. np.random.seed(seed)
  233. random.seed(seed)
  234. torch.manual_seed(seed)
  235. if FLAGS.is_cuda:
  236. torch.cuda.manual_seed(seed)
  237. torch.backends.cudnn.deterministic = True
  238. def get_model(embedding, num_layers):
  239. logger.info("num layers: {0}".format(num_layers))
  240. assert FLAGS.child_fixed_arc is not None, "Architecture should be provided."
  241. child_model = Model(
  242. embedding=embedding,
  243. hidden_units=FLAGS.child_out_filters_scale * FLAGS.child_out_filters,
  244. num_layers=num_layers,
  245. num_classes=FLAGS.class_num,
  246. choose_from_k=5 if FLAGS.multi_path else 1,
  247. lstm_keep_prob=FLAGS.lstm_out_keep_prob,
  248. cnn_keep_prob=FLAGS.cnn_keep_prob,
  249. att_keep_prob=FLAGS.attention_keep_prob,
  250. att_mask=FLAGS.is_mask,
  251. embed_keep_prob=FLAGS.embed_keep_prob,
  252. final_output_keep_prob=FLAGS.final_output_keep_prob,
  253. global_pool=FLAGS.output_type)
  254. apply_fixed_architecture(child_model, FLAGS.child_fixed_arc)
  255. return child_model
  256. def eval_once(child_model, device, eval_set, criterion, valid_dataloader=None, test_dataloader=None):
  257. if eval_set == "test":
  258. assert test_dataloader is not None
  259. dataloader = test_dataloader
  260. elif eval_set == "valid":
  261. assert valid_dataloader is not None
  262. dataloader = valid_dataloader
  263. else:
  264. raise NotImplementedError("Unknown eval_set '{}'".format(eval_set))
  265. tot_acc = 0
  266. tot = 0
  267. losses = []
  268. with torch.no_grad(): # save memory
  269. for batch in dataloader:
  270. (sent_ids, mask), labels = batch
  271. sent_ids = sent_ids.to(device, non_blocking=True)
  272. mask = mask.to(device, non_blocking=True)
  273. labels = labels.to(device, non_blocking=True)
  274. logits = child_model((sent_ids, mask)) # run
  275. loss = criterion(logits, labels.long())
  276. loss = loss.mean()
  277. preds = logits.argmax(dim=1).long()
  278. acc = torch.eq(preds, labels.long()).long().sum().item()
  279. losses.append(loss)
  280. tot_acc += acc
  281. tot += len(labels)
  282. losses = torch.tensor(losses)
  283. loss = losses.mean()
  284. if tot > 0:
  285. final_acc = float(tot_acc) / tot
  286. else:
  287. final_acc = 0
  288. logger.info("Error in calculating final_acc")
  289. return final_acc, loss
  290. def print_user_flags(FLAGS, line_limit=80):
  291. log_strings = "\n" + "-" * line_limit + "\n"
  292. for flag_name in sorted(vars(FLAGS)):
  293. value = "{}".format(getattr(FLAGS, flag_name))
  294. log_string = flag_name
  295. log_string += "." * (line_limit - len(flag_name) - len(value))
  296. log_string += value
  297. log_strings = log_strings + log_string
  298. log_strings = log_strings + "\n"
  299. log_strings += "-" * line_limit
  300. logger.info(log_strings)
  301. def count_model_params(trainable_params):
  302. num_vars = 0
  303. for var in trainable_params:
  304. num_vars += np.prod([dim for dim in var.size()])
  305. return num_vars
  306. def update_lr(
  307. optimizer,
  308. epoch,
  309. l2_reg=1e-4,
  310. lr_warmup_val=None,
  311. lr_init=0.1,
  312. lr_decay_scheme="cosine",
  313. lr_max=0.002,
  314. lr_min=0.000000001,
  315. lr_T_0=4,
  316. lr_T_mul=1,
  317. sync_replicas=False,
  318. num_aggregate=None,
  319. num_replicas=None):
  320. if lr_decay_scheme == "cosine":
  321. assert lr_max is not None, "Need lr_max to use lr_cosine"
  322. assert lr_min is not None, "Need lr_min to use lr_cosine"
  323. assert lr_T_0 is not None, "Need lr_T_0 to use lr_cosine"
  324. assert lr_T_mul is not None, "Need lr_T_mul to use lr_cosine"
  325. T_i = lr_T_0
  326. t_epoch = epoch
  327. last_reset = 0
  328. while True:
  329. t_epoch -= T_i
  330. if t_epoch < 0:
  331. break
  332. last_reset += T_i
  333. T_i *= lr_T_mul
  334. T_curr = epoch - last_reset
  335. def _update():
  336. rate = T_curr / T_i * 3.1415926
  337. lr = lr_min + 0.5 * (lr_max - lr_min) * (1.0 + math.cos(rate))
  338. return lr
  339. learning_rate = _update()
  340. else:
  341. raise ValueError("Unknown learning rate decay scheme {}".format(lr_decay_scheme))
  342. #update lr in optimizer
  343. for params_group in optimizer.param_groups:
  344. params_group['lr'] = learning_rate
  345. return learning_rate
  346. def train(device, data_path, output_dir, num_layers):
  347. logger.info("Build dataloader")
  348. train_dataset, valid_dataset, test_dataset, embedding = \
  349. read_data_sst(data_path,
  350. FLAGS.max_input_length,
  351. FLAGS.min_count,
  352. train_ratio=FLAGS.train_ratio,
  353. valid_ratio=FLAGS.valid_ratio,
  354. is_binary=FLAGS.is_binary)
  355. train_dataloader = DataLoader(train_dataset, batch_size=FLAGS.batch_size, shuffle=True, pin_memory=True)
  356. test_dataloader = DataLoader(test_dataset, batch_size=FLAGS.eval_batch_size, pin_memory=True)
  357. valid_dataloader = DataLoader(valid_dataset, batch_size=FLAGS.eval_batch_size, pin_memory=True)
  358. logger.info("Build model")
  359. child_model = get_model(embedding, num_layers)
  360. logger.info("Finish build model")
  361. #for name, var in child_model.named_parameters():
  362. # logger.info(name, var.size(), var.requires_grad) # output all params
  363. num_vars = count_model_params(child_model.parameters())
  364. logger.info("Model has {} params".format(num_vars))
  365. for m in child_model.modules(): # initializer
  366. if isinstance(m, (nn.Conv1d, nn.Linear)):
  367. nn.init.xavier_uniform_(m.weight)
  368. criterion = nn.CrossEntropyLoss()
  369. # get optimizer
  370. if FLAGS.child_optim_algo == "adam":
  371. optimizer = optim.Adam(child_model.parameters(), eps=1e-3, weight_decay=FLAGS.child_l2_reg) # with L2
  372. else:
  373. raise ValueError("Unknown optim_algo {}".format(FLAGS.child_optim_algo))
  374. child_model.to(device)
  375. criterion.to(device)
  376. logger.info("Start training")
  377. start_time = time.time()
  378. step = 0
  379. # save path
  380. model_save_path = os.path.join(FLAGS.output_dir, "model.pth")
  381. best_model_save_path = os.path.join(FLAGS.output_dir, "best_model.pth")
  382. best_acc = 0
  383. start_epoch = 0
  384. if FLAGS.load_checkpoint:
  385. if os.path.isfile(model_save_path):
  386. checkpoint = torch.load(model_save_path, map_location = torch.device('cpu'))
  387. step = checkpoint['step']
  388. start_epoch = checkpoint['epoch']
  389. child_model.load_state_dict(checkpoint['child_model_state_dict'])
  390. optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
  391. for epoch in range(start_epoch, FLAGS.num_epochs):
  392. lr = update_lr(optimizer,
  393. epoch,
  394. l2_reg=FLAGS.child_l2_reg,
  395. lr_warmup_val=None,
  396. lr_init=FLAGS.child_lr,
  397. lr_decay_scheme=FLAGS.child_lr_decay_scheme,
  398. lr_max=FLAGS.child_lr_max,
  399. lr_min=FLAGS.child_lr_min,
  400. lr_T_0=FLAGS.child_lr_T_0,
  401. lr_T_mul=FLAGS.child_lr_T_mul)
  402. child_model.train()
  403. for batch in train_dataloader:
  404. (sent_ids, mask), labels = batch
  405. sent_ids = sent_ids.to(device, non_blocking=True)
  406. mask = mask.to(device, non_blocking=True)
  407. labels = labels.to(device, non_blocking=True)
  408. step += 1
  409. logits = child_model((sent_ids, mask)) # run
  410. loss = criterion(logits, labels.long())
  411. loss = loss.mean()
  412. preds = logits.argmax(dim=1).long()
  413. acc = torch.eq(preds, labels.long()).long().sum().item()
  414. optimizer.zero_grad()
  415. loss.backward()
  416. grad_norm = 0
  417. trainable_params = child_model.parameters()
  418. assert FLAGS.child_grad_bound is not None, "Need grad_bound to clip gradients."
  419. # compute the gradient norm value
  420. grad_norm = nn.utils.clip_grad_norm_(trainable_params, 99999999)
  421. for param in trainable_params:
  422. nn.utils.clip_grad_norm_(param, FLAGS.child_grad_bound) # clip grad
  423. optimizer.step()
  424. if step % FLAGS.log_every == 0:
  425. curr_time = time.time()
  426. log_string = ""
  427. log_string += "epoch={:<6d}".format(epoch)
  428. log_string += "ch_step={:<6d}".format(step)
  429. log_string += " loss={:<8.6f}".format(loss)
  430. log_string += " lr={:<8.4f}".format(lr)
  431. log_string += " |g|={:<8.4f}".format(grad_norm)
  432. log_string += " tr_acc={:<3d}/{:>3d}".format(acc, logits.size()[0])
  433. log_string += " mins={:<10.2f}".format(float(curr_time - start_time) / 60)
  434. logger.info(log_string)
  435. epoch += 1
  436. save_state = {
  437. 'step' : step,
  438. 'epoch' : epoch,
  439. 'child_model_state_dict' : child_model.state_dict(),
  440. 'optimizer_state_dict' : optimizer.state_dict()}
  441. torch.save(save_state, model_save_path)
  442. child_model.eval()
  443. logger.info("Epoch {}: Eval".format(epoch))
  444. eval_acc, eval_loss = eval_once(child_model, device, "test", criterion, test_dataloader=test_dataloader)
  445. logger.info("ch_step={} {}_accuracy={:<6.4f} {}_loss={:<6.4f}".format(step, "test", eval_acc, "test", eval_loss))
  446. if eval_acc > best_acc:
  447. best_acc = eval_acc
  448. logger.info("Save best model")
  449. save_state = {
  450. 'step' : step,
  451. 'epoch' : epoch,
  452. 'child_model_state_dict' : child_model.state_dict(),
  453. 'optimizer_state_dict' : optimizer.state_dict()}
  454. torch.save(save_state, best_model_save_path)
  455. return eval_acc
  456. def main():
  457. parse_args()
  458. if not os.path.isdir(FLAGS.output_dir):
  459. logger.info("Path {} does not exist. Creating.".format(FLAGS.output_dir))
  460. os.makedirs(FLAGS.output_dir)
  461. elif FLAGS.reset_output_dir:
  462. logger.info("Path {} exists. Remove and remake.".format(FLAGS.output_dir))
  463. shutil.rmtree(FLAGS.output_dir, ignore_errors=True)
  464. os.makedirs(FLAGS.output_dir)
  465. print_user_flags(FLAGS)
  466. if FLAGS.fixed_seed:
  467. set_random_seed(FLAGS.global_seed)
  468. device = torch.device("cuda" if FLAGS.is_cuda else "cpu")
  469. train(device, FLAGS.data_path, FLAGS.output_dir, FLAGS.child_num_layers)
  470. if __name__ == "__main__":
  471. main()

一站式算法开发平台、高性能分布式深度学习框架、先进算法模型库、视觉模型炼知平台、数据可视化分析平台等一系列平台及工具,在模型高效分布式训练、数据处理和可视分析、模型炼知和轻量化等技术上形成独特优势,目前已在产学研等各领域近千家单位及个人提供AI应用赋能