You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

ML2022Spring_HW7.ipynb 23 kB

3 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568
  1. {
  2. "nbformat": 4,
  3. "nbformat_minor": 0,
  4. "metadata": {
  5. "accelerator": "GPU",
  6. "colab": {
  7. "name": "ML2022Spring - HW7.ipynb",
  8. "provenance": [],
  9. "collapsed_sections": []
  10. },
  11. "kernelspec": {
  12. "display_name": "Python 3",
  13. "name": "python3"
  14. }
  15. },
  16. "cells": [
  17. {
  18. "cell_type": "markdown",
  19. "metadata": {
  20. "id": "xvSGDbExff_I"
  21. },
  22. "source": [
  23. "# **Homework 7 - Bert (Question Answering)**\n",
  24. "\n",
  25. "If you have any questions, feel free to email us at mlta-2022-spring@googlegroups.com\n",
  26. "\n",
  27. "\n",
  28. "\n",
  29. "Slide: [Link](https://docs.google.com/presentation/d/1H5ZONrb2LMOCixLY7D5_5-7LkIaXO6AGEaV2mRdTOMY/edit?usp=sharing) Kaggle: [Link](https://www.kaggle.com/c/ml2022spring-hw7) Data: [Link](https://drive.google.com/uc?id=1AVgZvy3VFeg0fX-6WQJMHPVrx3A-M1kb)\n",
  30. "\n",
  31. "\n"
  32. ]
  33. },
  34. {
  35. "cell_type": "markdown",
  36. "metadata": {
  37. "id": "WGOr_eS3wJJf"
  38. },
  39. "source": [
  40. "## Task description\n",
  41. "- Chinese Extractive Question Answering\n",
  42. " - Input: Paragraph + Question\n",
  43. " - Output: Answer\n",
  44. "\n",
  45. "- Objective: Learn how to fine tune a pretrained model on downstream task using transformers\n",
  46. "\n",
  47. "- Todo\n",
  48. " - Fine tune a pretrained chinese BERT model\n",
  49. " - Change hyperparameters (e.g. doc_stride)\n",
  50. " - Apply linear learning rate decay\n",
  51. " - Try other pretrained models\n",
  52. " - Improve preprocessing\n",
  53. " - Improve postprocessing\n",
  54. "- Training tips\n",
  55. " - Automatic mixed precision\n",
  56. " - Gradient accumulation\n",
  57. " - Ensemble\n",
  58. "\n",
  59. "- Estimated training time (tesla t4 with automatic mixed precision enabled)\n",
  60. " - Simple: 8mins\n",
  61. " - Medium: 8mins\n",
  62. " - Strong: 25mins\n",
  63. " - Boss: 2.5hrs\n",
  64. " "
  65. ]
  66. },
  67. {
  68. "cell_type": "markdown",
  69. "metadata": {
  70. "id": "TJ1fSAJE2oaC"
  71. },
  72. "source": [
  73. "## Download Dataset"
  74. ]
  75. },
  76. {
  77. "cell_type": "code",
  78. "metadata": {
  79. "id": "YPrc4Eie9Yo5"
  80. },
  81. "source": [
  82. "# Download link 1\n",
  83. "!gdown --id '1AVgZvy3VFeg0fX-6WQJMHPVrx3A-M1kb' --output hw7_data.zip\n",
  84. "\n",
  85. "# Download Link 2 (if the above link fails) \n",
  86. "# !gdown --id '1qwjbRjq481lHsnTrrF4OjKQnxzgoLEFR' --output hw7_data.zip\n",
  87. "\n",
  88. "# Download Link 3 (if the above link fails) \n",
  89. "# !gdown --id '1QXuWjNRZH6DscSd6QcRER0cnxmpZvijn' --output hw7_data.zip\n",
  90. "\n",
  91. "!unzip -o hw7_data.zip\n",
  92. "\n",
  93. "# For this HW, K80 < P4 < T4 < P100 <= T4(fp16) < V100\n",
  94. "!nvidia-smi"
  95. ],
  96. "execution_count": null,
  97. "outputs": []
  98. },
  99. {
  100. "cell_type": "markdown",
  101. "metadata": {
  102. "id": "TevOvhC03m0h"
  103. },
  104. "source": [
  105. "## Install transformers\n",
  106. "\n",
  107. "Documentation for the toolkit: https://huggingface.co/transformers/"
  108. ]
  109. },
  110. {
  111. "cell_type": "code",
  112. "metadata": {
  113. "id": "tbxWFX_jpDom"
  114. },
  115. "source": [
  116. "# You are allowed to change version of transformers or use other toolkits\n",
  117. "!pip install transformers==4.5.0"
  118. ],
  119. "execution_count": null,
  120. "outputs": []
  121. },
  122. {
  123. "cell_type": "markdown",
  124. "metadata": {
  125. "id": "8dKM4yCh4LI_"
  126. },
  127. "source": [
  128. "## Import Packages"
  129. ]
  130. },
  131. {
  132. "cell_type": "code",
  133. "metadata": {
  134. "id": "WOTHHtWJoahe"
  135. },
  136. "source": [
  137. "import json\n",
  138. "import numpy as np\n",
  139. "import random\n",
  140. "import torch\n",
  141. "from torch.utils.data import DataLoader, Dataset \n",
  142. "from transformers import AdamW, BertForQuestionAnswering, BertTokenizerFast\n",
  143. "\n",
  144. "from tqdm.auto import tqdm\n",
  145. "\n",
  146. "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
  147. "\n",
  148. "# Fix random seed for reproducibility\n",
  149. "def same_seeds(seed):\n",
  150. "\t torch.manual_seed(seed)\n",
  151. "\t if torch.cuda.is_available():\n",
  152. "\t\t torch.cuda.manual_seed(seed)\n",
  153. "\t\t torch.cuda.manual_seed_all(seed)\n",
  154. "\t np.random.seed(seed)\n",
  155. "\t random.seed(seed)\n",
  156. "\t torch.backends.cudnn.benchmark = False\n",
  157. "\t torch.backends.cudnn.deterministic = True\n",
  158. "same_seeds(0)"
  159. ],
  160. "execution_count": null,
  161. "outputs": []
  162. },
  163. {
  164. "cell_type": "code",
  165. "metadata": {
  166. "id": "7pBtSZP1SKQO"
  167. },
  168. "source": [
  169. "# Change \"fp16_training\" to True to support automatic mixed precision training (fp16)\t\n",
  170. "fp16_training = False\n",
  171. "\n",
  172. "if fp16_training:\n",
  173. " !pip install accelerate==0.2.0\n",
  174. " from accelerate import Accelerator\n",
  175. " accelerator = Accelerator(fp16=True)\n",
  176. " device = accelerator.device\n",
  177. "\n",
  178. "# Documentation for the toolkit: https://huggingface.co/docs/accelerate/"
  179. ],
  180. "execution_count": null,
  181. "outputs": []
  182. },
  183. {
  184. "cell_type": "markdown",
  185. "metadata": {
  186. "id": "2YgXHuVLp_6j"
  187. },
  188. "source": [
  189. "## Load Model and Tokenizer\n",
  190. "\n",
  191. "\n",
  192. "\n",
  193. "\n",
  194. " "
  195. ]
  196. },
  197. {
  198. "cell_type": "code",
  199. "metadata": {
  200. "id": "xyBCYGjAp3ym"
  201. },
  202. "source": [
  203. "model = BertForQuestionAnswering.from_pretrained(\"bert-base-chinese\").to(device)\n",
  204. "tokenizer = BertTokenizerFast.from_pretrained(\"bert-base-chinese\")\n",
  205. "\n",
  206. "# You can safely ignore the warning message (it pops up because new prediction heads for QA are initialized randomly)"
  207. ],
  208. "execution_count": null,
  209. "outputs": []
  210. },
  211. {
  212. "cell_type": "markdown",
  213. "metadata": {
  214. "id": "3Td-GTmk5OW4"
  215. },
  216. "source": [
  217. "## Read Data\n",
  218. "\n",
  219. "- Training set: 31690 QA pairs\n",
  220. "- Dev set: 4131 QA pairs\n",
  221. "- Test set: 4957 QA pairs\n",
  222. "\n",
  223. "- {train/dev/test}_questions:\t\n",
  224. " - List of dicts with the following keys:\n",
  225. " - id (int)\n",
  226. " - paragraph_id (int)\n",
  227. " - question_text (string)\n",
  228. " - answer_text (string)\n",
  229. " - answer_start (int)\n",
  230. " - answer_end (int)\n",
  231. "- {train/dev/test}_paragraphs: \n",
  232. " - List of strings\n",
  233. " - paragraph_ids in questions correspond to indexs in paragraphs\n",
  234. " - A paragraph may be used by several questions "
  235. ]
  236. },
  237. {
  238. "cell_type": "code",
  239. "metadata": {
  240. "id": "NvX7hlepogvu"
  241. },
  242. "source": [
  243. "def read_data(file):\n",
  244. " with open(file, 'r', encoding=\"utf-8\") as reader:\n",
  245. " data = json.load(reader)\n",
  246. " return data[\"questions\"], data[\"paragraphs\"]\n",
  247. "\n",
  248. "train_questions, train_paragraphs = read_data(\"hw7_train.json\")\n",
  249. "dev_questions, dev_paragraphs = read_data(\"hw7_dev.json\")\n",
  250. "test_questions, test_paragraphs = read_data(\"hw7_test.json\")"
  251. ],
  252. "execution_count": null,
  253. "outputs": []
  254. },
  255. {
  256. "cell_type": "markdown",
  257. "metadata": {
  258. "id": "Fm0rpTHq0e4N"
  259. },
  260. "source": [
  261. "## Tokenize Data"
  262. ]
  263. },
  264. {
  265. "cell_type": "code",
  266. "metadata": {
  267. "id": "rTZ6B70Hoxie"
  268. },
  269. "source": [
  270. "# Tokenize questions and paragraphs separately\n",
  271. "# 「add_special_tokens」 is set to False since special tokens will be added when tokenized questions and paragraphs are combined in datset __getitem__ \n",
  272. "\n",
  273. "train_questions_tokenized = tokenizer([train_question[\"question_text\"] for train_question in train_questions], add_special_tokens=False)\n",
  274. "dev_questions_tokenized = tokenizer([dev_question[\"question_text\"] for dev_question in dev_questions], add_special_tokens=False)\n",
  275. "test_questions_tokenized = tokenizer([test_question[\"question_text\"] for test_question in test_questions], add_special_tokens=False) \n",
  276. "\n",
  277. "train_paragraphs_tokenized = tokenizer(train_paragraphs, add_special_tokens=False)\n",
  278. "dev_paragraphs_tokenized = tokenizer(dev_paragraphs, add_special_tokens=False)\n",
  279. "test_paragraphs_tokenized = tokenizer(test_paragraphs, add_special_tokens=False)\n",
  280. "\n",
  281. "# You can safely ignore the warning message as tokenized sequences will be futher processed in datset __getitem__ before passing to model"
  282. ],
  283. "execution_count": null,
  284. "outputs": []
  285. },
  286. {
  287. "cell_type": "markdown",
  288. "metadata": {
  289. "id": "Ws8c8_4d5UCI"
  290. },
  291. "source": [
  292. "## Dataset and Dataloader"
  293. ]
  294. },
  295. {
  296. "cell_type": "code",
  297. "metadata": {
  298. "id": "Xjooag-Swnuh"
  299. },
  300. "source": [
  301. "class QA_Dataset(Dataset):\n",
  302. " def __init__(self, split, questions, tokenized_questions, tokenized_paragraphs):\n",
  303. " self.split = split\n",
  304. " self.questions = questions\n",
  305. " self.tokenized_questions = tokenized_questions\n",
  306. " self.tokenized_paragraphs = tokenized_paragraphs\n",
  307. " self.max_question_len = 40\n",
  308. " self.max_paragraph_len = 150\n",
  309. " \n",
  310. " ##### TODO: Change value of doc_stride #####\n",
  311. " self.doc_stride = 150\n",
  312. "\n",
  313. " # Input sequence length = [CLS] + question + [SEP] + paragraph + [SEP]\n",
  314. " self.max_seq_len = 1 + self.max_question_len + 1 + self.max_paragraph_len + 1\n",
  315. "\n",
  316. " def __len__(self):\n",
  317. " return len(self.questions)\n",
  318. "\n",
  319. " def __getitem__(self, idx):\n",
  320. " question = self.questions[idx]\n",
  321. " tokenized_question = self.tokenized_questions[idx]\n",
  322. " tokenized_paragraph = self.tokenized_paragraphs[question[\"paragraph_id\"]]\n",
  323. "\n",
  324. " ##### TODO: Preprocessing #####\n",
  325. " # Hint: How to prevent model from learning something it should not learn\n",
  326. "\n",
  327. " if self.split == \"train\":\n",
  328. " # Convert answer's start/end positions in paragraph_text to start/end positions in tokenized_paragraph \n",
  329. " answer_start_token = tokenized_paragraph.char_to_token(question[\"answer_start\"])\n",
  330. " answer_end_token = tokenized_paragraph.char_to_token(question[\"answer_end\"])\n",
  331. "\n",
  332. " # A single window is obtained by slicing the portion of paragraph containing the answer\n",
  333. " mid = (answer_start_token + answer_end_token) // 2\n",
  334. " paragraph_start = max(0, min(mid - self.max_paragraph_len // 2, len(tokenized_paragraph) - self.max_paragraph_len))\n",
  335. " paragraph_end = paragraph_start + self.max_paragraph_len\n",
  336. " \n",
  337. " # Slice question/paragraph and add special tokens (101: CLS, 102: SEP)\n",
  338. " input_ids_question = [101] + tokenized_question.ids[:self.max_question_len] + [102] \n",
  339. " input_ids_paragraph = tokenized_paragraph.ids[paragraph_start : paragraph_end] + [102]\t\t\n",
  340. " \n",
  341. " # Convert answer's start/end positions in tokenized_paragraph to start/end positions in the window \n",
  342. " answer_start_token += len(input_ids_question) - paragraph_start\n",
  343. " answer_end_token += len(input_ids_question) - paragraph_start\n",
  344. " \n",
  345. " # Pad sequence and obtain inputs to model \n",
  346. " input_ids, token_type_ids, attention_mask = self.padding(input_ids_question, input_ids_paragraph)\n",
  347. " return torch.tensor(input_ids), torch.tensor(token_type_ids), torch.tensor(attention_mask), answer_start_token, answer_end_token\n",
  348. "\n",
  349. " # Validation/Testing\n",
  350. " else:\n",
  351. " input_ids_list, token_type_ids_list, attention_mask_list = [], [], []\n",
  352. " \n",
  353. " # Paragraph is split into several windows, each with start positions separated by step \"doc_stride\"\n",
  354. " for i in range(0, len(tokenized_paragraph), self.doc_stride):\n",
  355. " \n",
  356. " # Slice question/paragraph and add special tokens (101: CLS, 102: SEP)\n",
  357. " input_ids_question = [101] + tokenized_question.ids[:self.max_question_len] + [102]\n",
  358. " input_ids_paragraph = tokenized_paragraph.ids[i : i + self.max_paragraph_len] + [102]\n",
  359. " \n",
  360. " # Pad sequence and obtain inputs to model\n",
  361. " input_ids, token_type_ids, attention_mask = self.padding(input_ids_question, input_ids_paragraph)\n",
  362. " \n",
  363. " input_ids_list.append(input_ids)\n",
  364. " token_type_ids_list.append(token_type_ids)\n",
  365. " attention_mask_list.append(attention_mask)\n",
  366. " \n",
  367. " return torch.tensor(input_ids_list), torch.tensor(token_type_ids_list), torch.tensor(attention_mask_list)\n",
  368. "\n",
  369. " def padding(self, input_ids_question, input_ids_paragraph):\n",
  370. " # Pad zeros if sequence length is shorter than max_seq_len\n",
  371. " padding_len = self.max_seq_len - len(input_ids_question) - len(input_ids_paragraph)\n",
  372. " # Indices of input sequence tokens in the vocabulary\n",
  373. " input_ids = input_ids_question + input_ids_paragraph + [0] * padding_len\n",
  374. " # Segment token indices to indicate first and second portions of the inputs. Indices are selected in [0, 1]\n",
  375. " token_type_ids = [0] * len(input_ids_question) + [1] * len(input_ids_paragraph) + [0] * padding_len\n",
  376. " # Mask to avoid performing attention on padding token indices. Mask values selected in [0, 1]\n",
  377. " attention_mask = [1] * (len(input_ids_question) + len(input_ids_paragraph)) + [0] * padding_len\n",
  378. " \n",
  379. " return input_ids, token_type_ids, attention_mask\n",
  380. "\n",
  381. "train_set = QA_Dataset(\"train\", train_questions, train_questions_tokenized, train_paragraphs_tokenized)\n",
  382. "dev_set = QA_Dataset(\"dev\", dev_questions, dev_questions_tokenized, dev_paragraphs_tokenized)\n",
  383. "test_set = QA_Dataset(\"test\", test_questions, test_questions_tokenized, test_paragraphs_tokenized)\n",
  384. "\n",
  385. "train_batch_size = 32\n",
  386. "\n",
  387. "# Note: Do NOT change batch size of dev_loader / test_loader !\n",
  388. "# Although batch size=1, it is actually a batch consisting of several windows from the same QA pair\n",
  389. "train_loader = DataLoader(train_set, batch_size=train_batch_size, shuffle=True, pin_memory=True)\n",
  390. "dev_loader = DataLoader(dev_set, batch_size=1, shuffle=False, pin_memory=True)\n",
  391. "test_loader = DataLoader(test_set, batch_size=1, shuffle=False, pin_memory=True)"
  392. ],
  393. "execution_count": null,
  394. "outputs": []
  395. },
  396. {
  397. "cell_type": "markdown",
  398. "metadata": {
  399. "id": "5_H1kqhR8CdM"
  400. },
  401. "source": [
  402. "## Function for Evaluation"
  403. ]
  404. },
  405. {
  406. "cell_type": "code",
  407. "metadata": {
  408. "id": "SqeA3PLPxOHu"
  409. },
  410. "source": [
  411. "def evaluate(data, output):\n",
  412. " ##### TODO: Postprocessing #####\n",
  413. " # There is a bug and room for improvement in postprocessing \n",
  414. " # Hint: Open your prediction file to see what is wrong \n",
  415. " \n",
  416. " answer = ''\n",
  417. " max_prob = float('-inf')\n",
  418. " num_of_windows = data[0].shape[1]\n",
  419. " \n",
  420. " for k in range(num_of_windows):\n",
  421. " # Obtain answer by choosing the most probable start position / end position\n",
  422. " start_prob, start_index = torch.max(output.start_logits[k], dim=0)\n",
  423. " end_prob, end_index = torch.max(output.end_logits[k], dim=0)\n",
  424. " \n",
  425. " # Probability of answer is calculated as sum of start_prob and end_prob\n",
  426. " prob = start_prob + end_prob\n",
  427. " \n",
  428. " # Replace answer if calculated probability is larger than previous windows\n",
  429. " if prob > max_prob:\n",
  430. " max_prob = prob\n",
  431. " # Convert tokens to chars (e.g. [1920, 7032] --> \"大 金\")\n",
  432. " answer = tokenizer.decode(data[0][0][k][start_index : end_index + 1])\n",
  433. " \n",
  434. " # Remove spaces in answer (e.g. \"大 金\" --> \"大金\")\n",
  435. " return answer.replace(' ','')"
  436. ],
  437. "execution_count": null,
  438. "outputs": []
  439. },
  440. {
  441. "cell_type": "markdown",
  442. "metadata": {
  443. "id": "rzHQit6eMnKG"
  444. },
  445. "source": [
  446. "## Training"
  447. ]
  448. },
  449. {
  450. "cell_type": "code",
  451. "metadata": {
  452. "id": "3Q-B6ka7xoCM"
  453. },
  454. "source": [
  455. "num_epoch = 1\n",
  456. "validation = True\n",
  457. "logging_step = 100\n",
  458. "learning_rate = 1e-4\n",
  459. "optimizer = AdamW(model.parameters(), lr=learning_rate)\n",
  460. "\n",
  461. "if fp16_training:\n",
  462. " model, optimizer, train_loader = accelerator.prepare(model, optimizer, train_loader) \n",
  463. "\n",
  464. "model.train()\n",
  465. "\n",
  466. "print(\"Start Training ...\")\n",
  467. "\n",
  468. "for epoch in range(num_epoch):\n",
  469. " step = 1\n",
  470. " train_loss = train_acc = 0\n",
  471. " \n",
  472. " for data in tqdm(train_loader):\t\n",
  473. " # Load all data into GPU\n",
  474. " data = [i.to(device) for i in data]\n",
  475. " \n",
  476. " # Model inputs: input_ids, token_type_ids, attention_mask, start_positions, end_positions (Note: only \"input_ids\" is mandatory)\n",
  477. " # Model outputs: start_logits, end_logits, loss (return when start_positions/end_positions are provided) \n",
  478. " output = model(input_ids=data[0], token_type_ids=data[1], attention_mask=data[2], start_positions=data[3], end_positions=data[4])\n",
  479. "\n",
  480. " # Choose the most probable start position / end position\n",
  481. " start_index = torch.argmax(output.start_logits, dim=1)\n",
  482. " end_index = torch.argmax(output.end_logits, dim=1)\n",
  483. " \n",
  484. " # Prediction is correct only if both start_index and end_index are correct\n",
  485. " train_acc += ((start_index == data[3]) & (end_index == data[4])).float().mean()\n",
  486. " train_loss += output.loss\n",
  487. " \n",
  488. " if fp16_training:\n",
  489. " accelerator.backward(output.loss)\n",
  490. " else:\n",
  491. " output.loss.backward()\n",
  492. " \n",
  493. " optimizer.step()\n",
  494. " optimizer.zero_grad()\n",
  495. " step += 1\n",
  496. "\n",
  497. " ##### TODO: Apply linear learning rate decay #####\n",
  498. " \n",
  499. " \n",
  500. " # Print training loss and accuracy over past logging step\n",
  501. " if step % logging_step == 0:\n",
  502. " print(f\"Epoch {epoch + 1} | Step {step} | loss = {train_loss.item() / logging_step:.3f}, acc = {train_acc / logging_step:.3f}\")\n",
  503. " train_loss = train_acc = 0\n",
  504. "\n",
  505. " if validation:\n",
  506. " print(\"Evaluating Dev Set ...\")\n",
  507. " model.eval()\n",
  508. " with torch.no_grad():\n",
  509. " dev_acc = 0\n",
  510. " for i, data in enumerate(tqdm(dev_loader)):\n",
  511. " output = model(input_ids=data[0].squeeze(dim=0).to(device), token_type_ids=data[1].squeeze(dim=0).to(device),\n",
  512. " attention_mask=data[2].squeeze(dim=0).to(device))\n",
  513. " # prediction is correct only if answer text exactly matches\n",
  514. " dev_acc += evaluate(data, output) == dev_questions[i][\"answer_text\"]\n",
  515. " print(f\"Validation | Epoch {epoch + 1} | acc = {dev_acc / len(dev_loader):.3f}\")\n",
  516. " model.train()\n",
  517. "\n",
  518. "# Save a model and its configuration file to the directory 「saved_model」 \n",
  519. "# i.e. there are two files under the direcory 「saved_model」: 「pytorch_model.bin」 and 「config.json」\n",
  520. "# Saved model can be re-loaded using 「model = BertForQuestionAnswering.from_pretrained(\"saved_model\")」\n",
  521. "print(\"Saving Model ...\")\n",
  522. "model_save_dir = \"saved_model\" \n",
  523. "model.save_pretrained(model_save_dir)"
  524. ],
  525. "execution_count": null,
  526. "outputs": []
  527. },
  528. {
  529. "cell_type": "markdown",
  530. "metadata": {
  531. "id": "kMmdLOKBMsdE"
  532. },
  533. "source": [
  534. "## Testing"
  535. ]
  536. },
  537. {
  538. "cell_type": "code",
  539. "metadata": {
  540. "id": "U5scNKC9xz0C"
  541. },
  542. "source": [
  543. "print(\"Evaluating Test Set ...\")\n",
  544. "\n",
  545. "result = []\n",
  546. "\n",
  547. "model.eval()\n",
  548. "with torch.no_grad():\n",
  549. " for data in tqdm(test_loader):\n",
  550. " output = model(input_ids=data[0].squeeze(dim=0).to(device), token_type_ids=data[1].squeeze(dim=0).to(device),\n",
  551. " attention_mask=data[2].squeeze(dim=0).to(device))\n",
  552. " result.append(evaluate(data, output))\n",
  553. "\n",
  554. "result_file = \"result.csv\"\n",
  555. "with open(result_file, 'w') as f:\t\n",
  556. "\t f.write(\"ID,Answer\\n\")\n",
  557. "\t for i, test_question in enumerate(test_questions):\n",
  558. " # Replace commas in answers with empty strings (since csv is separated by comma)\n",
  559. " # Answers in kaggle are processed in the same way\n",
  560. "\t\t f.write(f\"{test_question['id']},{result[i].replace(',','')}\\n\")\n",
  561. "\n",
  562. "print(f\"Completed! Result is in {result_file}\")"
  563. ],
  564. "execution_count": null,
  565. "outputs": []
  566. }
  567. ]
  568. }