|
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821182218231824182518261827182818291830183118321833183418351836183718381839184018411842184318441845184618471848184918501851185218531854185518561857185818591860186118621863186418651866186718681869187018711872187318741875187618771878187918801881188218831884188518861887188818891890189118921893189418951896189718981899190019011902190319041905190619071908190919101911191219131914191519161917191819191920192119221923192419251926192719281929193019311932193319341935193619371938193919401941194219431944194519461947194819491950195119521953195419551956195719581959196019611962196319641965196619671968196919701971197219731974197519761977197819791980198119821983198419851986198719881989199019911992199319941995199619971998199920002001200220032004200520062007200820092010201120122013201420152016201720182019202020212022202320242025202620272028202920302031203220332034203520362037203820392040204120422043204420452046204720482049205020512052205320542055205620572058205920602061206220632064206520662067206820692070207120722073207420752076207720782079208020812082208320842085208620872088208920902091209220932094209520962097209820992100210121022103210421052106210721082109211021112112211321142115211621172118211921202121212221232124212521262127212821292130213121322133213421352136213721382139214021412142214321442145214621472148214921502151215221532154215521562157215821592160216121622163216421652166216721682169217021712172217321742175217621772178217921802181218221832184218521862187218821892190219121922193219421952196219721982199220022012202220322042205220622072208220922102211221222132214221522162217221822192220222122222223222422252226222722282229223022312232223322342235223622372238223922402241224222432244224522462247224822492250225122522253225422552256225722582259226022612262226322642265226622672268226922702271227222732274227522762277227822792280228122822283228422852286228722882289229022912292229322942295229622972298229923002301230223032304230523062307230823092310231123122313231423152316231723182319232023212322232323242325232623272328232923302331233223332334233523362337 |
- {
- "nbformat": 4,
- "nbformat_minor": 0,
- "metadata": {
- "colab": {
- "name": "HW05.ipynb",
- "provenance": [],
- "collapsed_sections": [
- "nKb4u67-sT_Z",
- "n1rwQysTsdJq",
- "59si_C0Wsms7",
- "oOpG4EBRLwe_",
- "6ZlE_1JnMv56",
- "UDAPmxjRNEEL",
- "ce5n4eS7NQNy",
- "rUB9f1WCNgMH",
- "VFJlkOMONsc6",
- "Gt1lX3DRO_yU",
- "BAGMiun8PnZy",
- "JOVQRHzGQU4-",
- "jegH0bvMQVmR",
- "a65glBVXQZiE",
- "smA0JraEQdxz",
- "Jn4XeawpQjLk"
- ]
- },
- "kernelspec": {
- "name": "python3",
- "display_name": "Python 3"
- },
- "language_info": {
- "name": "python"
- },
- "accelerator": "GPU"
- },
- "cells": [
- {
- "cell_type": "markdown",
- "source": [
- "# Homework Description\n",
- "- English to Chinese (Traditional) Translation\n",
- " - Input: an English sentence (e.g.\t\ttom is a student .)\n",
- " - Output: the Chinese translation (e.g. \t\t湯姆 是 個 學生 。)\n",
- "\n",
- "- TODO\n",
- " - Train a simple RNN seq2seq to acheive translation\n",
- " - Switch to transformer model to boost performance\n",
- " - Apply Back-translation to furthur boost performance"
- ],
- "metadata": {
- "id": "AFEKWoh3p1Mv"
- }
- },
- {
- "cell_type": "code",
- "source": [
- "!nvidia-smi"
- ],
- "metadata": {
- "id": "3Vf1Q79XPQ3D"
- },
- "execution_count": null,
- "outputs": []
- },
- {
- "cell_type": "markdown",
- "source": [
- "# Download and import required packages"
- ],
- "metadata": {
- "id": "59neB_Sxp5Ub"
- }
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "id": "rRlFbfFRpZYT"
- },
- "outputs": [],
- "source": [
- "!pip install 'torch>=1.6.0' editdistance matplotlib sacrebleu sacremoses sentencepiece tqdm wandb\n",
- "!pip install --upgrade jupyter ipywidgets"
- ]
- },
- {
- "cell_type": "code",
- "source": [
- "!git clone https://github.com/pytorch/fairseq.git\n",
- "!cd fairseq && git checkout 9a1c497\n",
- "!pip install --upgrade ./fairseq/"
- ],
- "metadata": {
- "id": "fSksMTdmp-Wt"
- },
- "execution_count": null,
- "outputs": []
- },
- {
- "cell_type": "code",
- "source": [
- "import sys\n",
- "import pdb\n",
- "import pprint\n",
- "import logging\n",
- "import os\n",
- "import random\n",
- "\n",
- "import torch\n",
- "import torch.nn as nn\n",
- "import torch.nn.functional as F\n",
- "from torch.utils import data\n",
- "import numpy as np\n",
- "import tqdm.auto as tqdm\n",
- "from pathlib import Path\n",
- "from argparse import Namespace\n",
- "from fairseq import utils\n",
- "\n",
- "import matplotlib.pyplot as plt"
- ],
- "metadata": {
- "id": "uRLTiuIuqGNc"
- },
- "execution_count": null,
- "outputs": []
- },
- {
- "cell_type": "markdown",
- "source": [
- "# Fix random seed"
- ],
- "metadata": {
- "id": "0n07Za1XqJzA"
- }
- },
- {
- "cell_type": "code",
- "source": [
- "seed = 73\n",
- "random.seed(seed)\n",
- "torch.manual_seed(seed)\n",
- "if torch.cuda.is_available():\n",
- " torch.cuda.manual_seed(seed)\n",
- " torch.cuda.manual_seed_all(seed) \n",
- "np.random.seed(seed) \n",
- "torch.backends.cudnn.benchmark = False\n",
- "torch.backends.cudnn.deterministic = True"
- ],
- "metadata": {
- "id": "xllxxyWxqI7s"
- },
- "execution_count": null,
- "outputs": []
- },
- {
- "cell_type": "markdown",
- "source": [
- "# Dataset\n",
- "\n",
- "## En-Zh Bilingual Parallel Corpus\n",
- "* [TED2020](#reimers-2020-multilingual-sentence-bert)\n",
- " - Raw: 398,066 (sentences) \n",
- " - Processed: 393,980 (sentences)\n",
- " \n",
- "\n",
- "## Testdata\n",
- "- Size: 4,000 (sentences)\n",
- "- **Chinese translation is undisclosed. The provided (.zh) file is psuedo translation, each line is a '。'**"
- ],
- "metadata": {
- "id": "N5ORDJ-2qdYw"
- }
- },
- {
- "cell_type": "markdown",
- "source": [
- "## Dataset Download"
- ],
- "metadata": {
- "id": "GQw2mY4Dqkzd"
- }
- },
- {
- "cell_type": "code",
- "source": [
- "data_dir = './DATA/rawdata'\n",
- "dataset_name = 'ted2020'\n",
- "urls = (\n",
- " \"https://github.com/yuhsinchan/ML2022-HW5Dataset/releases/download/v1.0.2/ted2020.tgz\",\n",
- " \"https://github.com/yuhsinchan/ML2022-HW5Dataset/releases/download/v1.0.2/test.tgz\",\n",
- ")\n",
- "file_names = (\n",
- " 'ted2020.tgz', # train & dev\n",
- " 'test.tgz', # test\n",
- ")\n",
- "prefix = Path(data_dir).absolute() / dataset_name\n",
- "\n",
- "prefix.mkdir(parents=True, exist_ok=True)\n",
- "for u, f in zip(urls, file_names):\n",
- " path = prefix/f\n",
- " if not path.exists():\n",
- " !wget {u} -O {path}\n",
- " if path.suffix == \".tgz\":\n",
- " !tar -xvf {path} -C {prefix}\n",
- " elif path.suffix == \".zip\":\n",
- " !unzip -o {path} -d {prefix}\n",
- "!mv {prefix/'raw.en'} {prefix/'train_dev.raw.en'}\n",
- "!mv {prefix/'raw.zh'} {prefix/'train_dev.raw.zh'}\n",
- "!mv {prefix/'test/test.en'} {prefix/'test.raw.en'}\n",
- "!mv {prefix/'test/test.zh'} {prefix/'test.raw.zh'}\n",
- "!rm -rf {prefix/'test'}"
- ],
- "metadata": {
- "id": "SXT42xQtqijD"
- },
- "execution_count": null,
- "outputs": []
- },
- {
- "cell_type": "markdown",
- "source": [
- "## Language"
- ],
- "metadata": {
- "id": "YLkJwNiFrIwZ"
- }
- },
- {
- "cell_type": "code",
- "source": [
- "src_lang = 'en'\n",
- "tgt_lang = 'zh'\n",
- "\n",
- "data_prefix = f'{prefix}/train_dev.raw'\n",
- "test_prefix = f'{prefix}/test.raw'"
- ],
- "metadata": {
- "id": "_uJYkCncrKJb"
- },
- "execution_count": null,
- "outputs": []
- },
- {
- "cell_type": "code",
- "source": [
- "!head {data_prefix+'.'+src_lang} -n 5\n",
- "!head {data_prefix+'.'+tgt_lang} -n 5"
- ],
- "metadata": {
- "id": "0t2CPt1brOT3"
- },
- "execution_count": null,
- "outputs": []
- },
- {
- "cell_type": "markdown",
- "source": [
- "## Preprocess files"
- ],
- "metadata": {
- "id": "pRoE9UK7r1gY"
- }
- },
- {
- "cell_type": "code",
- "source": [
- "import re\n",
- "\n",
- "def strQ2B(ustring):\n",
- " \"\"\"Full width -> half width\"\"\"\n",
- " # reference:https://ithelp.ithome.com.tw/articles/10233122\n",
- " ss = []\n",
- " for s in ustring:\n",
- " rstring = \"\"\n",
- " for uchar in s:\n",
- " inside_code = ord(uchar)\n",
- " if inside_code == 12288: # Full width space: direct conversion\n",
- " inside_code = 32\n",
- " elif (inside_code >= 65281 and inside_code <= 65374): # Full width chars (except space) conversion\n",
- " inside_code -= 65248\n",
- " rstring += chr(inside_code)\n",
- " ss.append(rstring)\n",
- " return ''.join(ss)\n",
- " \n",
- "def clean_s(s, lang):\n",
- " if lang == 'en':\n",
- " s = re.sub(r\"\\([^()]*\\)\", \"\", s) # remove ([text])\n",
- " s = s.replace('-', '') # remove '-'\n",
- " s = re.sub('([.,;!?()\\\"])', r' \\1 ', s) # keep punctuation\n",
- " elif lang == 'zh':\n",
- " s = strQ2B(s) # Q2B\n",
- " s = re.sub(r\"\\([^()]*\\)\", \"\", s) # remove ([text])\n",
- " s = s.replace(' ', '')\n",
- " s = s.replace('—', '')\n",
- " s = s.replace('“', '\"')\n",
- " s = s.replace('”', '\"')\n",
- " s = s.replace('_', '')\n",
- " s = re.sub('([。,;!?()\\\"~「」])', r' \\1 ', s) # keep punctuation\n",
- " s = ' '.join(s.strip().split())\n",
- " return s\n",
- "\n",
- "def len_s(s, lang):\n",
- " if lang == 'zh':\n",
- " return len(s)\n",
- " return len(s.split())\n",
- "\n",
- "def clean_corpus(prefix, l1, l2, ratio=9, max_len=1000, min_len=1):\n",
- " if Path(f'{prefix}.clean.{l1}').exists() and Path(f'{prefix}.clean.{l2}').exists():\n",
- " print(f'{prefix}.clean.{l1} & {l2} exists. skipping clean.')\n",
- " return\n",
- " with open(f'{prefix}.{l1}', 'r') as l1_in_f:\n",
- " with open(f'{prefix}.{l2}', 'r') as l2_in_f:\n",
- " with open(f'{prefix}.clean.{l1}', 'w') as l1_out_f:\n",
- " with open(f'{prefix}.clean.{l2}', 'w') as l2_out_f:\n",
- " for s1 in l1_in_f:\n",
- " s1 = s1.strip()\n",
- " s2 = l2_in_f.readline().strip()\n",
- " s1 = clean_s(s1, l1)\n",
- " s2 = clean_s(s2, l2)\n",
- " s1_len = len_s(s1, l1)\n",
- " s2_len = len_s(s2, l2)\n",
- " if min_len > 0: # remove short sentence\n",
- " if s1_len < min_len or s2_len < min_len:\n",
- " continue\n",
- " if max_len > 0: # remove long sentence\n",
- " if s1_len > max_len or s2_len > max_len:\n",
- " continue\n",
- " if ratio > 0: # remove by ratio of length\n",
- " if s1_len/s2_len > ratio or s2_len/s1_len > ratio:\n",
- " continue\n",
- " print(s1, file=l1_out_f)\n",
- " print(s2, file=l2_out_f)"
- ],
- "metadata": {
- "id": "3tzFwtnFrle3"
- },
- "execution_count": null,
- "outputs": []
- },
- {
- "cell_type": "code",
- "source": [
- "clean_corpus(data_prefix, src_lang, tgt_lang)\n",
- "clean_corpus(test_prefix, src_lang, tgt_lang, ratio=-1, min_len=-1, max_len=-1)"
- ],
- "metadata": {
- "id": "h_i8b1PRr9Nf"
- },
- "execution_count": null,
- "outputs": []
- },
- {
- "cell_type": "code",
- "source": [
- "!head {data_prefix+'.clean.'+src_lang} -n 5\n",
- "!head {data_prefix+'.clean.'+tgt_lang} -n 5"
- ],
- "metadata": {
- "id": "gjT3XCy9r_rj"
- },
- "execution_count": null,
- "outputs": []
- },
- {
- "cell_type": "markdown",
- "source": [
- "## Split into train/valid"
- ],
- "metadata": {
- "id": "nKb4u67-sT_Z"
- }
- },
- {
- "cell_type": "code",
- "source": [
- "valid_ratio = 0.01 # 3000~4000 would suffice\n",
- "train_ratio = 1 - valid_ratio"
- ],
- "metadata": {
- "id": "AuFKeDz3sGHL"
- },
- "execution_count": null,
- "outputs": []
- },
- {
- "cell_type": "code",
- "source": [
- "if (prefix/f'train.clean.{src_lang}').exists() \\\n",
- "and (prefix/f'train.clean.{tgt_lang}').exists() \\\n",
- "and (prefix/f'valid.clean.{src_lang}').exists() \\\n",
- "and (prefix/f'valid.clean.{tgt_lang}').exists():\n",
- " print(f'train/valid splits exists. skipping split.')\n",
- "else:\n",
- " line_num = sum(1 for line in open(f'{data_prefix}.clean.{src_lang}'))\n",
- " labels = list(range(line_num))\n",
- " random.shuffle(labels)\n",
- " for lang in [src_lang, tgt_lang]:\n",
- " train_f = open(os.path.join(data_dir, dataset_name, f'train.clean.{lang}'), 'w')\n",
- " valid_f = open(os.path.join(data_dir, dataset_name, f'valid.clean.{lang}'), 'w')\n",
- " count = 0\n",
- " for line in open(f'{data_prefix}.clean.{lang}', 'r'):\n",
- " if labels[count]/line_num < train_ratio:\n",
- " train_f.write(line)\n",
- " else:\n",
- " valid_f.write(line)\n",
- " count += 1\n",
- " train_f.close()\n",
- " valid_f.close()"
- ],
- "metadata": {
- "id": "QR2NVldqsXyY"
- },
- "execution_count": null,
- "outputs": []
- },
- {
- "cell_type": "markdown",
- "source": [
- "## Subword Units \n",
- "Out of vocabulary (OOV) has been a major problem in machine translation. This can be alleviated by using subword units.\n",
- "- We will use the [sentencepiece](#kudo-richardson-2018-sentencepiece) package\n",
- "- select 'unigram' or 'byte-pair encoding (BPE)' algorithm"
- ],
- "metadata": {
- "id": "n1rwQysTsdJq"
- }
- },
- {
- "cell_type": "code",
- "source": [
- "import sentencepiece as spm\n",
- "vocab_size = 8000\n",
- "if (prefix/f'spm{vocab_size}.model').exists():\n",
- " print(f'{prefix}/spm{vocab_size}.model exists. skipping spm_train.')\n",
- "else:\n",
- " spm.SentencePieceTrainer.train(\n",
- " input=','.join([f'{prefix}/train.clean.{src_lang}',\n",
- " f'{prefix}/valid.clean.{src_lang}',\n",
- " f'{prefix}/train.clean.{tgt_lang}',\n",
- " f'{prefix}/valid.clean.{tgt_lang}']),\n",
- " model_prefix=prefix/f'spm{vocab_size}',\n",
- " vocab_size=vocab_size,\n",
- " character_coverage=1,\n",
- " model_type='unigram', # 'bpe' works as well\n",
- " input_sentence_size=1e6,\n",
- " shuffle_input_sentence=True,\n",
- " normalization_rule_name='nmt_nfkc_cf',\n",
- " )"
- ],
- "metadata": {
- "id": "Ecwllsa7sZRA"
- },
- "execution_count": null,
- "outputs": []
- },
- {
- "cell_type": "code",
- "source": [
- "spm_model = spm.SentencePieceProcessor(model_file=str(prefix/f'spm{vocab_size}.model'))\n",
- "in_tag = {\n",
- " 'train': 'train.clean',\n",
- " 'valid': 'valid.clean',\n",
- " 'test': 'test.raw.clean',\n",
- "}\n",
- "for split in ['train', 'valid', 'test']:\n",
- " for lang in [src_lang, tgt_lang]:\n",
- " out_path = prefix/f'{split}.{lang}'\n",
- " if out_path.exists():\n",
- " print(f\"{out_path} exists. skipping spm_encode.\")\n",
- " else:\n",
- " with open(prefix/f'{split}.{lang}', 'w') as out_f:\n",
- " with open(prefix/f'{in_tag[split]}.{lang}', 'r') as in_f:\n",
- " for line in in_f:\n",
- " line = line.strip()\n",
- " tok = spm_model.encode(line, out_type=str)\n",
- " print(' '.join(tok), file=out_f)"
- ],
- "metadata": {
- "id": "lQPRNldqse_V"
- },
- "execution_count": null,
- "outputs": []
- },
- {
- "cell_type": "code",
- "source": [
- "!head {data_dir+'/'+dataset_name+'/train.'+src_lang} -n 5\n",
- "!head {data_dir+'/'+dataset_name+'/train.'+tgt_lang} -n 5"
- ],
- "metadata": {
- "id": "4j6lXHjAsjXa"
- },
- "execution_count": null,
- "outputs": []
- },
- {
- "cell_type": "markdown",
- "source": [
- "## Binarize the data with fairseq"
- ],
- "metadata": {
- "id": "59si_C0Wsms7"
- }
- },
- {
- "cell_type": "code",
- "source": [
- "binpath = Path('./DATA/data-bin', dataset_name)\n",
- "if binpath.exists():\n",
- " print(binpath, \"exists, will not overwrite!\")\n",
- "else:\n",
- " !python -m fairseq_cli.preprocess \\\n",
- " --source-lang {src_lang}\\\n",
- " --target-lang {tgt_lang}\\\n",
- " --trainpref {prefix/'train'}\\\n",
- " --validpref {prefix/'valid'}\\\n",
- " --testpref {prefix/'test'}\\\n",
- " --destdir {binpath}\\\n",
- " --joined-dictionary\\\n",
- " --workers 2"
- ],
- "metadata": {
- "id": "w-cHVLSpsknh"
- },
- "execution_count": null,
- "outputs": []
- },
- {
- "cell_type": "markdown",
- "source": [
- "# Configuration for experiments"
- ],
- "metadata": {
- "id": "szMuH1SWLPWA"
- }
- },
- {
- "cell_type": "code",
- "source": [
- "config = Namespace(\n",
- " datadir = \"./DATA/data-bin/ted2020\",\n",
- " savedir = \"./checkpoints/rnn\",\n",
- " source_lang = \"en\",\n",
- " target_lang = \"zh\",\n",
- " \n",
- " # cpu threads when fetching & processing data.\n",
- " num_workers=2, \n",
- " # batch size in terms of tokens. gradient accumulation increases the effective batchsize.\n",
- " max_tokens=8192,\n",
- " accum_steps=2,\n",
- " \n",
- " # the lr s calculated from Noam lr scheduler. you can tune the maximum lr by this factor.\n",
- " lr_factor=2.,\n",
- " lr_warmup=4000,\n",
- " \n",
- " # clipping gradient norm helps alleviate gradient exploding\n",
- " clip_norm=1.0,\n",
- " \n",
- " # maximum epochs for training\n",
- " max_epoch=15,\n",
- " start_epoch=1,\n",
- " \n",
- " # beam size for beam search\n",
- " beam=5, \n",
- " # generate sequences of maximum length ax + b, where x is the source length\n",
- " max_len_a=1.2, \n",
- " max_len_b=10, \n",
- " # when decoding, post process sentence by removing sentencepiece symbols and jieba tokenization.\n",
- " post_process = \"sentencepiece\",\n",
- " \n",
- " # checkpoints\n",
- " keep_last_epochs=5,\n",
- " resume=None, # if resume from checkpoint name (under config.savedir)\n",
- " \n",
- " # logging\n",
- " use_wandb=False,\n",
- ")"
- ],
- "metadata": {
- "id": "5Luz3_tVLUxs"
- },
- "execution_count": null,
- "outputs": []
- },
- {
- "cell_type": "markdown",
- "source": [
- "# Logging\n",
- "- logging package logs ordinary messages\n",
- "- wandb logs the loss, bleu, etc. in the training process"
- ],
- "metadata": {
- "id": "cjrJFvyQLg86"
- }
- },
- {
- "cell_type": "code",
- "source": [
- "logging.basicConfig(\n",
- " format=\"%(asctime)s | %(levelname)s | %(name)s | %(message)s\",\n",
- " datefmt=\"%Y-%m-%d %H:%M:%S\",\n",
- " level=\"INFO\", # \"DEBUG\" \"WARNING\" \"ERROR\"\n",
- " stream=sys.stdout,\n",
- ")\n",
- "proj = \"hw5.seq2seq\"\n",
- "logger = logging.getLogger(proj)\n",
- "if config.use_wandb:\n",
- " import wandb\n",
- " wandb.init(project=proj, name=Path(config.savedir).stem, config=config)"
- ],
- "metadata": {
- "id": "-ZiMyDWALbDk"
- },
- "execution_count": null,
- "outputs": []
- },
- {
- "cell_type": "markdown",
- "source": [
- "# CUDA Environments"
- ],
- "metadata": {
- "id": "BNoSkK45Lmqc"
- }
- },
- {
- "cell_type": "code",
- "source": [
- "cuda_env = utils.CudaEnvironment()\n",
- "utils.CudaEnvironment.pretty_print_cuda_env_list([cuda_env])\n",
- "device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')"
- ],
- "metadata": {
- "id": "oqrsbmcoLqMl"
- },
- "execution_count": null,
- "outputs": []
- },
- {
- "cell_type": "markdown",
- "source": [
- "# Dataloading"
- ],
- "metadata": {
- "id": "TbJuBIHLLt2D"
- }
- },
- {
- "cell_type": "markdown",
- "source": [
- "## We borrow the TranslationTask from fairseq\n",
- "* used to load the binarized data created above\n",
- "* well-implemented data iterator (dataloader)\n",
- "* built-in task.source_dictionary and task.target_dictionary are also handy\n",
- "* well-implemented beach search decoder"
- ],
- "metadata": {
- "id": "oOpG4EBRLwe_"
- }
- },
- {
- "cell_type": "code",
- "source": [
- "from fairseq.tasks.translation import TranslationConfig, TranslationTask\n",
- "\n",
- "## setup task\n",
- "task_cfg = TranslationConfig(\n",
- " data=config.datadir,\n",
- " source_lang=config.source_lang,\n",
- " target_lang=config.target_lang,\n",
- " train_subset=\"train\",\n",
- " required_seq_len_multiple=8,\n",
- " dataset_impl=\"mmap\",\n",
- " upsample_primary=1,\n",
- ")\n",
- "task = TranslationTask.setup_task(task_cfg)"
- ],
- "metadata": {
- "id": "3gSEy1uFLvVs"
- },
- "execution_count": null,
- "outputs": []
- },
- {
- "cell_type": "code",
- "source": [
- "logger.info(\"loading data for epoch 1\")\n",
- "task.load_dataset(split=\"train\", epoch=1, combine=True) # combine if you have back-translation data.\n",
- "task.load_dataset(split=\"valid\", epoch=1)"
- ],
- "metadata": {
- "id": "mR7Bhov7L4IU"
- },
- "execution_count": null,
- "outputs": []
- },
- {
- "cell_type": "code",
- "source": [
- "sample = task.dataset(\"valid\")[1]\n",
- "pprint.pprint(sample)\n",
- "pprint.pprint(\n",
- " \"Source: \" + \\\n",
- " task.source_dictionary.string(\n",
- " sample['source'],\n",
- " config.post_process,\n",
- " )\n",
- ")\n",
- "pprint.pprint(\n",
- " \"Target: \" + \\\n",
- " task.target_dictionary.string(\n",
- " sample['target'],\n",
- " config.post_process,\n",
- " )\n",
- ")"
- ],
- "metadata": {
- "id": "P0BCEm_9L6ig"
- },
- "execution_count": null,
- "outputs": []
- },
- {
- "cell_type": "markdown",
- "source": [
- "# Dataset iterator"
- ],
- "metadata": {
- "id": "UcfCVa2FMBSE"
- }
- },
- {
- "cell_type": "markdown",
- "source": [
- "* Controls every batch to contain no more than N tokens, which optimizes GPU memory efficiency\n",
- "* Shuffles the training set for every epoch\n",
- "* Ignore sentences exceeding maximum length\n",
- "* Pad all sentences in a batch to the same length, which enables parallel computing by GPU\n",
- "* Add eos and shift one token\n",
- " - teacher forcing: to train the model to predict the next token based on prefix, we feed the right shifted target sequence as the decoder input.\n",
- " - generally, prepending bos to the target would do the job (as shown below)\n",
- "\n",
- " - in fairseq however, this is done by moving the eos token to the begining. Empirically, this has the same effect. For instance:\n",
- " ```\n",
- " # output target (target) and Decoder input (prev_output_tokens): \n",
- " eos = 2\n",
- " target = 419, 711, 238, 888, 792, 60, 968, 8, 2\n",
- " prev_output_tokens = 2, 419, 711, 238, 888, 792, 60, 968, 8\n",
- " ```\n",
- "\n"
- ],
- "metadata": {
- "id": "yBvc-B_6MKZM"
- }
- },
- {
- "cell_type": "code",
- "source": [
- "def load_data_iterator(task, split, epoch=1, max_tokens=4000, num_workers=1, cached=True):\n",
- " batch_iterator = task.get_batch_iterator(\n",
- " dataset=task.dataset(split),\n",
- " max_tokens=max_tokens,\n",
- " max_sentences=None,\n",
- " max_positions=utils.resolve_max_positions(\n",
- " task.max_positions(),\n",
- " max_tokens,\n",
- " ),\n",
- " ignore_invalid_inputs=True,\n",
- " seed=seed,\n",
- " num_workers=num_workers,\n",
- " epoch=epoch,\n",
- " disable_iterator_cache=not cached,\n",
- " # Set this to False to speed up. However, if set to False, changing max_tokens beyond \n",
- " # first call of this method has no effect. \n",
- " )\n",
- " return batch_iterator\n",
- "\n",
- "demo_epoch_obj = load_data_iterator(task, \"valid\", epoch=1, max_tokens=20, num_workers=1, cached=False)\n",
- "demo_iter = demo_epoch_obj.next_epoch_itr(shuffle=True)\n",
- "sample = next(demo_iter)\n",
- "sample"
- ],
- "metadata": {
- "id": "OWFJFmCnMDXW"
- },
- "execution_count": null,
- "outputs": []
- },
- {
- "cell_type": "markdown",
- "source": [
- "* each batch is a python dict, with string key and Tensor value. Contents are described below:\n",
- "```python\n",
- "batch = {\n",
- " \"id\": id, # id for each example \n",
- " \"nsentences\": len(samples), # batch size (sentences)\n",
- " \"ntokens\": ntokens, # batch size (tokens)\n",
- " \"net_input\": {\n",
- " \"src_tokens\": src_tokens, # sequence in source language\n",
- " \"src_lengths\": src_lengths, # sequence length of each example before padding\n",
- " \"prev_output_tokens\": prev_output_tokens, # right shifted target, as mentioned above.\n",
- " },\n",
- " \"target\": target, # target sequence\n",
- "}\n",
- "```"
- ],
- "metadata": {
- "id": "p86K-0g7Me4M"
- }
- },
- {
- "cell_type": "markdown",
- "source": [
- "# Model Architecture\n",
- "* We again inherit fairseq's encoder, decoder and model, so that in the testing phase we can directly leverage fairseq's beam search decoder."
- ],
- "metadata": {
- "id": "9EyDBE5ZMkFZ"
- }
- },
- {
- "cell_type": "code",
- "source": [
- "from fairseq.models import (\n",
- " FairseqEncoder, \n",
- " FairseqIncrementalDecoder,\n",
- " FairseqEncoderDecoderModel\n",
- ")"
- ],
- "metadata": {
- "id": "Hzh74qLIMfW_"
- },
- "execution_count": null,
- "outputs": []
- },
- {
- "cell_type": "markdown",
- "source": [
- "# Encoder"
- ],
- "metadata": {
- "id": "OI46v1z7MotH"
- }
- },
- {
- "cell_type": "markdown",
- "source": [
- "- The Encoder is a RNN or Transformer Encoder. The following description is for RNN. For every input token, Encoder will generate a output vector and a hidden states vector, and the hidden states vector is passed on to the next step. In other words, the Encoder sequentially reads in the input sequence, and outputs a single vector at each timestep, then finally outputs the final hidden states, or content vector, at the last timestep.\n",
- "- Parameters:\n",
- " - *args*\n",
- " - encoder_embed_dim: the dimension of embeddings, this compresses the one-hot vector into fixed dimensions, which achieves dimension reduction\n",
- " - encoder_ffn_embed_dim is the dimension of hidden states and output vectors\n",
- " - encoder_layers is the number of layers for Encoder RNN\n",
- " - dropout determines the probability of a neuron's activation being set to 0, in order to prevent overfitting. Generally this is applied in training, and removed in testing.\n",
- " - *dictionary*: the dictionary provided by fairseq. it's used to obtain the padding index, and in turn the encoder padding mask. \n",
- " - *embed_tokens*: an instance of token embeddings (nn.Embedding)\n",
- "\n",
- "- Inputs: \n",
- " - *src_tokens*: integer sequence representing english e.g. 1, 28, 29, 205, 2 \n",
- "- Outputs: \n",
- " - *outputs*: the output of RNN at each timestep, can be furthur processed by Attention\n",
- " - *final_hiddens*: the hidden states of each timestep, will be passed to decoder for decoding\n",
- " - *encoder_padding_mask*: this tells the decoder which position to ignore\n"
- ],
- "metadata": {
- "id": "Wn0wSeLLMrbc"
- }
- },
- {
- "cell_type": "code",
- "source": [
- "class RNNEncoder(FairseqEncoder):\n",
- " def __init__(self, args, dictionary, embed_tokens):\n",
- " super().__init__(dictionary)\n",
- " self.embed_tokens = embed_tokens\n",
- " \n",
- " self.embed_dim = args.encoder_embed_dim\n",
- " self.hidden_dim = args.encoder_ffn_embed_dim\n",
- " self.num_layers = args.encoder_layers\n",
- " \n",
- " self.dropout_in_module = nn.Dropout(args.dropout)\n",
- " self.rnn = nn.GRU(\n",
- " self.embed_dim, \n",
- " self.hidden_dim, \n",
- " self.num_layers, \n",
- " dropout=args.dropout, \n",
- " batch_first=False, \n",
- " bidirectional=True\n",
- " )\n",
- " self.dropout_out_module = nn.Dropout(args.dropout)\n",
- " \n",
- " self.padding_idx = dictionary.pad()\n",
- " \n",
- " def combine_bidir(self, outs, bsz: int):\n",
- " out = outs.view(self.num_layers, 2, bsz, -1).transpose(1, 2).contiguous()\n",
- " return out.view(self.num_layers, bsz, -1)\n",
- "\n",
- " def forward(self, src_tokens, **unused):\n",
- " bsz, seqlen = src_tokens.size()\n",
- " \n",
- " # get embeddings\n",
- " x = self.embed_tokens(src_tokens)\n",
- " x = self.dropout_in_module(x)\n",
- "\n",
- " # B x T x C -> T x B x C\n",
- " x = x.transpose(0, 1)\n",
- " \n",
- " # pass thru bidirectional RNN\n",
- " h0 = x.new_zeros(2 * self.num_layers, bsz, self.hidden_dim)\n",
- " x, final_hiddens = self.rnn(x, h0)\n",
- " outputs = self.dropout_out_module(x)\n",
- " # outputs = [sequence len, batch size, hid dim * directions]\n",
- " # hidden = [num_layers * directions, batch size , hid dim]\n",
- " \n",
- " # Since Encoder is bidirectional, we need to concatenate the hidden states of two directions\n",
- " final_hiddens = self.combine_bidir(final_hiddens, bsz)\n",
- " # hidden = [num_layers x batch x num_directions*hidden]\n",
- " \n",
- " encoder_padding_mask = src_tokens.eq(self.padding_idx).t()\n",
- " return tuple(\n",
- " (\n",
- " outputs, # seq_len x batch x hidden\n",
- " final_hiddens, # num_layers x batch x num_directions*hidden\n",
- " encoder_padding_mask, # seq_len x batch\n",
- " )\n",
- " )\n",
- " \n",
- " def reorder_encoder_out(self, encoder_out, new_order):\n",
- " # This is used by fairseq's beam search. How and why is not particularly important here.\n",
- " return tuple(\n",
- " (\n",
- " encoder_out[0].index_select(1, new_order),\n",
- " encoder_out[1].index_select(1, new_order),\n",
- " encoder_out[2].index_select(1, new_order),\n",
- " )\n",
- " )"
- ],
- "metadata": {
- "id": "WcX3W4iGMq-S"
- },
- "execution_count": null,
- "outputs": []
- },
- {
- "cell_type": "markdown",
- "source": [
- "## Attention"
- ],
- "metadata": {
- "id": "6ZlE_1JnMv56"
- }
- },
- {
- "cell_type": "markdown",
- "source": [
- "- When the input sequence is long, \"content vector\" alone cannot accurately represent the whole sequence, attention mechanism can provide the Decoder more information.\n",
- "- According to the **Decoder embeddings** of the current timestep, match the **Encoder outputs** with decoder embeddings to determine correlation, and then sum the Encoder outputs weighted by the correlation as the input to **Decoder** RNN.\n",
- "- Common attention implementations use neural network / dot product as the correlation between **query** (decoder embeddings) and **key** (Encoder outputs), followed by **softmax** to obtain a distribution, and finally **values** (Encoder outputs) is **weighted sum**-ed by said distribution.\n",
- "\n",
- "- Parameters:\n",
- " - *input_embed_dim*: dimensionality of key, should be that of the vector in decoder to attend others\n",
- " - *source_embed_dim*: dimensionality of query, should be that of the vector to be attended to (encoder outputs)\n",
- " - *output_embed_dim*: dimensionality of value, should be that of the vector after attention, expected by the next layer\n",
- "\n",
- "- Inputs: \n",
- " - *inputs*: is the key, the vector to attend to others\n",
- " - *encoder_outputs*: is the query/value, the vector to be attended to\n",
- " - *encoder_padding_mask*: this tells the decoder which position to ignore\n",
- "- Outputs: \n",
- " - *output*: the context vector after attention\n",
- " - *attention score*: the attention distribution\n"
- ],
- "metadata": {
- "id": "ZSFSKt_ZMzgh"
- }
- },
- {
- "cell_type": "code",
- "source": [
- "class AttentionLayer(nn.Module):\n",
- " def __init__(self, input_embed_dim, source_embed_dim, output_embed_dim, bias=False):\n",
- " super().__init__()\n",
- "\n",
- " self.input_proj = nn.Linear(input_embed_dim, source_embed_dim, bias=bias)\n",
- " self.output_proj = nn.Linear(\n",
- " input_embed_dim + source_embed_dim, output_embed_dim, bias=bias\n",
- " )\n",
- "\n",
- " def forward(self, inputs, encoder_outputs, encoder_padding_mask):\n",
- " # inputs: T, B, dim\n",
- " # encoder_outputs: S x B x dim\n",
- " # padding mask: S x B\n",
- " \n",
- " # convert all to batch first\n",
- " inputs = inputs.transpose(1,0) # B, T, dim\n",
- " encoder_outputs = encoder_outputs.transpose(1,0) # B, S, dim\n",
- " encoder_padding_mask = encoder_padding_mask.transpose(1,0) # B, S\n",
- " \n",
- " # project to the dimensionality of encoder_outputs\n",
- " x = self.input_proj(inputs)\n",
- "\n",
- " # compute attention\n",
- " # (B, T, dim) x (B, dim, S) = (B, T, S)\n",
- " attn_scores = torch.bmm(x, encoder_outputs.transpose(1,2))\n",
- "\n",
- " # cancel the attention at positions corresponding to padding\n",
- " if encoder_padding_mask is not None:\n",
- " # leveraging broadcast B, S -> (B, 1, S)\n",
- " encoder_padding_mask = encoder_padding_mask.unsqueeze(1)\n",
- " attn_scores = (\n",
- " attn_scores.float()\n",
- " .masked_fill_(encoder_padding_mask, float(\"-inf\"))\n",
- " .type_as(attn_scores)\n",
- " ) # FP16 support: cast to float and back\n",
- "\n",
- " # softmax on the dimension corresponding to source sequence\n",
- " attn_scores = F.softmax(attn_scores, dim=-1)\n",
- "\n",
- " # shape (B, T, S) x (B, S, dim) = (B, T, dim) weighted sum\n",
- " x = torch.bmm(attn_scores, encoder_outputs)\n",
- "\n",
- " # (B, T, dim)\n",
- " x = torch.cat((x, inputs), dim=-1)\n",
- " x = torch.tanh(self.output_proj(x)) # concat + linear + tanh\n",
- " \n",
- " # restore shape (B, T, dim) -> (T, B, dim)\n",
- " return x.transpose(1,0), attn_scores"
- ],
- "metadata": {
- "id": "1Atf_YuCMyyF"
- },
- "execution_count": null,
- "outputs": []
- },
- {
- "cell_type": "markdown",
- "source": [
- "# Decoder"
- ],
- "metadata": {
- "id": "doSCOA2gM7fK"
- }
- },
- {
- "cell_type": "markdown",
- "source": [
- "* The hidden states of **Decoder** will be initialized by the final hidden states of **Encoder** (the content vector)\n",
- "* At the same time, **Decoder** will change its hidden states based on the input of the current timestep (the outputs of previous timesteps), and generates an output\n",
- "* Attention improves the performance\n",
- "* The seq2seq steps are implemented in decoder, so that later the Seq2Seq class can accept RNN and Transformer, without furthur modification.\n",
- "- Parameters:\n",
- " - *args*\n",
- " - decoder_embed_dim: is the dimensionality of the decoder embeddings, similar to encoder_embed_dim,\n",
- " - decoder_ffn_embed_dim: is the dimensionality of the decoder RNN hidden states, similar to encoder_ffn_embed_dim\n",
- " - decoder_layers: number of layers of RNN decoder\n",
- " - share_decoder_input_output_embed: usually, the projection matrix of the decoder will share weights with the decoder input embeddings\n",
- " - *dictionary*: the dictionary provided by fairseq\n",
- " - *embed_tokens*: an instance of token embeddings (nn.Embedding)\n",
- "- Inputs: \n",
- " - *prev_output_tokens*: integer sequence representing the right-shifted target e.g. 1, 28, 29, 205, 2 \n",
- " - *encoder_out*: encoder's output.\n",
- " - *incremental_state*: in order to speed up decoding during test time, we will save the hidden state of each timestep. see forward() for details.\n",
- "- Outputs: \n",
- " - *outputs*: the logits (before softmax) output of decoder for each timesteps\n",
- " - *extra*: unsused"
- ],
- "metadata": {
- "id": "2M8Vod2gNABR"
- }
- },
- {
- "cell_type": "code",
- "source": [
- "class RNNDecoder(FairseqIncrementalDecoder):\n",
- " def __init__(self, args, dictionary, embed_tokens):\n",
- " super().__init__(dictionary)\n",
- " self.embed_tokens = embed_tokens\n",
- " \n",
- " assert args.decoder_layers == args.encoder_layers, f\"\"\"seq2seq rnn requires that encoder \n",
- " and decoder have same layers of rnn. got: {args.encoder_layers, args.decoder_layers}\"\"\"\n",
- " assert args.decoder_ffn_embed_dim == args.encoder_ffn_embed_dim*2, f\"\"\"seq2seq-rnn requires \n",
- " that decoder hidden to be 2*encoder hidden dim. got: {args.decoder_ffn_embed_dim, args.encoder_ffn_embed_dim*2}\"\"\"\n",
- " \n",
- " self.embed_dim = args.decoder_embed_dim\n",
- " self.hidden_dim = args.decoder_ffn_embed_dim\n",
- " self.num_layers = args.decoder_layers\n",
- " \n",
- " \n",
- " self.dropout_in_module = nn.Dropout(args.dropout)\n",
- " self.rnn = nn.GRU(\n",
- " self.embed_dim, \n",
- " self.hidden_dim, \n",
- " self.num_layers, \n",
- " dropout=args.dropout, \n",
- " batch_first=False, \n",
- " bidirectional=False\n",
- " )\n",
- " self.attention = AttentionLayer(\n",
- " self.embed_dim, self.hidden_dim, self.embed_dim, bias=False\n",
- " ) \n",
- " # self.attention = None\n",
- " self.dropout_out_module = nn.Dropout(args.dropout)\n",
- " \n",
- " if self.hidden_dim != self.embed_dim:\n",
- " self.project_out_dim = nn.Linear(self.hidden_dim, self.embed_dim)\n",
- " else:\n",
- " self.project_out_dim = None\n",
- " \n",
- " if args.share_decoder_input_output_embed:\n",
- " self.output_projection = nn.Linear(\n",
- " self.embed_tokens.weight.shape[1],\n",
- " self.embed_tokens.weight.shape[0],\n",
- " bias=False,\n",
- " )\n",
- " self.output_projection.weight = self.embed_tokens.weight\n",
- " else:\n",
- " self.output_projection = nn.Linear(\n",
- " self.output_embed_dim, len(dictionary), bias=False\n",
- " )\n",
- " nn.init.normal_(\n",
- " self.output_projection.weight, mean=0, std=self.output_embed_dim ** -0.5\n",
- " )\n",
- " \n",
- " def forward(self, prev_output_tokens, encoder_out, incremental_state=None, **unused):\n",
- " # extract the outputs from encoder\n",
- " encoder_outputs, encoder_hiddens, encoder_padding_mask = encoder_out\n",
- " # outputs: seq_len x batch x num_directions*hidden\n",
- " # encoder_hiddens: num_layers x batch x num_directions*encoder_hidden\n",
- " # padding_mask: seq_len x batch\n",
- " \n",
- " if incremental_state is not None and len(incremental_state) > 0:\n",
- " # if the information from last timestep is retained, we can continue from there instead of starting from bos\n",
- " prev_output_tokens = prev_output_tokens[:, -1:]\n",
- " cache_state = self.get_incremental_state(incremental_state, \"cached_state\")\n",
- " prev_hiddens = cache_state[\"prev_hiddens\"]\n",
- " else:\n",
- " # incremental state does not exist, either this is training time, or the first timestep of test time\n",
- " # prepare for seq2seq: pass the encoder_hidden to the decoder hidden states\n",
- " prev_hiddens = encoder_hiddens\n",
- " \n",
- " bsz, seqlen = prev_output_tokens.size()\n",
- " \n",
- " # embed tokens\n",
- " x = self.embed_tokens(prev_output_tokens)\n",
- " x = self.dropout_in_module(x)\n",
- "\n",
- " # B x T x C -> T x B x C\n",
- " x = x.transpose(0, 1)\n",
- " \n",
- " # decoder-to-encoder attention\n",
- " if self.attention is not None:\n",
- " x, attn = self.attention(x, encoder_outputs, encoder_padding_mask)\n",
- " \n",
- " # pass thru unidirectional RNN\n",
- " x, final_hiddens = self.rnn(x, prev_hiddens)\n",
- " # outputs = [sequence len, batch size, hid dim]\n",
- " # hidden = [num_layers * directions, batch size , hid dim]\n",
- " x = self.dropout_out_module(x)\n",
- " \n",
- " # project to embedding size (if hidden differs from embed size, and share_embedding is True, \n",
- " # we need to do an extra projection)\n",
- " if self.project_out_dim != None:\n",
- " x = self.project_out_dim(x)\n",
- " \n",
- " # project to vocab size\n",
- " x = self.output_projection(x)\n",
- " \n",
- " # T x B x C -> B x T x C\n",
- " x = x.transpose(1, 0)\n",
- " \n",
- " # if incremental, record the hidden states of current timestep, which will be restored in the next timestep\n",
- " cache_state = {\n",
- " \"prev_hiddens\": final_hiddens,\n",
- " }\n",
- " self.set_incremental_state(incremental_state, \"cached_state\", cache_state)\n",
- " \n",
- " return x, None\n",
- " \n",
- " def reorder_incremental_state(\n",
- " self,\n",
- " incremental_state,\n",
- " new_order,\n",
- " ):\n",
- " # This is used by fairseq's beam search. How and why is not particularly important here.\n",
- " cache_state = self.get_incremental_state(incremental_state, \"cached_state\")\n",
- " prev_hiddens = cache_state[\"prev_hiddens\"]\n",
- " prev_hiddens = [p.index_select(0, new_order) for p in prev_hiddens]\n",
- " cache_state = {\n",
- " \"prev_hiddens\": torch.stack(prev_hiddens),\n",
- " }\n",
- " self.set_incremental_state(incremental_state, \"cached_state\", cache_state)\n",
- " return"
- ],
- "metadata": {
- "id": "QfvgqHYDM6Lp"
- },
- "execution_count": null,
- "outputs": []
- },
- {
- "cell_type": "markdown",
- "source": [
- "## Seq2Seq\n",
- "- Composed of **Encoder** and **Decoder**\n",
- "- Recieves inputs and pass to **Encoder** \n",
- "- Pass the outputs from **Encoder** to **Decoder**\n",
- "- **Decoder** will decode according to outputs of previous timesteps as well as **Encoder** outputs \n",
- "- Once done decoding, return the **Decoder** outputs"
- ],
- "metadata": {
- "id": "UDAPmxjRNEEL"
- }
- },
- {
- "cell_type": "code",
- "source": [
- "class Seq2Seq(FairseqEncoderDecoderModel):\n",
- " def __init__(self, args, encoder, decoder):\n",
- " super().__init__(encoder, decoder)\n",
- " self.args = args\n",
- " \n",
- " def forward(\n",
- " self,\n",
- " src_tokens,\n",
- " src_lengths,\n",
- " prev_output_tokens,\n",
- " return_all_hiddens: bool = True,\n",
- " ):\n",
- " \"\"\"\n",
- " Run the forward pass for an encoder-decoder model.\n",
- " \"\"\"\n",
- " encoder_out = self.encoder(\n",
- " src_tokens, src_lengths=src_lengths, return_all_hiddens=return_all_hiddens\n",
- " )\n",
- " logits, extra = self.decoder(\n",
- " prev_output_tokens,\n",
- " encoder_out=encoder_out,\n",
- " src_lengths=src_lengths,\n",
- " return_all_hiddens=return_all_hiddens,\n",
- " )\n",
- " return logits, extra"
- ],
- "metadata": {
- "id": "oRwKdLa0NEU6"
- },
- "execution_count": null,
- "outputs": []
- },
- {
- "cell_type": "markdown",
- "source": [
- "# Model Initialization"
- ],
- "metadata": {
- "id": "zu3C2JfqNHzk"
- }
- },
- {
- "cell_type": "code",
- "source": [
- "# # HINT: transformer architecture\n",
- "from fairseq.models.transformer import (\n",
- " TransformerEncoder, \n",
- " TransformerDecoder,\n",
- ")\n",
- "\n",
- "def build_model(args, task):\n",
- " \"\"\" build a model instance based on hyperparameters \"\"\"\n",
- " src_dict, tgt_dict = task.source_dictionary, task.target_dictionary\n",
- "\n",
- " # token embeddings\n",
- " encoder_embed_tokens = nn.Embedding(len(src_dict), args.encoder_embed_dim, src_dict.pad())\n",
- " decoder_embed_tokens = nn.Embedding(len(tgt_dict), args.decoder_embed_dim, tgt_dict.pad())\n",
- " \n",
- " # encoder decoder\n",
- " # HINT: TODO: switch to TransformerEncoder & TransformerDecoder\n",
- " encoder = RNNEncoder(args, src_dict, encoder_embed_tokens)\n",
- " decoder = RNNDecoder(args, tgt_dict, decoder_embed_tokens)\n",
- " # encoder = TransformerEncoder(args, src_dict, encoder_embed_tokens)\n",
- " # decoder = TransformerDecoder(args, tgt_dict, decoder_embed_tokens)\n",
- "\n",
- " # sequence to sequence model\n",
- " model = Seq2Seq(args, encoder, decoder)\n",
- " \n",
- " # initialization for seq2seq model is important, requires extra handling\n",
- " def init_params(module):\n",
- " from fairseq.modules import MultiheadAttention\n",
- " if isinstance(module, nn.Linear):\n",
- " module.weight.data.normal_(mean=0.0, std=0.02)\n",
- " if module.bias is not None:\n",
- " module.bias.data.zero_()\n",
- " if isinstance(module, nn.Embedding):\n",
- " module.weight.data.normal_(mean=0.0, std=0.02)\n",
- " if module.padding_idx is not None:\n",
- " module.weight.data[module.padding_idx].zero_()\n",
- " if isinstance(module, MultiheadAttention):\n",
- " module.q_proj.weight.data.normal_(mean=0.0, std=0.02)\n",
- " module.k_proj.weight.data.normal_(mean=0.0, std=0.02)\n",
- " module.v_proj.weight.data.normal_(mean=0.0, std=0.02)\n",
- " if isinstance(module, nn.RNNBase):\n",
- " for name, param in module.named_parameters():\n",
- " if \"weight\" in name or \"bias\" in name:\n",
- " param.data.uniform_(-0.1, 0.1)\n",
- " \n",
- " # weight initialization\n",
- " model.apply(init_params)\n",
- " return model"
- ],
- "metadata": {
- "id": "nyI9FOx-NJ2m"
- },
- "execution_count": null,
- "outputs": []
- },
- {
- "cell_type": "markdown",
- "source": [
- "## Architecture Related Configuration\n",
- "\n",
- "For strong baseline, please refer to the hyperparameters for *transformer-base* in Table 3 in [Attention is all you need](#vaswani2017)"
- ],
- "metadata": {
- "id": "ce5n4eS7NQNy"
- }
- },
- {
- "cell_type": "code",
- "source": [
- "arch_args = Namespace(\n",
- " encoder_embed_dim=256,\n",
- " encoder_ffn_embed_dim=512,\n",
- " encoder_layers=1,\n",
- " decoder_embed_dim=256,\n",
- " decoder_ffn_embed_dim=1024,\n",
- " decoder_layers=1,\n",
- " share_decoder_input_output_embed=True,\n",
- " dropout=0.3,\n",
- ")\n",
- "\n",
- "# HINT: these patches on parameters for Transformer\n",
- "def add_transformer_args(args):\n",
- " args.encoder_attention_heads=4\n",
- " args.encoder_normalize_before=True\n",
- " \n",
- " args.decoder_attention_heads=4\n",
- " args.decoder_normalize_before=True\n",
- " \n",
- " args.activation_fn=\"relu\"\n",
- " args.max_source_positions=1024\n",
- " args.max_target_positions=1024\n",
- " \n",
- " # patches on default parameters for Transformer (those not set above)\n",
- " from fairseq.models.transformer import base_architecture\n",
- " base_architecture(arch_args)\n",
- "\n",
- "# add_transformer_args(arch_args)"
- ],
- "metadata": {
- "id": "Cyn30VoGNT6N"
- },
- "execution_count": null,
- "outputs": []
- },
- {
- "cell_type": "code",
- "source": [
- "if config.use_wandb:\n",
- " wandb.config.update(vars(arch_args))"
- ],
- "metadata": {
- "id": "Nbb76QLCNZZZ"
- },
- "execution_count": null,
- "outputs": []
- },
- {
- "cell_type": "code",
- "source": [
- "model = build_model(arch_args, task)\n",
- "logger.info(model)"
- ],
- "metadata": {
- "id": "7ZWfxsCDNatH"
- },
- "execution_count": null,
- "outputs": []
- },
- {
- "cell_type": "markdown",
- "source": [
- "# Optimization"
- ],
- "metadata": {
- "id": "aHll7GRNNdqc"
- }
- },
- {
- "cell_type": "markdown",
- "source": [
- "## Loss: Label Smoothing Regularization\n",
- "* let the model learn to generate less concentrated distribution, and prevent over-confidence\n",
- "* sometimes the ground truth may not be the only answer. thus, when calculating loss, we reserve some probability for incorrect labels\n",
- "* avoids overfitting\n",
- "\n",
- "code [source](https://fairseq.readthedocs.io/en/latest/_modules/fairseq/criterions/label_smoothed_cross_entropy.html)"
- ],
- "metadata": {
- "id": "rUB9f1WCNgMH"
- }
- },
- {
- "cell_type": "code",
- "source": [
- "class LabelSmoothedCrossEntropyCriterion(nn.Module):\n",
- " def __init__(self, smoothing, ignore_index=None, reduce=True):\n",
- " super().__init__()\n",
- " self.smoothing = smoothing\n",
- " self.ignore_index = ignore_index\n",
- " self.reduce = reduce\n",
- " \n",
- " def forward(self, lprobs, target):\n",
- " if target.dim() == lprobs.dim() - 1:\n",
- " target = target.unsqueeze(-1)\n",
- " # nll: Negative log likelihood,the cross-entropy when target is one-hot. following line is same as F.nll_loss\n",
- " nll_loss = -lprobs.gather(dim=-1, index=target)\n",
- " # reserve some probability for other labels. thus when calculating cross-entropy, \n",
- " # equivalent to summing the log probs of all labels\n",
- " smooth_loss = -lprobs.sum(dim=-1, keepdim=True)\n",
- " if self.ignore_index is not None:\n",
- " pad_mask = target.eq(self.ignore_index)\n",
- " nll_loss.masked_fill_(pad_mask, 0.0)\n",
- " smooth_loss.masked_fill_(pad_mask, 0.0)\n",
- " else:\n",
- " nll_loss = nll_loss.squeeze(-1)\n",
- " smooth_loss = smooth_loss.squeeze(-1)\n",
- " if self.reduce:\n",
- " nll_loss = nll_loss.sum()\n",
- " smooth_loss = smooth_loss.sum()\n",
- " # when calculating cross-entropy, add the loss of other labels\n",
- " eps_i = self.smoothing / lprobs.size(-1)\n",
- " loss = (1.0 - self.smoothing) * nll_loss + eps_i * smooth_loss\n",
- " return loss\n",
- "\n",
- "# generally, 0.1 is good enough\n",
- "criterion = LabelSmoothedCrossEntropyCriterion(\n",
- " smoothing=0.1,\n",
- " ignore_index=task.target_dictionary.pad(),\n",
- ")"
- ],
- "metadata": {
- "id": "IgspdJn0NdYF"
- },
- "execution_count": null,
- "outputs": []
- },
- {
- "cell_type": "markdown",
- "source": [
- "## Optimizer: Adam + lr scheduling\n",
- "Inverse square root scheduling is important to the stability when training Transformer. It's later used on RNN as well.\n",
- "Update the learning rate according to the following equation. Linearly increase the first stage, then decay proportionally to the inverse square root of timestep.\n",
- "$$lrate = d_{\\text{model}}^{-0.5}\\cdot\\min({step\\_num}^{-0.5},{step\\_num}\\cdot{warmup\\_steps}^{-1.5})$$"
- ],
- "metadata": {
- "id": "aRalDto2NkJJ"
- }
- },
- {
- "cell_type": "code",
- "source": [
- "def get_rate(d_model, step_num, warmup_step):\n",
- " # TODO: Change lr from constant to the equation shown above\n",
- " lr = 0.001\n",
- " return lr"
- ],
- "metadata": {
- "id": "sS7tQj1ROBYm"
- },
- "execution_count": null,
- "outputs": []
- },
- {
- "cell_type": "code",
- "source": [
- "class NoamOpt:\n",
- " \"Optim wrapper that implements rate.\"\n",
- " def __init__(self, model_size, factor, warmup, optimizer):\n",
- " self.optimizer = optimizer\n",
- " self._step = 0\n",
- " self.warmup = warmup\n",
- " self.factor = factor\n",
- " self.model_size = model_size\n",
- " self._rate = 0\n",
- " \n",
- " @property\n",
- " def param_groups(self):\n",
- " return self.optimizer.param_groups\n",
- " \n",
- " def multiply_grads(self, c):\n",
- " \"\"\"Multiplies grads by a constant *c*.\"\"\" \n",
- " for group in self.param_groups:\n",
- " for p in group['params']:\n",
- " if p.grad is not None:\n",
- " p.grad.data.mul_(c)\n",
- " \n",
- " def step(self):\n",
- " \"Update parameters and rate\"\n",
- " self._step += 1\n",
- " rate = self.rate()\n",
- " for p in self.param_groups:\n",
- " p['lr'] = rate\n",
- " self._rate = rate\n",
- " self.optimizer.step()\n",
- " \n",
- " def rate(self, step = None):\n",
- " \"Implement `lrate` above\"\n",
- " if step is None:\n",
- " step = self._step\n",
- " return 0 if not step else self.factor * get_rate(self.model_size, step, self.warmup)"
- ],
- "metadata": {
- "id": "J8hoAjHPNkh3"
- },
- "execution_count": null,
- "outputs": []
- },
- {
- "cell_type": "markdown",
- "source": [
- "## Scheduling Visualized"
- ],
- "metadata": {
- "id": "VFJlkOMONsc6"
- }
- },
- {
- "cell_type": "code",
- "source": [
- "optimizer = NoamOpt(\n",
- " model_size=arch_args.encoder_embed_dim, \n",
- " factor=config.lr_factor, \n",
- " warmup=config.lr_warmup, \n",
- " optimizer=torch.optim.AdamW(model.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9, weight_decay=0.0001))\n",
- "plt.plot(np.arange(1, 100000), [optimizer.rate(i) for i in range(1, 100000)])\n",
- "plt.legend([f\"{optimizer.model_size}:{optimizer.warmup}\"])\n",
- "None"
- ],
- "metadata": {
- "id": "A135fwPCNrQs"
- },
- "execution_count": null,
- "outputs": []
- },
- {
- "cell_type": "markdown",
- "source": [
- "# Training Procedure"
- ],
- "metadata": {
- "id": "TOR0g-cVO5ZO"
- }
- },
- {
- "cell_type": "markdown",
- "source": [
- "## Training"
- ],
- "metadata": {
- "id": "f-0ZjbK3O8Iv"
- }
- },
- {
- "cell_type": "code",
- "source": [
- "from fairseq.data import iterators\n",
- "from torch.cuda.amp import GradScaler, autocast\n",
- "\n",
- "def train_one_epoch(epoch_itr, model, task, criterion, optimizer, accum_steps=1):\n",
- " itr = epoch_itr.next_epoch_itr(shuffle=True)\n",
- " itr = iterators.GroupedIterator(itr, accum_steps) # gradient accumulation: update every accum_steps samples\n",
- " \n",
- " stats = {\"loss\": []}\n",
- " scaler = GradScaler() # automatic mixed precision (amp) \n",
- " \n",
- " model.train()\n",
- " progress = tqdm.tqdm(itr, desc=f\"train epoch {epoch_itr.epoch}\", leave=False)\n",
- " for samples in progress:\n",
- " model.zero_grad()\n",
- " accum_loss = 0\n",
- " sample_size = 0\n",
- " # gradient accumulation: update every accum_steps samples\n",
- " for i, sample in enumerate(samples):\n",
- " if i == 1:\n",
- " # emptying the CUDA cache after the first step can reduce the chance of OOM\n",
- " torch.cuda.empty_cache()\n",
- "\n",
- " sample = utils.move_to_cuda(sample, device=device)\n",
- " target = sample[\"target\"]\n",
- " sample_size_i = sample[\"ntokens\"]\n",
- " sample_size += sample_size_i\n",
- " \n",
- " # mixed precision training\n",
- " with autocast():\n",
- " net_output = model.forward(**sample[\"net_input\"])\n",
- " lprobs = F.log_softmax(net_output[0], -1) \n",
- " loss = criterion(lprobs.view(-1, lprobs.size(-1)), target.view(-1))\n",
- " \n",
- " # logging\n",
- " accum_loss += loss.item()\n",
- " # back-prop\n",
- " scaler.scale(loss).backward() \n",
- " \n",
- " scaler.unscale_(optimizer)\n",
- " optimizer.multiply_grads(1 / (sample_size or 1.0)) # (sample_size or 1.0) handles the case of a zero gradient\n",
- " gnorm = nn.utils.clip_grad_norm_(model.parameters(), config.clip_norm) # grad norm clipping prevents gradient exploding\n",
- " \n",
- " scaler.step(optimizer)\n",
- " scaler.update()\n",
- " \n",
- " # logging\n",
- " loss_print = accum_loss/sample_size\n",
- " stats[\"loss\"].append(loss_print)\n",
- " progress.set_postfix(loss=loss_print)\n",
- " if config.use_wandb:\n",
- " wandb.log({\n",
- " \"train/loss\": loss_print,\n",
- " \"train/grad_norm\": gnorm.item(),\n",
- " \"train/lr\": optimizer.rate(),\n",
- " \"train/sample_size\": sample_size,\n",
- " })\n",
- " \n",
- " loss_print = np.mean(stats[\"loss\"])\n",
- " logger.info(f\"training loss: {loss_print:.4f}\")\n",
- " return stats"
- ],
- "metadata": {
- "id": "foal3xM1O404"
- },
- "execution_count": null,
- "outputs": []
- },
- {
- "cell_type": "markdown",
- "source": [
- "## Validation & Inference\n",
- "To prevent overfitting, validation is required every epoch to validate the performance on unseen data.\n",
- "- the procedure is essensially same as training, with the addition of inference step\n",
- "- after validation we can save the model weights\n",
- "\n",
- "Validation loss alone cannot describe the actual performance of the model\n",
- "- Directly produce translation hypotheses based on current model, then calculate BLEU with the reference translation\n",
- "- We can also manually examine the hypotheses' quality\n",
- "- We use fairseq's sequence generator for beam search to generate translation hypotheses"
- ],
- "metadata": {
- "id": "Gt1lX3DRO_yU"
- }
- },
- {
- "cell_type": "code",
- "source": [
- "# fairseq's beam search generator\n",
- "# given model and input seqeunce, produce translation hypotheses by beam search\n",
- "sequence_generator = task.build_generator([model], config)\n",
- "\n",
- "def decode(toks, dictionary):\n",
- " # convert from Tensor to human readable sentence\n",
- " s = dictionary.string(\n",
- " toks.int().cpu(),\n",
- " config.post_process,\n",
- " )\n",
- " return s if s else \"<unk>\"\n",
- "\n",
- "def inference_step(sample, model):\n",
- " gen_out = sequence_generator.generate([model], sample)\n",
- " srcs = []\n",
- " hyps = []\n",
- " refs = []\n",
- " for i in range(len(gen_out)):\n",
- " # for each sample, collect the input, hypothesis and reference, later be used to calculate BLEU\n",
- " srcs.append(decode(\n",
- " utils.strip_pad(sample[\"net_input\"][\"src_tokens\"][i], task.source_dictionary.pad()), \n",
- " task.source_dictionary,\n",
- " ))\n",
- " hyps.append(decode(\n",
- " gen_out[i][0][\"tokens\"], # 0 indicates using the top hypothesis in beam\n",
- " task.target_dictionary,\n",
- " ))\n",
- " refs.append(decode(\n",
- " utils.strip_pad(sample[\"target\"][i], task.target_dictionary.pad()), \n",
- " task.target_dictionary,\n",
- " ))\n",
- " return srcs, hyps, refs"
- ],
- "metadata": {
- "id": "2og80HYQPAKq"
- },
- "execution_count": null,
- "outputs": []
- },
- {
- "cell_type": "code",
- "source": [
- "import shutil\n",
- "import sacrebleu\n",
- "\n",
- "def validate(model, task, criterion, log_to_wandb=True):\n",
- " logger.info('begin validation')\n",
- " itr = load_data_iterator(task, \"valid\", 1, config.max_tokens, config.num_workers).next_epoch_itr(shuffle=False)\n",
- " \n",
- " stats = {\"loss\":[], \"bleu\": 0, \"srcs\":[], \"hyps\":[], \"refs\":[]}\n",
- " srcs = []\n",
- " hyps = []\n",
- " refs = []\n",
- " \n",
- " model.eval()\n",
- " progress = tqdm.tqdm(itr, desc=f\"validation\", leave=False)\n",
- " with torch.no_grad():\n",
- " for i, sample in enumerate(progress):\n",
- " # validation loss\n",
- " sample = utils.move_to_cuda(sample, device=device)\n",
- " net_output = model.forward(**sample[\"net_input\"])\n",
- "\n",
- " lprobs = F.log_softmax(net_output[0], -1)\n",
- " target = sample[\"target\"]\n",
- " sample_size = sample[\"ntokens\"]\n",
- " loss = criterion(lprobs.view(-1, lprobs.size(-1)), target.view(-1)) / sample_size\n",
- " progress.set_postfix(valid_loss=loss.item())\n",
- " stats[\"loss\"].append(loss)\n",
- " \n",
- " # do inference\n",
- " s, h, r = inference_step(sample, model)\n",
- " srcs.extend(s)\n",
- " hyps.extend(h)\n",
- " refs.extend(r)\n",
- " \n",
- " tok = 'zh' if task.cfg.target_lang == 'zh' else '13a'\n",
- " stats[\"loss\"] = torch.stack(stats[\"loss\"]).mean().item()\n",
- " stats[\"bleu\"] = sacrebleu.corpus_bleu(hyps, [refs], tokenize=tok) # 計算BLEU score\n",
- " stats[\"srcs\"] = srcs\n",
- " stats[\"hyps\"] = hyps\n",
- " stats[\"refs\"] = refs\n",
- " \n",
- " if config.use_wandb and log_to_wandb:\n",
- " wandb.log({\n",
- " \"valid/loss\": stats[\"loss\"],\n",
- " \"valid/bleu\": stats[\"bleu\"].score,\n",
- " }, commit=False)\n",
- " \n",
- " showid = np.random.randint(len(hyps))\n",
- " logger.info(\"example source: \" + srcs[showid])\n",
- " logger.info(\"example hypothesis: \" + hyps[showid])\n",
- " logger.info(\"example reference: \" + refs[showid])\n",
- " \n",
- " # show bleu results\n",
- " logger.info(f\"validation loss:\\t{stats['loss']:.4f}\")\n",
- " logger.info(stats[\"bleu\"].format())\n",
- " return stats"
- ],
- "metadata": {
- "id": "y1o7LeDkPDsd"
- },
- "execution_count": null,
- "outputs": []
- },
- {
- "cell_type": "markdown",
- "source": [
- "# Save and Load Model Weights\n"
- ],
- "metadata": {
- "id": "1sRF6nd4PGEE"
- }
- },
- {
- "cell_type": "code",
- "source": [
- "def validate_and_save(model, task, criterion, optimizer, epoch, save=True): \n",
- " stats = validate(model, task, criterion)\n",
- " bleu = stats['bleu']\n",
- " loss = stats['loss']\n",
- " if save:\n",
- " # save epoch checkpoints\n",
- " savedir = Path(config.savedir).absolute()\n",
- " savedir.mkdir(parents=True, exist_ok=True)\n",
- " \n",
- " check = {\n",
- " \"model\": model.state_dict(),\n",
- " \"stats\": {\"bleu\": bleu.score, \"loss\": loss},\n",
- " \"optim\": {\"step\": optimizer._step}\n",
- " }\n",
- " torch.save(check, savedir/f\"checkpoint{epoch}.pt\")\n",
- " shutil.copy(savedir/f\"checkpoint{epoch}.pt\", savedir/f\"checkpoint_last.pt\")\n",
- " logger.info(f\"saved epoch checkpoint: {savedir}/checkpoint{epoch}.pt\")\n",
- " \n",
- " # save epoch samples\n",
- " with open(savedir/f\"samples{epoch}.{config.source_lang}-{config.target_lang}.txt\", \"w\") as f:\n",
- " for s, h in zip(stats[\"srcs\"], stats[\"hyps\"]):\n",
- " f.write(f\"{s}\\t{h}\\n\")\n",
- "\n",
- " # get best valid bleu \n",
- " if getattr(validate_and_save, \"best_bleu\", 0) < bleu.score:\n",
- " validate_and_save.best_bleu = bleu.score\n",
- " torch.save(check, savedir/f\"checkpoint_best.pt\")\n",
- " \n",
- " del_file = savedir / f\"checkpoint{epoch - config.keep_last_epochs}.pt\"\n",
- " if del_file.exists():\n",
- " del_file.unlink()\n",
- " \n",
- " return stats\n",
- "\n",
- "def try_load_checkpoint(model, optimizer=None, name=None):\n",
- " name = name if name else \"checkpoint_last.pt\"\n",
- " checkpath = Path(config.savedir)/name\n",
- " if checkpath.exists():\n",
- " check = torch.load(checkpath)\n",
- " model.load_state_dict(check[\"model\"])\n",
- " stats = check[\"stats\"]\n",
- " step = \"unknown\"\n",
- " if optimizer != None:\n",
- " optimizer._step = step = check[\"optim\"][\"step\"]\n",
- " logger.info(f\"loaded checkpoint {checkpath}: step={step} loss={stats['loss']} bleu={stats['bleu']}\")\n",
- " else:\n",
- " logger.info(f\"no checkpoints found at {checkpath}!\")"
- ],
- "metadata": {
- "id": "edBuLlkuPGr9"
- },
- "execution_count": null,
- "outputs": []
- },
- {
- "cell_type": "markdown",
- "source": [
- "# Main\n",
- "## Training loop"
- ],
- "metadata": {
- "id": "KyIFpibfPJ5u"
- }
- },
- {
- "cell_type": "code",
- "source": [
- "model = model.to(device=device)\n",
- "criterion = criterion.to(device=device)"
- ],
- "metadata": {
- "id": "hu7RZbCUPKQr"
- },
- "execution_count": null,
- "outputs": []
- },
- {
- "cell_type": "code",
- "source": [
- "logger.info(\"task: {}\".format(task.__class__.__name__))\n",
- "logger.info(\"encoder: {}\".format(model.encoder.__class__.__name__))\n",
- "logger.info(\"decoder: {}\".format(model.decoder.__class__.__name__))\n",
- "logger.info(\"criterion: {}\".format(criterion.__class__.__name__))\n",
- "logger.info(\"optimizer: {}\".format(optimizer.__class__.__name__))\n",
- "logger.info(\n",
- " \"num. model params: {:,} (num. trained: {:,})\".format(\n",
- " sum(p.numel() for p in model.parameters()),\n",
- " sum(p.numel() for p in model.parameters() if p.requires_grad),\n",
- " )\n",
- ")\n",
- "logger.info(f\"max tokens per batch = {config.max_tokens}, accumulate steps = {config.accum_steps}\")"
- ],
- "metadata": {
- "id": "5xxlJxU2PeAo"
- },
- "execution_count": null,
- "outputs": []
- },
- {
- "cell_type": "code",
- "source": [
- "epoch_itr = load_data_iterator(task, \"train\", config.start_epoch, config.max_tokens, config.num_workers)\n",
- "try_load_checkpoint(model, optimizer, name=config.resume)\n",
- "while epoch_itr.next_epoch_idx <= config.max_epoch:\n",
- " # train for one epoch\n",
- " train_one_epoch(epoch_itr, model, task, criterion, optimizer, config.accum_steps)\n",
- " stats = validate_and_save(model, task, criterion, optimizer, epoch=epoch_itr.epoch)\n",
- " logger.info(\"end of epoch {}\".format(epoch_itr.epoch)) \n",
- " epoch_itr = load_data_iterator(task, \"train\", epoch_itr.next_epoch_idx, config.max_tokens, config.num_workers)"
- ],
- "metadata": {
- "id": "MSPRqpQUPfaX"
- },
- "execution_count": null,
- "outputs": []
- },
- {
- "cell_type": "markdown",
- "source": [
- "# Submission"
- ],
- "metadata": {
- "id": "KyjRwllxPjtf"
- }
- },
- {
- "cell_type": "code",
- "source": [
- "# averaging a few checkpoints can have a similar effect to ensemble\n",
- "checkdir=config.savedir\n",
- "!python ./fairseq/scripts/average_checkpoints.py \\\n",
- "--inputs {checkdir} \\\n",
- "--num-epoch-checkpoints 5 \\\n",
- "--output {checkdir}/avg_last_5_checkpoint.pt"
- ],
- "metadata": {
- "id": "N70Gc6smPi1d"
- },
- "execution_count": null,
- "outputs": []
- },
- {
- "cell_type": "markdown",
- "source": [
- "## Confirm model weights used to generate submission"
- ],
- "metadata": {
- "id": "BAGMiun8PnZy"
- }
- },
- {
- "cell_type": "code",
- "source": [
- "# checkpoint_last.pt : latest epoch\n",
- "# checkpoint_best.pt : highest validation bleu\n",
- "# avg_last_5_checkpoint.pt: the average of last 5 epochs\n",
- "try_load_checkpoint(model, name=\"avg_last_5_checkpoint.pt\")\n",
- "validate(model, task, criterion, log_to_wandb=False)\n",
- "None"
- ],
- "metadata": {
- "id": "tvRdivVUPnsU"
- },
- "execution_count": null,
- "outputs": []
- },
- {
- "cell_type": "markdown",
- "source": [
- "## Generate Prediction"
- ],
- "metadata": {
- "id": "ioAIflXpPsxt"
- }
- },
- {
- "cell_type": "code",
- "source": [
- "def generate_prediction(model, task, split=\"test\", outfile=\"./prediction.txt\"): \n",
- " task.load_dataset(split=split, epoch=1)\n",
- " itr = load_data_iterator(task, split, 1, config.max_tokens, config.num_workers).next_epoch_itr(shuffle=False)\n",
- " \n",
- " idxs = []\n",
- " hyps = []\n",
- "\n",
- " model.eval()\n",
- " progress = tqdm.tqdm(itr, desc=f\"prediction\")\n",
- " with torch.no_grad():\n",
- " for i, sample in enumerate(progress):\n",
- " # validation loss\n",
- " sample = utils.move_to_cuda(sample, device=device)\n",
- "\n",
- " # do inference\n",
- " s, h, r = inference_step(sample, model)\n",
- " \n",
- " hyps.extend(h)\n",
- " idxs.extend(list(sample['id']))\n",
- " \n",
- " # sort based on the order before preprocess\n",
- " hyps = [x for _,x in sorted(zip(idxs,hyps))]\n",
- " \n",
- " with open(outfile, \"w\") as f:\n",
- " for h in hyps:\n",
- " f.write(h+\"\\n\")"
- ],
- "metadata": {
- "id": "oYMxA8FlPtIq"
- },
- "execution_count": null,
- "outputs": []
- },
- {
- "cell_type": "code",
- "source": [
- "generate_prediction(model, task)"
- ],
- "metadata": {
- "id": "Le4RFWXxjmm0"
- },
- "execution_count": null,
- "outputs": []
- },
- {
- "cell_type": "code",
- "source": [
- "raise"
- ],
- "metadata": {
- "id": "wvenyi6BPwnD"
- },
- "execution_count": null,
- "outputs": []
- },
- {
- "cell_type": "markdown",
- "source": [
- "# Back-translation"
- ],
- "metadata": {
- "id": "1z0cJE-wPzaU"
- }
- },
- {
- "cell_type": "markdown",
- "source": [
- "## Train a backward translation model"
- ],
- "metadata": {
- "id": "5-7uPJ2CP0sm"
- }
- },
- {
- "cell_type": "markdown",
- "source": [
- "1. Switch the source_lang and target_lang in **config** \n",
- "2. Change the savedir in **config** (eg. \"./checkpoints/transformer-back\")\n",
- "3. Train model"
- ],
- "metadata": {
- "id": "ppGHjg2ZP3sV"
- }
- },
- {
- "cell_type": "markdown",
- "source": [
- "## Generate synthetic data with backward model "
- ],
- "metadata": {
- "id": "waTGz29UP6WI"
- }
- },
- {
- "cell_type": "markdown",
- "source": [
- "### Download monolingual data"
- ],
- "metadata": {
- "id": "sIeTsPexP8FL"
- }
- },
- {
- "cell_type": "code",
- "source": [
- "mono_dataset_name = 'mono'"
- ],
- "metadata": {
- "id": "i7N4QlsbP8fh"
- },
- "execution_count": null,
- "outputs": []
- },
- {
- "cell_type": "code",
- "source": [
- "mono_prefix = Path(data_dir).absolute() / mono_dataset_name\n",
- "mono_prefix.mkdir(parents=True, exist_ok=True)\n",
- "\n",
- "urls = (\n",
- " \"https://github.com/yuhsinchan/ML2022-HW5Dataset/releases/download/v1.0.2/ted_zh_corpus.deduped.gz\"\n",
- ")\n",
- "file_names = (\n",
- " 'ted_zh_corpus.deduped.gz',\n",
- ")\n",
- "\n",
- "for u, f in zip(urls, file_names):\n",
- " path = mono_prefix/f\n",
- " if not path.exists():\n",
- " else:\n",
- " !wget {u} -O {path}\n",
- " else:\n",
- " print(f'{f} is exist, skip downloading')\n",
- " if path.suffix == \".tgz\":\n",
- " !tar -xvf {path} -C {prefix}\n",
- " elif path.suffix == \".zip\":\n",
- " !unzip -o {path} -d {prefix}\n",
- " elif path.suffix == \".gz\":\n",
- " !gzip -fkd {path}"
- ],
- "metadata": {
- "id": "396saD9-QBPY"
- },
- "execution_count": null,
- "outputs": []
- },
- {
- "cell_type": "markdown",
- "source": [
- "### TODO: clean corpus\n",
- "\n",
- "1. remove sentences that are too long or too short\n",
- "2. unify punctuation\n",
- "\n",
- "hint: you can use clean_s() defined above to do this"
- ],
- "metadata": {
- "id": "JOVQRHzGQU4-"
- }
- },
- {
- "cell_type": "code",
- "source": [
- ""
- ],
- "metadata": {
- "id": "eIYmxfUOQSov"
- },
- "execution_count": null,
- "outputs": []
- },
- {
- "cell_type": "markdown",
- "source": [
- "### TODO: Subword Units\n",
- "\n",
- "Use the spm model of the backward model to tokenize the data into subword units\n",
- "\n",
- "hint: spm model is located at DATA/raw-data/\\[dataset\\]/spm\\[vocab_num\\].model"
- ],
- "metadata": {
- "id": "jegH0bvMQVmR"
- }
- },
- {
- "cell_type": "code",
- "source": [
- ""
- ],
- "metadata": {
- "id": "vqgR4uUMQZGY"
- },
- "execution_count": null,
- "outputs": []
- },
- {
- "cell_type": "markdown",
- "source": [
- "### Binarize\n",
- "\n",
- "use fairseq to binarize data"
- ],
- "metadata": {
- "id": "a65glBVXQZiE"
- }
- },
- {
- "cell_type": "code",
- "source": [
- "binpath = Path('./DATA/data-bin', mono_dataset_name)\n",
- "src_dict_file = './DATA/data-bin/ted2020/dict.en.txt'\n",
- "tgt_dict_file = src_dict_file\n",
- "monopref = str(mono_prefix/\"mono.tok\") # whatever filepath you get after applying subword tokenization\n",
- "if binpath.exists():\n",
- " print(binpath, \"exists, will not overwrite!\")\n",
- "else:\n",
- " !python -m fairseq_cli.preprocess\\\n",
- " --source-lang 'zh'\\\n",
- " --target-lang 'en'\\\n",
- " --trainpref {monopref}\\\n",
- " --destdir {binpath}\\\n",
- " --srcdict {src_dict_file}\\\n",
- " --tgtdict {tgt_dict_file}\\\n",
- " --workers 2"
- ],
- "metadata": {
- "id": "b803qA5aQaEu"
- },
- "execution_count": null,
- "outputs": []
- },
- {
- "cell_type": "markdown",
- "source": [
- "### TODO: Generate synthetic data with backward model\n",
- "\n",
- "Add binarized monolingual data to the original data directory, and name it with \"split_name\"\n",
- "\n",
- "ex. ./DATA/data-bin/ted2020/\\[split_name\\].zh-en.\\[\"en\", \"zh\"\\].\\[\"bin\", \"idx\"\\]\n",
- "\n",
- "then you can use 'generate_prediction(model, task, split=\"split_name\")' to generate translation prediction"
- ],
- "metadata": {
- "id": "smA0JraEQdxz"
- }
- },
- {
- "cell_type": "code",
- "source": [
- "# Add binarized monolingual data to the original data directory, and name it with \"split_name\"\n",
- "# ex. ./DATA/data-bin/ted2020/\\[split_name\\].zh-en.\\[\"en\", \"zh\"\\].\\[\"bin\", \"idx\"\\]\n",
- "!cp ./DATA/data-bin/mono/train.zh-en.zh.bin ./DATA/data-bin/ted2020/mono.zh-en.zh.bin\n",
- "!cp ./DATA/data-bin/mono/train.zh-en.zh.idx ./DATA/data-bin/ted2020/mono.zh-en.zh.idx\n",
- "!cp ./DATA/data-bin/mono/train.zh-en.en.bin ./DATA/data-bin/ted2020/mono.zh-en.en.bin\n",
- "!cp ./DATA/data-bin/mono/train.zh-en.en.idx ./DATA/data-bin/ted2020/mono.zh-en.en.idx"
- ],
- "metadata": {
- "id": "jvaOVHeoQfkB"
- },
- "execution_count": null,
- "outputs": []
- },
- {
- "cell_type": "code",
- "source": [
- "# hint: do prediction on split='mono' to create prediction_file\n",
- "# generate_prediction( ... ,split=... ,outfile=... )"
- ],
- "metadata": {
- "id": "fFEkxPu-Qhlc"
- },
- "execution_count": null,
- "outputs": []
- },
- {
- "cell_type": "markdown",
- "source": [
- "### TODO: Create new dataset\n",
- "\n",
- "1. Combine the prediction data with monolingual data\n",
- "2. Use the original spm model to tokenize data into Subword Units\n",
- "3. Binarize data with fairseq"
- ],
- "metadata": {
- "id": "Jn4XeawpQjLk"
- }
- },
- {
- "cell_type": "code",
- "source": [
- "# Combine prediction_file (.en) and mono.zh (.zh) into a new dataset.\n",
- "# \n",
- "# hint: tokenize prediction_file with the spm model\n",
- "# spm_model.encode(line, out_type=str)\n",
- "# output: ./DATA/rawdata/mono/mono.tok.en & mono.tok.zh\n",
- "#\n",
- "# hint: use fairseq to binarize these two files again\n",
- "# binpath = Path('./DATA/data-bin/synthetic')\n",
- "# src_dict_file = './DATA/data-bin/ted2020/dict.en.txt'\n",
- "# tgt_dict_file = src_dict_file\n",
- "# monopref = ./DATA/rawdata/mono/mono.tok # or whatever path after applying subword tokenization, w/o the suffix (.zh/.en)\n",
- "# if binpath.exists():\n",
- "# print(binpath, \"exists, will not overwrite!\")\n",
- "# else:\n",
- "# !python -m fairseq_cli.preprocess\\\n",
- "# --source-lang 'zh'\\\n",
- "# --target-lang 'en'\\\n",
- "# --trainpref {monopref}\\\n",
- "# --destdir {binpath}\\\n",
- "# --srcdict {src_dict_file}\\\n",
- "# --tgtdict {tgt_dict_file}\\\n",
- "# --workers 2"
- ],
- "metadata": {
- "id": "3R35JTaTQjkm"
- },
- "execution_count": null,
- "outputs": []
- },
- {
- "cell_type": "code",
- "source": [
- "# create a new dataset from all the files prepared above\n",
- "!cp -r ./DATA/data-bin/ted2020/ ./DATA/data-bin/ted2020_with_mono/\n",
- "\n",
- "!cp ./DATA/data-bin/synthetic/train.zh-en.zh.bin ./DATA/data-bin/ted2020_with_mono/train1.en-zh.zh.bin\n",
- "!cp ./DATA/data-bin/synthetic/train.zh-en.zh.idx ./DATA/data-bin/ted2020_with_mono/train1.en-zh.zh.idx\n",
- "!cp ./DATA/data-bin/synthetic/train.zh-en.en.bin ./DATA/data-bin/ted2020_with_mono/train1.en-zh.en.bin\n",
- "!cp ./DATA/data-bin/synthetic/train.zh-en.en.idx ./DATA/data-bin/ted2020_with_mono/train1.en-zh.en.idx"
- ],
- "metadata": {
- "id": "MSkse1tyQnsR"
- },
- "execution_count": null,
- "outputs": []
- },
- {
- "cell_type": "markdown",
- "source": [
- "Created new dataset \"ted2020_with_mono\"\n",
- "\n",
- "1. Change the datadir in **config** (\"./DATA/data-bin/ted2020_with_mono\")\n",
- "2. Switch back the source_lang and target_lang in **config** (\"en\", \"zh\")\n",
- "2. Change the savedir in **config** (eg. \"./checkpoints/transformer-bt\")\n",
- "3. Train model"
- ],
- "metadata": {
- "id": "YVdxVGO3QrSs"
- }
- },
- {
- "cell_type": "markdown",
- "source": [
- "1. <a name=ott2019fairseq></a>Ott, M., Edunov, S., Baevski, A., Fan, A., Gross, S., Ng, N., ... & Auli, M. (2019, June). fairseq: A Fast, Extensible Toolkit for Sequence Modeling. In Proceedings of the 2019 Conference of the North American Chapter of the Association for Computational Linguistics (Demonstrations) (pp. 48-53).\n",
- "2. <a name=vaswani2017></a>Vaswani, A., Shazeer, N., Parmar, N., Uszkoreit, J., Jones, L., Gomez, A. N., ... & Polosukhin, I. (2017, December). Attention is all you need. In Proceedings of the 31st International Conference on Neural Information Processing Systems (pp. 6000-6010).\n",
- "3. <a name=reimers-2020-multilingual-sentence-bert></a>Reimers, N., & Gurevych, I. (2020, November). Making Monolingual Sentence Embeddings Multilingual Using Knowledge Distillation. In Proceedings of the 2020 Conference on Empirical Methods in Natural Language Processing (EMNLP) (pp. 4512-4525).\n",
- "4. <a name=tiedemann2012parallel></a>Tiedemann, J. (2012, May). Parallel Data, Tools and Interfaces in OPUS. In Lrec (Vol. 2012, pp. 2214-2218).\n",
- "5. <a name=kudo-richardson-2018-sentencepiece></a>Kudo, T., & Richardson, J. (2018, November). SentencePiece: A simple and language independent subword tokenizer and detokenizer for Neural Text Processing. In Proceedings of the 2018 Conference on Empirical Methods in Natural Language Processing: System Demonstrations (pp. 66-71).\n",
- "6. <a name=sennrich-etal-2016-improving></a>Sennrich, R., Haddow, B., & Birch, A. (2016, August). Improving Neural Machine Translation Models with Monolingual Data. In Proceedings of the 54th Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers) (pp. 86-96).\n",
- "7. <a name=edunov-etal-2018-understanding></a>Edunov, S., Ott, M., Auli, M., & Grangier, D. (2018). Understanding Back-Translation at Scale. In Proceedings of the 2018 Conference on Empirical Methods in Natural Language Processing (pp. 489-500).\n",
- "8. https://github.com/ajinkyakulkarni14/TED-Multilingual-Parallel-Corpus\n",
- "9. https://ithelp.ithome.com.tw/articles/10233122\n",
- "10. https://nlp.seas.harvard.edu/2018/04/03/attention.html\n",
- "11. https://colab.research.google.com/github/ga642381/ML2021-Spring/blob/main/HW05/HW05.ipynb"
- ],
- "metadata": {
- "id": "_CZU2beUQtl3"
- }
- },
- {
- "cell_type": "code",
- "source": [
- ""
- ],
- "metadata": {
- "id": "Rrfm6iLJQ0tS"
- },
- "execution_count": null,
- "outputs": []
- }
- ]
- }
|