|
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568 |
- {
- "nbformat": 4,
- "nbformat_minor": 0,
- "metadata": {
- "accelerator": "GPU",
- "colab": {
- "name": "ML2022Spring - HW7.ipynb",
- "provenance": [],
- "collapsed_sections": []
- },
- "kernelspec": {
- "display_name": "Python 3",
- "name": "python3"
- }
- },
- "cells": [
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "xvSGDbExff_I"
- },
- "source": [
- "# **Homework 7 - Bert (Question Answering)**\n",
- "\n",
- "If you have any questions, feel free to email us at mlta-2022-spring@googlegroups.com\n",
- "\n",
- "\n",
- "\n",
- "Slide: [Link](https://docs.google.com/presentation/d/1H5ZONrb2LMOCixLY7D5_5-7LkIaXO6AGEaV2mRdTOMY/edit?usp=sharing) Kaggle: [Link](https://www.kaggle.com/c/ml2022spring-hw7) Data: [Link](https://drive.google.com/uc?id=1AVgZvy3VFeg0fX-6WQJMHPVrx3A-M1kb)\n",
- "\n",
- "\n"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "WGOr_eS3wJJf"
- },
- "source": [
- "## Task description\n",
- "- Chinese Extractive Question Answering\n",
- " - Input: Paragraph + Question\n",
- " - Output: Answer\n",
- "\n",
- "- Objective: Learn how to fine tune a pretrained model on downstream task using transformers\n",
- "\n",
- "- Todo\n",
- " - Fine tune a pretrained chinese BERT model\n",
- " - Change hyperparameters (e.g. doc_stride)\n",
- " - Apply linear learning rate decay\n",
- " - Try other pretrained models\n",
- " - Improve preprocessing\n",
- " - Improve postprocessing\n",
- "- Training tips\n",
- " - Automatic mixed precision\n",
- " - Gradient accumulation\n",
- " - Ensemble\n",
- "\n",
- "- Estimated training time (tesla t4 with automatic mixed precision enabled)\n",
- " - Simple: 8mins\n",
- " - Medium: 8mins\n",
- " - Strong: 25mins\n",
- " - Boss: 2.5hrs\n",
- " "
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "TJ1fSAJE2oaC"
- },
- "source": [
- "## Download Dataset"
- ]
- },
- {
- "cell_type": "code",
- "metadata": {
- "id": "YPrc4Eie9Yo5"
- },
- "source": [
- "# Download link 1\n",
- "!gdown --id '1AVgZvy3VFeg0fX-6WQJMHPVrx3A-M1kb' --output hw7_data.zip\n",
- "\n",
- "# Download Link 2 (if the above link fails) \n",
- "# !gdown --id '1qwjbRjq481lHsnTrrF4OjKQnxzgoLEFR' --output hw7_data.zip\n",
- "\n",
- "# Download Link 3 (if the above link fails) \n",
- "# !gdown --id '1QXuWjNRZH6DscSd6QcRER0cnxmpZvijn' --output hw7_data.zip\n",
- "\n",
- "!unzip -o hw7_data.zip\n",
- "\n",
- "# For this HW, K80 < P4 < T4 < P100 <= T4(fp16) < V100\n",
- "!nvidia-smi"
- ],
- "execution_count": null,
- "outputs": []
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "TevOvhC03m0h"
- },
- "source": [
- "## Install transformers\n",
- "\n",
- "Documentation for the toolkit: https://huggingface.co/transformers/"
- ]
- },
- {
- "cell_type": "code",
- "metadata": {
- "id": "tbxWFX_jpDom"
- },
- "source": [
- "# You are allowed to change version of transformers or use other toolkits\n",
- "!pip install transformers==4.5.0"
- ],
- "execution_count": null,
- "outputs": []
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "8dKM4yCh4LI_"
- },
- "source": [
- "## Import Packages"
- ]
- },
- {
- "cell_type": "code",
- "metadata": {
- "id": "WOTHHtWJoahe"
- },
- "source": [
- "import json\n",
- "import numpy as np\n",
- "import random\n",
- "import torch\n",
- "from torch.utils.data import DataLoader, Dataset \n",
- "from transformers import AdamW, BertForQuestionAnswering, BertTokenizerFast\n",
- "\n",
- "from tqdm.auto import tqdm\n",
- "\n",
- "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
- "\n",
- "# Fix random seed for reproducibility\n",
- "def same_seeds(seed):\n",
- "\t torch.manual_seed(seed)\n",
- "\t if torch.cuda.is_available():\n",
- "\t\t torch.cuda.manual_seed(seed)\n",
- "\t\t torch.cuda.manual_seed_all(seed)\n",
- "\t np.random.seed(seed)\n",
- "\t random.seed(seed)\n",
- "\t torch.backends.cudnn.benchmark = False\n",
- "\t torch.backends.cudnn.deterministic = True\n",
- "same_seeds(0)"
- ],
- "execution_count": null,
- "outputs": []
- },
- {
- "cell_type": "code",
- "metadata": {
- "id": "7pBtSZP1SKQO"
- },
- "source": [
- "# Change \"fp16_training\" to True to support automatic mixed precision training (fp16)\t\n",
- "fp16_training = False\n",
- "\n",
- "if fp16_training:\n",
- " !pip install accelerate==0.2.0\n",
- " from accelerate import Accelerator\n",
- " accelerator = Accelerator(fp16=True)\n",
- " device = accelerator.device\n",
- "\n",
- "# Documentation for the toolkit: https://huggingface.co/docs/accelerate/"
- ],
- "execution_count": null,
- "outputs": []
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "2YgXHuVLp_6j"
- },
- "source": [
- "## Load Model and Tokenizer\n",
- "\n",
- "\n",
- "\n",
- "\n",
- " "
- ]
- },
- {
- "cell_type": "code",
- "metadata": {
- "id": "xyBCYGjAp3ym"
- },
- "source": [
- "model = BertForQuestionAnswering.from_pretrained(\"bert-base-chinese\").to(device)\n",
- "tokenizer = BertTokenizerFast.from_pretrained(\"bert-base-chinese\")\n",
- "\n",
- "# You can safely ignore the warning message (it pops up because new prediction heads for QA are initialized randomly)"
- ],
- "execution_count": null,
- "outputs": []
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "3Td-GTmk5OW4"
- },
- "source": [
- "## Read Data\n",
- "\n",
- "- Training set: 31690 QA pairs\n",
- "- Dev set: 4131 QA pairs\n",
- "- Test set: 4957 QA pairs\n",
- "\n",
- "- {train/dev/test}_questions:\t\n",
- " - List of dicts with the following keys:\n",
- " - id (int)\n",
- " - paragraph_id (int)\n",
- " - question_text (string)\n",
- " - answer_text (string)\n",
- " - answer_start (int)\n",
- " - answer_end (int)\n",
- "- {train/dev/test}_paragraphs: \n",
- " - List of strings\n",
- " - paragraph_ids in questions correspond to indexs in paragraphs\n",
- " - A paragraph may be used by several questions "
- ]
- },
- {
- "cell_type": "code",
- "metadata": {
- "id": "NvX7hlepogvu"
- },
- "source": [
- "def read_data(file):\n",
- " with open(file, 'r', encoding=\"utf-8\") as reader:\n",
- " data = json.load(reader)\n",
- " return data[\"questions\"], data[\"paragraphs\"]\n",
- "\n",
- "train_questions, train_paragraphs = read_data(\"hw7_train.json\")\n",
- "dev_questions, dev_paragraphs = read_data(\"hw7_dev.json\")\n",
- "test_questions, test_paragraphs = read_data(\"hw7_test.json\")"
- ],
- "execution_count": null,
- "outputs": []
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "Fm0rpTHq0e4N"
- },
- "source": [
- "## Tokenize Data"
- ]
- },
- {
- "cell_type": "code",
- "metadata": {
- "id": "rTZ6B70Hoxie"
- },
- "source": [
- "# Tokenize questions and paragraphs separately\n",
- "# 「add_special_tokens」 is set to False since special tokens will be added when tokenized questions and paragraphs are combined in datset __getitem__ \n",
- "\n",
- "train_questions_tokenized = tokenizer([train_question[\"question_text\"] for train_question in train_questions], add_special_tokens=False)\n",
- "dev_questions_tokenized = tokenizer([dev_question[\"question_text\"] for dev_question in dev_questions], add_special_tokens=False)\n",
- "test_questions_tokenized = tokenizer([test_question[\"question_text\"] for test_question in test_questions], add_special_tokens=False) \n",
- "\n",
- "train_paragraphs_tokenized = tokenizer(train_paragraphs, add_special_tokens=False)\n",
- "dev_paragraphs_tokenized = tokenizer(dev_paragraphs, add_special_tokens=False)\n",
- "test_paragraphs_tokenized = tokenizer(test_paragraphs, add_special_tokens=False)\n",
- "\n",
- "# You can safely ignore the warning message as tokenized sequences will be futher processed in datset __getitem__ before passing to model"
- ],
- "execution_count": null,
- "outputs": []
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "Ws8c8_4d5UCI"
- },
- "source": [
- "## Dataset and Dataloader"
- ]
- },
- {
- "cell_type": "code",
- "metadata": {
- "id": "Xjooag-Swnuh"
- },
- "source": [
- "class QA_Dataset(Dataset):\n",
- " def __init__(self, split, questions, tokenized_questions, tokenized_paragraphs):\n",
- " self.split = split\n",
- " self.questions = questions\n",
- " self.tokenized_questions = tokenized_questions\n",
- " self.tokenized_paragraphs = tokenized_paragraphs\n",
- " self.max_question_len = 40\n",
- " self.max_paragraph_len = 150\n",
- " \n",
- " ##### TODO: Change value of doc_stride #####\n",
- " self.doc_stride = 150\n",
- "\n",
- " # Input sequence length = [CLS] + question + [SEP] + paragraph + [SEP]\n",
- " self.max_seq_len = 1 + self.max_question_len + 1 + self.max_paragraph_len + 1\n",
- "\n",
- " def __len__(self):\n",
- " return len(self.questions)\n",
- "\n",
- " def __getitem__(self, idx):\n",
- " question = self.questions[idx]\n",
- " tokenized_question = self.tokenized_questions[idx]\n",
- " tokenized_paragraph = self.tokenized_paragraphs[question[\"paragraph_id\"]]\n",
- "\n",
- " ##### TODO: Preprocessing #####\n",
- " # Hint: How to prevent model from learning something it should not learn\n",
- "\n",
- " if self.split == \"train\":\n",
- " # Convert answer's start/end positions in paragraph_text to start/end positions in tokenized_paragraph \n",
- " answer_start_token = tokenized_paragraph.char_to_token(question[\"answer_start\"])\n",
- " answer_end_token = tokenized_paragraph.char_to_token(question[\"answer_end\"])\n",
- "\n",
- " # A single window is obtained by slicing the portion of paragraph containing the answer\n",
- " mid = (answer_start_token + answer_end_token) // 2\n",
- " paragraph_start = max(0, min(mid - self.max_paragraph_len // 2, len(tokenized_paragraph) - self.max_paragraph_len))\n",
- " paragraph_end = paragraph_start + self.max_paragraph_len\n",
- " \n",
- " # Slice question/paragraph and add special tokens (101: CLS, 102: SEP)\n",
- " input_ids_question = [101] + tokenized_question.ids[:self.max_question_len] + [102] \n",
- " input_ids_paragraph = tokenized_paragraph.ids[paragraph_start : paragraph_end] + [102]\t\t\n",
- " \n",
- " # Convert answer's start/end positions in tokenized_paragraph to start/end positions in the window \n",
- " answer_start_token += len(input_ids_question) - paragraph_start\n",
- " answer_end_token += len(input_ids_question) - paragraph_start\n",
- " \n",
- " # Pad sequence and obtain inputs to model \n",
- " input_ids, token_type_ids, attention_mask = self.padding(input_ids_question, input_ids_paragraph)\n",
- " return torch.tensor(input_ids), torch.tensor(token_type_ids), torch.tensor(attention_mask), answer_start_token, answer_end_token\n",
- "\n",
- " # Validation/Testing\n",
- " else:\n",
- " input_ids_list, token_type_ids_list, attention_mask_list = [], [], []\n",
- " \n",
- " # Paragraph is split into several windows, each with start positions separated by step \"doc_stride\"\n",
- " for i in range(0, len(tokenized_paragraph), self.doc_stride):\n",
- " \n",
- " # Slice question/paragraph and add special tokens (101: CLS, 102: SEP)\n",
- " input_ids_question = [101] + tokenized_question.ids[:self.max_question_len] + [102]\n",
- " input_ids_paragraph = tokenized_paragraph.ids[i : i + self.max_paragraph_len] + [102]\n",
- " \n",
- " # Pad sequence and obtain inputs to model\n",
- " input_ids, token_type_ids, attention_mask = self.padding(input_ids_question, input_ids_paragraph)\n",
- " \n",
- " input_ids_list.append(input_ids)\n",
- " token_type_ids_list.append(token_type_ids)\n",
- " attention_mask_list.append(attention_mask)\n",
- " \n",
- " return torch.tensor(input_ids_list), torch.tensor(token_type_ids_list), torch.tensor(attention_mask_list)\n",
- "\n",
- " def padding(self, input_ids_question, input_ids_paragraph):\n",
- " # Pad zeros if sequence length is shorter than max_seq_len\n",
- " padding_len = self.max_seq_len - len(input_ids_question) - len(input_ids_paragraph)\n",
- " # Indices of input sequence tokens in the vocabulary\n",
- " input_ids = input_ids_question + input_ids_paragraph + [0] * padding_len\n",
- " # Segment token indices to indicate first and second portions of the inputs. Indices are selected in [0, 1]\n",
- " token_type_ids = [0] * len(input_ids_question) + [1] * len(input_ids_paragraph) + [0] * padding_len\n",
- " # Mask to avoid performing attention on padding token indices. Mask values selected in [0, 1]\n",
- " attention_mask = [1] * (len(input_ids_question) + len(input_ids_paragraph)) + [0] * padding_len\n",
- " \n",
- " return input_ids, token_type_ids, attention_mask\n",
- "\n",
- "train_set = QA_Dataset(\"train\", train_questions, train_questions_tokenized, train_paragraphs_tokenized)\n",
- "dev_set = QA_Dataset(\"dev\", dev_questions, dev_questions_tokenized, dev_paragraphs_tokenized)\n",
- "test_set = QA_Dataset(\"test\", test_questions, test_questions_tokenized, test_paragraphs_tokenized)\n",
- "\n",
- "train_batch_size = 32\n",
- "\n",
- "# Note: Do NOT change batch size of dev_loader / test_loader !\n",
- "# Although batch size=1, it is actually a batch consisting of several windows from the same QA pair\n",
- "train_loader = DataLoader(train_set, batch_size=train_batch_size, shuffle=True, pin_memory=True)\n",
- "dev_loader = DataLoader(dev_set, batch_size=1, shuffle=False, pin_memory=True)\n",
- "test_loader = DataLoader(test_set, batch_size=1, shuffle=False, pin_memory=True)"
- ],
- "execution_count": null,
- "outputs": []
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "5_H1kqhR8CdM"
- },
- "source": [
- "## Function for Evaluation"
- ]
- },
- {
- "cell_type": "code",
- "metadata": {
- "id": "SqeA3PLPxOHu"
- },
- "source": [
- "def evaluate(data, output):\n",
- " ##### TODO: Postprocessing #####\n",
- " # There is a bug and room for improvement in postprocessing \n",
- " # Hint: Open your prediction file to see what is wrong \n",
- " \n",
- " answer = ''\n",
- " max_prob = float('-inf')\n",
- " num_of_windows = data[0].shape[1]\n",
- " \n",
- " for k in range(num_of_windows):\n",
- " # Obtain answer by choosing the most probable start position / end position\n",
- " start_prob, start_index = torch.max(output.start_logits[k], dim=0)\n",
- " end_prob, end_index = torch.max(output.end_logits[k], dim=0)\n",
- " \n",
- " # Probability of answer is calculated as sum of start_prob and end_prob\n",
- " prob = start_prob + end_prob\n",
- " \n",
- " # Replace answer if calculated probability is larger than previous windows\n",
- " if prob > max_prob:\n",
- " max_prob = prob\n",
- " # Convert tokens to chars (e.g. [1920, 7032] --> \"大 金\")\n",
- " answer = tokenizer.decode(data[0][0][k][start_index : end_index + 1])\n",
- " \n",
- " # Remove spaces in answer (e.g. \"大 金\" --> \"大金\")\n",
- " return answer.replace(' ','')"
- ],
- "execution_count": null,
- "outputs": []
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "rzHQit6eMnKG"
- },
- "source": [
- "## Training"
- ]
- },
- {
- "cell_type": "code",
- "metadata": {
- "id": "3Q-B6ka7xoCM"
- },
- "source": [
- "num_epoch = 1\n",
- "validation = True\n",
- "logging_step = 100\n",
- "learning_rate = 1e-4\n",
- "optimizer = AdamW(model.parameters(), lr=learning_rate)\n",
- "\n",
- "if fp16_training:\n",
- " model, optimizer, train_loader = accelerator.prepare(model, optimizer, train_loader) \n",
- "\n",
- "model.train()\n",
- "\n",
- "print(\"Start Training ...\")\n",
- "\n",
- "for epoch in range(num_epoch):\n",
- " step = 1\n",
- " train_loss = train_acc = 0\n",
- " \n",
- " for data in tqdm(train_loader):\t\n",
- " # Load all data into GPU\n",
- " data = [i.to(device) for i in data]\n",
- " \n",
- " # Model inputs: input_ids, token_type_ids, attention_mask, start_positions, end_positions (Note: only \"input_ids\" is mandatory)\n",
- " # Model outputs: start_logits, end_logits, loss (return when start_positions/end_positions are provided) \n",
- " output = model(input_ids=data[0], token_type_ids=data[1], attention_mask=data[2], start_positions=data[3], end_positions=data[4])\n",
- "\n",
- " # Choose the most probable start position / end position\n",
- " start_index = torch.argmax(output.start_logits, dim=1)\n",
- " end_index = torch.argmax(output.end_logits, dim=1)\n",
- " \n",
- " # Prediction is correct only if both start_index and end_index are correct\n",
- " train_acc += ((start_index == data[3]) & (end_index == data[4])).float().mean()\n",
- " train_loss += output.loss\n",
- " \n",
- " if fp16_training:\n",
- " accelerator.backward(output.loss)\n",
- " else:\n",
- " output.loss.backward()\n",
- " \n",
- " optimizer.step()\n",
- " optimizer.zero_grad()\n",
- " step += 1\n",
- "\n",
- " ##### TODO: Apply linear learning rate decay #####\n",
- " \n",
- " \n",
- " # Print training loss and accuracy over past logging step\n",
- " if step % logging_step == 0:\n",
- " print(f\"Epoch {epoch + 1} | Step {step} | loss = {train_loss.item() / logging_step:.3f}, acc = {train_acc / logging_step:.3f}\")\n",
- " train_loss = train_acc = 0\n",
- "\n",
- " if validation:\n",
- " print(\"Evaluating Dev Set ...\")\n",
- " model.eval()\n",
- " with torch.no_grad():\n",
- " dev_acc = 0\n",
- " for i, data in enumerate(tqdm(dev_loader)):\n",
- " output = model(input_ids=data[0].squeeze(dim=0).to(device), token_type_ids=data[1].squeeze(dim=0).to(device),\n",
- " attention_mask=data[2].squeeze(dim=0).to(device))\n",
- " # prediction is correct only if answer text exactly matches\n",
- " dev_acc += evaluate(data, output) == dev_questions[i][\"answer_text\"]\n",
- " print(f\"Validation | Epoch {epoch + 1} | acc = {dev_acc / len(dev_loader):.3f}\")\n",
- " model.train()\n",
- "\n",
- "# Save a model and its configuration file to the directory 「saved_model」 \n",
- "# i.e. there are two files under the direcory 「saved_model」: 「pytorch_model.bin」 and 「config.json」\n",
- "# Saved model can be re-loaded using 「model = BertForQuestionAnswering.from_pretrained(\"saved_model\")」\n",
- "print(\"Saving Model ...\")\n",
- "model_save_dir = \"saved_model\" \n",
- "model.save_pretrained(model_save_dir)"
- ],
- "execution_count": null,
- "outputs": []
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "kMmdLOKBMsdE"
- },
- "source": [
- "## Testing"
- ]
- },
- {
- "cell_type": "code",
- "metadata": {
- "id": "U5scNKC9xz0C"
- },
- "source": [
- "print(\"Evaluating Test Set ...\")\n",
- "\n",
- "result = []\n",
- "\n",
- "model.eval()\n",
- "with torch.no_grad():\n",
- " for data in tqdm(test_loader):\n",
- " output = model(input_ids=data[0].squeeze(dim=0).to(device), token_type_ids=data[1].squeeze(dim=0).to(device),\n",
- " attention_mask=data[2].squeeze(dim=0).to(device))\n",
- " result.append(evaluate(data, output))\n",
- "\n",
- "result_file = \"result.csv\"\n",
- "with open(result_file, 'w') as f:\t\n",
- "\t f.write(\"ID,Answer\\n\")\n",
- "\t for i, test_question in enumerate(test_questions):\n",
- " # Replace commas in answers with empty strings (since csv is separated by comma)\n",
- " # Answers in kaggle are processed in the same way\n",
- "\t\t f.write(f\"{test_question['id']},{result[i].replace(',','')}\\n\")\n",
- "\n",
- "print(f\"Completed! Result is in {result_file}\")"
- ],
- "execution_count": null,
- "outputs": []
- }
- ]
- }
|