You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

ML2022Spring_HW8.ipynb 17 kB

3 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558
  1. {
  2. "nbformat": 4,
  3. "nbformat_minor": 0,
  4. "metadata": {
  5. "accelerator": "GPU",
  6. "colab": {
  7. "name": "ML2022Spring - HW8.ipynb",
  8. "provenance": [],
  9. "collapsed_sections": [
  10. "bDk9r2YOcDc9",
  11. "Oi12tJMYWi0Q",
  12. "DCgNXSsEWuY7",
  13. "HNe7QU7n7cqh",
  14. "6X6fkGPnYyaF",
  15. "1EbfwRREhA7c",
  16. "vrJ9bScg9AgO",
  17. "XKNUImqUhIeq"
  18. ]
  19. },
  20. "kernelspec": {
  21. "display_name": "Python 3",
  22. "name": "python3"
  23. }
  24. },
  25. "cells": [
  26. {
  27. "cell_type": "markdown",
  28. "metadata": {
  29. "id": "YiVfKn-6tXz8"
  30. },
  31. "source": [
  32. "# **Homework 8 - Anomaly Detection**\n",
  33. "\n",
  34. "If there are any questions, please contact mlta-2022spring-ta@googlegroups.com\n",
  35. "\n",
  36. "Slide: [Link]() Kaggle: [Link](https://www.kaggle.com/c/ml2022spring-hw8)"
  37. ]
  38. },
  39. {
  40. "cell_type": "markdown",
  41. "metadata": {
  42. "id": "bDk9r2YOcDc9"
  43. },
  44. "source": [
  45. "# Set up the environment\n"
  46. ]
  47. },
  48. {
  49. "cell_type": "markdown",
  50. "metadata": {
  51. "id": "Oi12tJMYWi0Q"
  52. },
  53. "source": [
  54. "## Package installation"
  55. ]
  56. },
  57. {
  58. "cell_type": "code",
  59. "metadata": {
  60. "id": "7LexxyPWWjJB"
  61. },
  62. "source": [
  63. "# Training progress bar\n",
  64. "!pip install -q qqdm"
  65. ],
  66. "execution_count": null,
  67. "outputs": []
  68. },
  69. {
  70. "cell_type": "markdown",
  71. "metadata": {
  72. "id": "DCgNXSsEWuY7"
  73. },
  74. "source": [
  75. "## Downloading data"
  76. ]
  77. },
  78. {
  79. "cell_type": "code",
  80. "source": [
  81. "!wget https://github.com/MachineLearningHW/HW8_Dataset/releases/download/v1.0.0/data.zip"
  82. ],
  83. "metadata": {
  84. "id": "SCLJtgF2BLSK"
  85. },
  86. "execution_count": null,
  87. "outputs": []
  88. },
  89. {
  90. "cell_type": "code",
  91. "metadata": {
  92. "id": "0K5kmlkuWzhJ"
  93. },
  94. "source": [
  95. "!unzip data.zip"
  96. ],
  97. "execution_count": null,
  98. "outputs": []
  99. },
  100. {
  101. "cell_type": "markdown",
  102. "metadata": {
  103. "id": "HNe7QU7n7cqh"
  104. },
  105. "source": [
  106. "# Import packages"
  107. ]
  108. },
  109. {
  110. "cell_type": "code",
  111. "metadata": {
  112. "id": "Jk3qFK_a7k8P"
  113. },
  114. "source": [
  115. "import random\n",
  116. "import numpy as np\n",
  117. "import torch\n",
  118. "from torch import nn\n",
  119. "from torch.utils.data import DataLoader, RandomSampler, SequentialSampler, TensorDataset\n",
  120. "import torchvision.transforms as transforms\n",
  121. "import torch.nn.functional as F\n",
  122. "from torch.autograd import Variable\n",
  123. "import torchvision.models as models\n",
  124. "from torch.optim import Adam, AdamW\n",
  125. "from qqdm import qqdm, format_str\n",
  126. "import pandas as pd"
  127. ],
  128. "execution_count": null,
  129. "outputs": []
  130. },
  131. {
  132. "cell_type": "markdown",
  133. "metadata": {
  134. "id": "6X6fkGPnYyaF"
  135. },
  136. "source": [
  137. "# Loading data"
  138. ]
  139. },
  140. {
  141. "cell_type": "code",
  142. "metadata": {
  143. "id": "k7Wd4yiUYzAm"
  144. },
  145. "source": [
  146. "\n",
  147. "train = np.load('data/trainingset.npy', allow_pickle=True)\n",
  148. "test = np.load('data/testingset.npy', allow_pickle=True)\n",
  149. "\n",
  150. "print(train.shape)\n",
  151. "print(test.shape)"
  152. ],
  153. "execution_count": null,
  154. "outputs": []
  155. },
  156. {
  157. "cell_type": "markdown",
  158. "metadata": {
  159. "id": "_flpmj6OYIa6"
  160. },
  161. "source": [
  162. "## Random seed\n",
  163. "Set the random seed to a certain value for reproducibility."
  164. ]
  165. },
  166. {
  167. "cell_type": "code",
  168. "metadata": {
  169. "id": "Gb-dgXQYYI2Q"
  170. },
  171. "source": [
  172. "def same_seeds(seed):\n",
  173. " random.seed(seed)\n",
  174. " np.random.seed(seed)\n",
  175. " torch.manual_seed(seed)\n",
  176. " if torch.cuda.is_available():\n",
  177. " torch.cuda.manual_seed(seed)\n",
  178. " torch.cuda.manual_seed_all(seed)\n",
  179. " torch.backends.cudnn.benchmark = False\n",
  180. " torch.backends.cudnn.deterministic = True\n",
  181. "\n",
  182. "same_seeds(48763)"
  183. ],
  184. "execution_count": null,
  185. "outputs": []
  186. },
  187. {
  188. "cell_type": "markdown",
  189. "metadata": {
  190. "id": "zR9zC0_Df-CR"
  191. },
  192. "source": [
  193. "# Autoencoder"
  194. ]
  195. },
  196. {
  197. "cell_type": "markdown",
  198. "metadata": {
  199. "id": "1EbfwRREhA7c"
  200. },
  201. "source": [
  202. "# Models & loss"
  203. ]
  204. },
  205. {
  206. "cell_type": "code",
  207. "metadata": {
  208. "id": "Wi8ds1fugCkR"
  209. },
  210. "source": [
  211. "class fcn_autoencoder(nn.Module):\n",
  212. " def __init__(self):\n",
  213. " super(fcn_autoencoder, self).__init__()\n",
  214. " self.encoder = nn.Sequential(\n",
  215. " nn.Linear(64 * 64 * 3, 128),\n",
  216. " nn.ReLU(),\n",
  217. " nn.Linear(128, 64),\n",
  218. " nn.ReLU(), \n",
  219. " nn.Linear(64, 12), \n",
  220. " nn.ReLU(), \n",
  221. " nn.Linear(12, 3)\n",
  222. " )\n",
  223. " \n",
  224. " self.decoder = nn.Sequential(\n",
  225. " nn.Linear(3, 12),\n",
  226. " nn.ReLU(), \n",
  227. " nn.Linear(12, 64),\n",
  228. " nn.ReLU(),\n",
  229. " nn.Linear(64, 128),\n",
  230. " nn.ReLU(), \n",
  231. " nn.Linear(128, 64 * 64 * 3), \n",
  232. " nn.Tanh()\n",
  233. " )\n",
  234. "\n",
  235. " def forward(self, x):\n",
  236. " x = self.encoder(x)\n",
  237. " x = self.decoder(x)\n",
  238. " return x\n",
  239. "\n",
  240. "\n",
  241. "class conv_autoencoder(nn.Module):\n",
  242. " def __init__(self):\n",
  243. " super(conv_autoencoder, self).__init__()\n",
  244. " self.encoder = nn.Sequential(\n",
  245. " nn.Conv2d(3, 12, 4, stride=2, padding=1), \n",
  246. " nn.ReLU(),\n",
  247. " nn.Conv2d(12, 24, 4, stride=2, padding=1), \n",
  248. " nn.ReLU(),\n",
  249. "\t\t\t nn.Conv2d(24, 48, 4, stride=2, padding=1), \n",
  250. " nn.ReLU(),\n",
  251. " )\n",
  252. " self.decoder = nn.Sequential(\n",
  253. "\t\t\t nn.ConvTranspose2d(48, 24, 4, stride=2, padding=1),\n",
  254. " nn.ReLU(),\n",
  255. "\t\t\t nn.ConvTranspose2d(24, 12, 4, stride=2, padding=1), \n",
  256. " nn.ReLU(),\n",
  257. " nn.ConvTranspose2d(12, 3, 4, stride=2, padding=1),\n",
  258. " nn.Tanh(),\n",
  259. " )\n",
  260. "\n",
  261. " def forward(self, x):\n",
  262. " x = self.encoder(x)\n",
  263. " x = self.decoder(x)\n",
  264. " return x\n",
  265. "\n",
  266. "\n",
  267. "class VAE(nn.Module):\n",
  268. " def __init__(self):\n",
  269. " super(VAE, self).__init__()\n",
  270. " self.encoder = nn.Sequential(\n",
  271. " nn.Conv2d(3, 12, 4, stride=2, padding=1), \n",
  272. " nn.ReLU(),\n",
  273. " nn.Conv2d(12, 24, 4, stride=2, padding=1), \n",
  274. " nn.ReLU(),\n",
  275. " )\n",
  276. " self.enc_out_1 = nn.Sequential(\n",
  277. " nn.Conv2d(24, 48, 4, stride=2, padding=1), \n",
  278. " nn.ReLU(),\n",
  279. " )\n",
  280. " self.enc_out_2 = nn.Sequential(\n",
  281. " nn.Conv2d(24, 48, 4, stride=2, padding=1),\n",
  282. " nn.ReLU(),\n",
  283. " )\n",
  284. " self.decoder = nn.Sequential(\n",
  285. "\t\t\t nn.ConvTranspose2d(48, 24, 4, stride=2, padding=1), \n",
  286. " nn.ReLU(),\n",
  287. "\t\t\t nn.ConvTranspose2d(24, 12, 4, stride=2, padding=1), \n",
  288. " nn.ReLU(),\n",
  289. " nn.ConvTranspose2d(12, 3, 4, stride=2, padding=1), \n",
  290. " nn.Tanh(),\n",
  291. " )\n",
  292. "\n",
  293. " def encode(self, x):\n",
  294. " h1 = self.encoder(x)\n",
  295. " return self.enc_out_1(h1), self.enc_out_2(h1)\n",
  296. "\n",
  297. " def reparametrize(self, mu, logvar):\n",
  298. " std = logvar.mul(0.5).exp_()\n",
  299. " if torch.cuda.is_available():\n",
  300. " eps = torch.cuda.FloatTensor(std.size()).normal_()\n",
  301. " else:\n",
  302. " eps = torch.FloatTensor(std.size()).normal_()\n",
  303. " eps = Variable(eps)\n",
  304. " return eps.mul(std).add_(mu)\n",
  305. "\n",
  306. " def decode(self, z):\n",
  307. " return self.decoder(z)\n",
  308. "\n",
  309. " def forward(self, x):\n",
  310. " mu, logvar = self.encode(x)\n",
  311. " z = self.reparametrize(mu, logvar)\n",
  312. " return self.decode(z), mu, logvar\n",
  313. "\n",
  314. "\n",
  315. "def loss_vae(recon_x, x, mu, logvar, criterion):\n",
  316. " \"\"\"\n",
  317. " recon_x: generating images\n",
  318. " x: origin images\n",
  319. " mu: latent mean\n",
  320. " logvar: latent log variance\n",
  321. " \"\"\"\n",
  322. " mse = criterion(recon_x, x)\n",
  323. " KLD_element = mu.pow(2).add_(logvar.exp()).mul_(-1).add_(1).add_(logvar)\n",
  324. " KLD = torch.sum(KLD_element).mul_(-0.5)\n",
  325. " return mse + KLD"
  326. ],
  327. "execution_count": null,
  328. "outputs": []
  329. },
  330. {
  331. "cell_type": "markdown",
  332. "metadata": {
  333. "id": "vrJ9bScg9AgO"
  334. },
  335. "source": [
  336. "# Dataset module\n",
  337. "\n",
  338. "Module for obtaining and processing data. The transform function here normalizes image's pixels from [0, 255] to [-1.0, 1.0].\n"
  339. ]
  340. },
  341. {
  342. "cell_type": "code",
  343. "metadata": {
  344. "id": "33fWhE-h9LPq"
  345. },
  346. "source": [
  347. "class CustomTensorDataset(TensorDataset):\n",
  348. " \"\"\"TensorDataset with support of transforms.\n",
  349. " \"\"\"\n",
  350. " def __init__(self, tensors):\n",
  351. " self.tensors = tensors\n",
  352. " if tensors.shape[-1] == 3:\n",
  353. " self.tensors = tensors.permute(0, 3, 1, 2)\n",
  354. " \n",
  355. " self.transform = transforms.Compose([\n",
  356. " transforms.Lambda(lambda x: x.to(torch.float32)),\n",
  357. " transforms.Lambda(lambda x: 2. * x/255. - 1.),\n",
  358. " ])\n",
  359. " \n",
  360. " def __getitem__(self, index):\n",
  361. " x = self.tensors[index]\n",
  362. " \n",
  363. " if self.transform:\n",
  364. " # mapping images to [-1.0, 1.0]\n",
  365. " x = self.transform(x)\n",
  366. "\n",
  367. " return x\n",
  368. "\n",
  369. " def __len__(self):\n",
  370. " return len(self.tensors)"
  371. ],
  372. "execution_count": null,
  373. "outputs": []
  374. },
  375. {
  376. "cell_type": "markdown",
  377. "metadata": {
  378. "id": "XKNUImqUhIeq"
  379. },
  380. "source": [
  381. "# Training"
  382. ]
  383. },
  384. {
  385. "cell_type": "markdown",
  386. "metadata": {
  387. "id": "7ebAJdjFmS08"
  388. },
  389. "source": [
  390. "## Configuration\n"
  391. ]
  392. },
  393. {
  394. "cell_type": "code",
  395. "metadata": {
  396. "id": "in7yLfmqtZTk"
  397. },
  398. "source": [
  399. "# Training hyperparameters\n",
  400. "num_epochs = 50\n",
  401. "batch_size = 2000\n",
  402. "learning_rate = 1e-3\n",
  403. "\n",
  404. "# Build training dataloader\n",
  405. "x = torch.from_numpy(train)\n",
  406. "train_dataset = CustomTensorDataset(x)\n",
  407. "\n",
  408. "train_sampler = RandomSampler(train_dataset)\n",
  409. "train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=batch_size)\n",
  410. "\n",
  411. "# Model\n",
  412. "model_type = 'vae' # selecting a model type from {'cnn', 'fcn', 'vae', 'resnet'}\n",
  413. "model_classes = {'fcn': fcn_autoencoder(), 'cnn': conv_autoencoder(), 'vae': VAE()}\n",
  414. "model = model_classes[model_type].cuda()\n",
  415. "\n",
  416. "# Loss and optimizer\n",
  417. "criterion = nn.MSELoss()\n",
  418. "optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)"
  419. ],
  420. "execution_count": null,
  421. "outputs": []
  422. },
  423. {
  424. "cell_type": "markdown",
  425. "metadata": {
  426. "id": "wyooN-JPm8sS"
  427. },
  428. "source": [
  429. "## Training loop"
  430. ]
  431. },
  432. {
  433. "cell_type": "code",
  434. "metadata": {
  435. "id": "JoW1UrrxgI_U"
  436. },
  437. "source": [
  438. "\n",
  439. "best_loss = np.inf\n",
  440. "model.train()\n",
  441. "\n",
  442. "qqdm_train = qqdm(range(num_epochs), desc=format_str('bold', 'Description'))\n",
  443. "for epoch in qqdm_train:\n",
  444. " tot_loss = list()\n",
  445. " for data in train_dataloader:\n",
  446. "\n",
  447. " # ===================loading=====================\n",
  448. " img = data.float().cuda()\n",
  449. " if model_type in ['fcn']:\n",
  450. " img = img.view(img.shape[0], -1)\n",
  451. "\n",
  452. " # ===================forward=====================\n",
  453. " output = model(img)\n",
  454. " if model_type in ['vae']:\n",
  455. " loss = loss_vae(output[0], img, output[1], output[2], criterion)\n",
  456. " else:\n",
  457. " loss = criterion(output, img)\n",
  458. "\n",
  459. " tot_loss.append(loss.item())\n",
  460. " # ===================backward====================\n",
  461. " optimizer.zero_grad()\n",
  462. " loss.backward()\n",
  463. " optimizer.step()\n",
  464. " # ===================save_best====================\n",
  465. " mean_loss = np.mean(tot_loss)\n",
  466. " if mean_loss < best_loss:\n",
  467. " best_loss = mean_loss\n",
  468. " torch.save(model, 'best_model_{}.pt'.format(model_type))\n",
  469. " # ===================log========================\n",
  470. " qqdm_train.set_infos({\n",
  471. " 'epoch': f'{epoch + 1:.0f}/{num_epochs:.0f}',\n",
  472. " 'loss': f'{mean_loss:.4f}',\n",
  473. " })\n",
  474. " # ===================save_last========================\n",
  475. " torch.save(model, 'last_model_{}.pt'.format(model_type))"
  476. ],
  477. "execution_count": null,
  478. "outputs": []
  479. },
  480. {
  481. "cell_type": "markdown",
  482. "metadata": {
  483. "id": "Wk0UxFuchLzR"
  484. },
  485. "source": [
  486. "# Inference\n",
  487. "Model is loaded and generates its anomaly score predictions."
  488. ]
  489. },
  490. {
  491. "cell_type": "markdown",
  492. "metadata": {
  493. "id": "evgMW3OwoGqD"
  494. },
  495. "source": [
  496. "## Initialize\n",
  497. "- dataloader\n",
  498. "- model\n",
  499. "- prediction file"
  500. ]
  501. },
  502. {
  503. "cell_type": "code",
  504. "metadata": {
  505. "id": "_MBnXAswoKmq"
  506. },
  507. "source": [
  508. "eval_batch_size = 200\n",
  509. "\n",
  510. "# build testing dataloader\n",
  511. "data = torch.tensor(test, dtype=torch.float32)\n",
  512. "test_dataset = CustomTensorDataset(data)\n",
  513. "test_sampler = SequentialSampler(test_dataset)\n",
  514. "test_dataloader = DataLoader(test_dataset, sampler=test_sampler, batch_size=eval_batch_size, num_workers=1)\n",
  515. "eval_loss = nn.MSELoss(reduction='none')\n",
  516. "\n",
  517. "# load trained model\n",
  518. "checkpoint_path = f'last_model_{model_type}.pt'\n",
  519. "model = torch.load(checkpoint_path)\n",
  520. "model.eval()\n",
  521. "\n",
  522. "# prediction file \n",
  523. "out_file = 'prediction.csv'"
  524. ],
  525. "execution_count": null,
  526. "outputs": []
  527. },
  528. {
  529. "cell_type": "code",
  530. "source": [
  531. "anomality = list()\n",
  532. "with torch.no_grad():\n",
  533. " for i, data in enumerate(test_dataloader):\n",
  534. " img = data.float().cuda()\n",
  535. " if model_type in ['fcn']:\n",
  536. " img = img.view(img.shape[0], -1)\n",
  537. " output = model(img)\n",
  538. " if model_type in ['vae']:\n",
  539. " output = output[0]\n",
  540. " if model_type in ['fcn']:\n",
  541. " loss = eval_loss(output, img).sum(-1)\n",
  542. " else:\n",
  543. " loss = eval_loss(output, img).sum([1, 2, 3])\n",
  544. " anomality.append(loss)\n",
  545. "anomality = torch.cat(anomality, axis=0)\n",
  546. "anomality = torch.sqrt(anomality).reshape(len(test), 1).cpu().numpy()\n",
  547. "\n",
  548. "df = pd.DataFrame(anomality, columns=['score'])\n",
  549. "df.to_csv(out_file, index_label = 'ID')"
  550. ],
  551. "metadata": {
  552. "id": "_1IxCX2iCW6V"
  553. },
  554. "execution_count": null,
  555. "outputs": []
  556. }
  557. ]
  558. }