diff --git a/11 Quantum ML/作业HW12/HW12_EN.pdf b/11 Quantum ML/作业HW12/HW12_EN.pdf new file mode 100644 index 0000000..f07f555 Binary files /dev/null and b/11 Quantum ML/作业HW12/HW12_EN.pdf differ diff --git a/11 Quantum ML/作业HW12/hw12_reinforcement_learning_chinese_version.ipynb b/11 Quantum ML/作业HW12/hw12_reinforcement_learning_chinese_version.ipynb new file mode 100644 index 0000000..771a9aa --- /dev/null +++ b/11 Quantum ML/作业HW12/hw12_reinforcement_learning_chinese_version.ipynb @@ -0,0 +1,3645 @@ +{ + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "colab": { + "name": "hw12_reinforcement_learning_chinese_version.ipynb", + "provenance": [], + "collapsed_sections": [] + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + }, + "widgets": { + "application/vnd.jupyter.widget-state+json": { + "2acab9542fe64b979fa2ac2adb3f10a8": { + "model_module": "@jupyter-widgets/controls", + "model_name": "HBoxModel", + "state": { + "_view_name": "HBoxView", + "_dom_classes": [], + "_model_name": "HBoxModel", + "_view_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_view_count": null, + "_view_module_version": "1.5.0", + "box_style": "", + "layout": "IPY_MODEL_f288c64b5ff748eb82178bf1de17934f", + "_model_module": "@jupyter-widgets/controls", + "children": [ + "IPY_MODEL_de34e5b178f5470e98e0275102a65042", + "IPY_MODEL_c93cba301cac439ca56fb6b45bd1c4e4" + ] + } + }, + "f288c64b5ff748eb82178bf1de17934f": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "state": { + "_view_name": "LayoutView", + "grid_template_rows": null, + "right": null, + "justify_content": null, + "_view_module": "@jupyter-widgets/base", + "overflow": null, + "_model_module_version": "1.2.0", + "_view_count": null, + "flex_flow": null, + "width": null, + "min_width": null, + "border": null, + "align_items": null, + "bottom": null, + "_model_module": "@jupyter-widgets/base", + "top": null, + "grid_column": null, + "overflow_y": null, + "overflow_x": null, + "grid_auto_flow": null, + "grid_area": null, + "grid_template_columns": null, + "flex": null, + "_model_name": "LayoutModel", + "justify_items": null, + "grid_row": null, + "max_height": null, + "align_content": null, + "visibility": null, + "align_self": null, + "height": null, + "min_height": null, + "padding": null, + "grid_auto_rows": null, + "grid_gap": null, + "max_width": null, + "order": null, + "_view_module_version": "1.2.0", + "grid_template_areas": null, + "object_position": null, + "object_fit": null, + "grid_auto_columns": null, + "margin": null, + "display": null, + "left": null + } + }, + "de34e5b178f5470e98e0275102a65042": { + "model_module": "@jupyter-widgets/controls", + "model_name": "FloatProgressModel", + "state": { + "_view_name": "ProgressView", + "style": "IPY_MODEL_43c6ee720b674626ab3a869bda5dd6e3", + "_dom_classes": [], + "description": "Total: -24.0, Final: -40.0: 100%", + "_model_name": "FloatProgressModel", + "bar_style": "success", + "max": 400, + "_view_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "value": 400, + "_view_count": null, + "_view_module_version": "1.5.0", + "orientation": "horizontal", + "min": 0, + "description_tooltip": null, + "_model_module": "@jupyter-widgets/controls", + "layout": "IPY_MODEL_2465d2b109d34922a486341232d86ad6" + } + }, + "c93cba301cac439ca56fb6b45bd1c4e4": { + "model_module": "@jupyter-widgets/controls", + "model_name": "HTMLModel", + "state": { + "_view_name": "HTMLView", + "style": "IPY_MODEL_aa27187195be4da9874025395eac35eb", + "_dom_classes": [], + "description": "", + "_model_name": "HTMLModel", + "placeholder": "​", + "_view_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "value": " 400/400 [11:02<00:00, 1.66s/it]", + "_view_count": null, + "_view_module_version": "1.5.0", + "description_tooltip": null, + "_model_module": "@jupyter-widgets/controls", + "layout": "IPY_MODEL_02d196d4f9734f998455d92bd9300adb" + } + }, + "43c6ee720b674626ab3a869bda5dd6e3": { + "model_module": "@jupyter-widgets/controls", + "model_name": "ProgressStyleModel", + "state": { + "_view_name": "StyleView", + "_model_name": "ProgressStyleModel", + "description_width": "initial", + "_view_module": "@jupyter-widgets/base", + "_model_module_version": "1.5.0", + "_view_count": null, + "_view_module_version": "1.2.0", + "bar_color": null, + "_model_module": "@jupyter-widgets/controls" + } + }, + "2465d2b109d34922a486341232d86ad6": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "state": { + "_view_name": "LayoutView", + "grid_template_rows": null, + "right": null, + "justify_content": null, + "_view_module": "@jupyter-widgets/base", + "overflow": null, + "_model_module_version": "1.2.0", + "_view_count": null, + "flex_flow": null, + "width": null, + "min_width": null, + "border": null, + "align_items": null, + "bottom": null, + "_model_module": "@jupyter-widgets/base", + "top": null, + "grid_column": null, + "overflow_y": null, + "overflow_x": null, + "grid_auto_flow": null, + "grid_area": null, + "grid_template_columns": null, + "flex": null, + "_model_name": "LayoutModel", + "justify_items": null, + "grid_row": null, + "max_height": null, + "align_content": null, + "visibility": null, + "align_self": null, + "height": null, + "min_height": null, + "padding": null, + "grid_auto_rows": null, + "grid_gap": null, + "max_width": null, + "order": null, + "_view_module_version": "1.2.0", + "grid_template_areas": null, + "object_position": null, + "object_fit": null, + "grid_auto_columns": null, + "margin": null, + "display": null, + "left": null + } + }, + "aa27187195be4da9874025395eac35eb": { + "model_module": "@jupyter-widgets/controls", + "model_name": "DescriptionStyleModel", + "state": { + "_view_name": "StyleView", + "_model_name": "DescriptionStyleModel", + "description_width": "", + "_view_module": "@jupyter-widgets/base", + "_model_module_version": "1.5.0", + "_view_count": null, + "_view_module_version": "1.2.0", + "_model_module": "@jupyter-widgets/controls" + } + }, + "02d196d4f9734f998455d92bd9300adb": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "state": { + "_view_name": "LayoutView", + "grid_template_rows": null, + "right": null, + "justify_content": null, + "_view_module": "@jupyter-widgets/base", + "overflow": null, + "_model_module_version": "1.2.0", + "_view_count": null, + "flex_flow": null, + "width": null, + "min_width": null, + "border": null, + "align_items": null, + "bottom": null, + "_model_module": "@jupyter-widgets/base", + "top": null, + "grid_column": null, + "overflow_y": null, + "overflow_x": null, + "grid_auto_flow": null, + "grid_area": null, + "grid_template_columns": null, + "flex": null, + "_model_name": "LayoutModel", + "justify_items": null, + "grid_row": null, + "max_height": null, + "align_content": null, + "visibility": null, + "align_self": null, + "height": null, + "min_height": null, + "padding": null, + "grid_auto_rows": null, + "grid_gap": null, + "max_width": null, + "order": null, + "_view_module_version": "1.2.0", + "grid_template_areas": null, + "object_position": null, + "object_fit": null, + "grid_auto_columns": null, + "margin": null, + "display": null, + "left": null + } + } + } + }, + "accelerator": "GPU" + }, + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "Fp30SB4bxeQb" + }, + "source": [ + "# **Homework 12 - Reinforcement Learning**\n", + "\n", + "若有任何問題,歡迎來信至助教信箱 ntu-ml-2021spring-ta@googlegroups.com\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "yXsnCWPtWSNk" + }, + "source": [ + "## 前置作業\n", + "\n", + "首先我們需要安裝必要的系統套件及 PyPi 套件。\n", + "gym 這個套件由 OpenAI 所提供,是一套用來開發與比較 Reinforcement Learning 演算法的工具包(toolkit)。\n", + "而其餘套件則是為了在 Notebook 中繪圖所需要的套件。" + ] + }, + { + "cell_type": "code", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "5e2bScpnkVbv", + "outputId": "dd8cf053-de15-4a11-c146-5f3405d1e377" + }, + "source": [ + "!apt update\n", + "!apt install python-opengl xvfb -y\n", + "!pip install gym[box2d]==0.18.3 pyvirtualdisplay tqdm numpy==1.19.5 torch==1.8.1" + ], + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "text": [ + "\u001b[33m\r0% [Working]\u001b[0m\r \rGet:1 https://cloud.r-project.org/bin/linux/ubuntu bionic-cran40/ InRelease [3,626 B]\n", + "\u001b[33m\r0% [Connecting to archive.ubuntu.com (91.189.88.142)] [Waiting for headers] [1 \u001b[0m\u001b[33m\r0% [Connecting to archive.ubuntu.com (91.189.88.142)] [Waiting for headers] [Co\u001b[0m\r \rIgn:2 https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64 InRelease\n", + "\u001b[33m\r0% [Connecting to archive.ubuntu.com (91.189.88.142)] [Waiting for headers] [Co\u001b[0m\u001b[33m\r0% [1 InRelease gpgv 3,626 B] [Connecting to archive.ubuntu.com (91.189.88.142)\u001b[0m\r \rIgn:3 https://developer.download.nvidia.com/compute/machine-learning/repos/ubuntu1804/x86_64 InRelease\n", + "\u001b[33m\r0% [1 InRelease gpgv 3,626 B] [Connecting to archive.ubuntu.com (91.189.88.142)\u001b[0m\r \rGet:4 https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64 Release [697 B]\n", + "\u001b[33m\r0% [1 InRelease gpgv 3,626 B] [Connecting to archive.ubuntu.com (91.189.88.142)\u001b[0m\u001b[33m\r0% [1 InRelease gpgv 3,626 B] [Connecting to archive.ubuntu.com (91.189.88.142)\u001b[0m\r \rHit:5 https://developer.download.nvidia.com/compute/machine-learning/repos/ubuntu1804/x86_64 Release\n", + "\u001b[33m\r0% [1 InRelease gpgv 3,626 B] [Connecting to archive.ubuntu.com (91.189.88.142)\u001b[0m\r \rGet:6 https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64 Release.gpg [836 B]\n", + "Get:7 http://security.ubuntu.com/ubuntu bionic-security InRelease [88.7 kB]\n", + "Hit:8 http://archive.ubuntu.com/ubuntu bionic InRelease\n", + "Get:9 http://ppa.launchpad.net/c2d4u.team/c2d4u4.0+/ubuntu bionic InRelease [15.9 kB]\n", + "Get:10 https://cloud.r-project.org/bin/linux/ubuntu bionic-cran40/ Packages [60.9 kB]\n", + "Get:11 http://archive.ubuntu.com/ubuntu bionic-updates InRelease [88.7 kB]\n", + "Hit:13 http://ppa.launchpad.net/cran/libgit2/ubuntu bionic InRelease\n", + "Ign:14 https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64 Packages\n", + "Get:14 https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64 Packages [798 kB]\n", + "Get:15 http://archive.ubuntu.com/ubuntu bionic-backports InRelease [74.6 kB]\n", + "Hit:16 http://ppa.launchpad.net/deadsnakes/ppa/ubuntu bionic InRelease\n", + "Get:17 http://ppa.launchpad.net/graphics-drivers/ppa/ubuntu bionic InRelease [21.3 kB]\n", + "Get:18 http://security.ubuntu.com/ubuntu bionic-security/restricted amd64 Packages [423 kB]\n", + "Get:19 http://security.ubuntu.com/ubuntu bionic-security/main amd64 Packages [2,152 kB]\n", + "Get:20 http://ppa.launchpad.net/c2d4u.team/c2d4u4.0+/ubuntu bionic/main Sources [1,770 kB]\n", + "Get:21 http://security.ubuntu.com/ubuntu bionic-security/universe amd64 Packages [1,413 kB]\n", + "Get:22 http://archive.ubuntu.com/ubuntu bionic-updates/universe amd64 Packages [2,184 kB]\n", + "Get:23 http://archive.ubuntu.com/ubuntu bionic-updates/restricted amd64 Packages [452 kB]\n", + "Get:24 http://archive.ubuntu.com/ubuntu bionic-updates/main amd64 Packages [2,584 kB]\n", + "Get:25 http://ppa.launchpad.net/c2d4u.team/c2d4u4.0+/ubuntu bionic/main amd64 Packages [905 kB]\n", + "Get:26 http://ppa.launchpad.net/graphics-drivers/ppa/ubuntu bionic/main amd64 Packages [41.5 kB]\n", + "Fetched 13.1 MB in 4s (3,031 kB/s)\n", + "Reading package lists... Done\n", + "Building dependency tree \n", + "Reading state information... Done\n", + "86 packages can be upgraded. Run 'apt list --upgradable' to see them.\n", + "Reading package lists... Done\n", + "Building dependency tree \n", + "Reading state information... Done\n", + "The following package was automatically installed and is no longer required:\n", + " libnvidia-common-460\n", + "Use 'apt autoremove' to remove it.\n", + "Suggested packages:\n", + " libgle3\n", + "The following NEW packages will be installed:\n", + " python-opengl xvfb\n", + "0 upgraded, 2 newly installed, 0 to remove and 86 not upgraded.\n", + "Need to get 1,281 kB of archives.\n", + "After this operation, 7,686 kB of additional disk space will be used.\n", + "Get:1 http://archive.ubuntu.com/ubuntu bionic/universe amd64 python-opengl all 3.1.0+dfsg-1 [496 kB]\n", + "Get:2 http://archive.ubuntu.com/ubuntu bionic-updates/universe amd64 xvfb amd64 2:1.19.6-1ubuntu4.9 [784 kB]\n", + "Fetched 1,281 kB in 1s (977 kB/s)\n", + "Selecting previously unselected package python-opengl.\n", + "(Reading database ... 160706 files and directories currently installed.)\n", + "Preparing to unpack .../python-opengl_3.1.0+dfsg-1_all.deb ...\n", + "Unpacking python-opengl (3.1.0+dfsg-1) ...\n", + "Selecting previously unselected package xvfb.\n", + "Preparing to unpack .../xvfb_2%3a1.19.6-1ubuntu4.9_amd64.deb ...\n", + "Unpacking xvfb (2:1.19.6-1ubuntu4.9) ...\n", + "Setting up python-opengl (3.1.0+dfsg-1) ...\n", + "Setting up xvfb (2:1.19.6-1ubuntu4.9) ...\n", + "Processing triggers for man-db (2.8.3-2ubuntu0.1) ...\n", + "Requirement already satisfied: gym[box2d] in /usr/local/lib/python3.7/dist-packages (0.17.3)\n", + "Collecting pyvirtualdisplay\n", + " Downloading https://files.pythonhosted.org/packages/19/88/7a198a5ee3baa3d547f5a49574cd8c3913b216f5276b690b028f89ffb325/PyVirtualDisplay-2.1-py3-none-any.whl\n", + "Requirement already satisfied: tqdm in /usr/local/lib/python3.7/dist-packages (4.41.1)\n", + "Requirement already satisfied: cloudpickle<1.7.0,>=1.2.0 in /usr/local/lib/python3.7/dist-packages (from gym[box2d]) (1.3.0)\n", + "Requirement already satisfied: numpy>=1.10.4 in /usr/local/lib/python3.7/dist-packages (from gym[box2d]) (1.19.5)\n", + "Requirement already satisfied: pyglet<=1.5.0,>=1.4.0 in /usr/local/lib/python3.7/dist-packages (from gym[box2d]) (1.5.0)\n", + "Requirement already satisfied: scipy in /usr/local/lib/python3.7/dist-packages (from gym[box2d]) (1.4.1)\n", + "Collecting box2d-py~=2.3.5; extra == \"box2d\"\n", + "\u001b[?25l Downloading https://files.pythonhosted.org/packages/87/34/da5393985c3ff9a76351df6127c275dcb5749ae0abbe8d5210f06d97405d/box2d_py-2.3.8-cp37-cp37m-manylinux1_x86_64.whl (448kB)\n", + "\u001b[K |████████████████████████████████| 450kB 10.3MB/s \n", + "\u001b[?25hCollecting EasyProcess\n", + " Downloading https://files.pythonhosted.org/packages/48/3c/75573613641c90c6d094059ac28adb748560d99bd27ee6f80cce398f404e/EasyProcess-0.3-py2.py3-none-any.whl\n", + "Requirement already satisfied: future in /usr/local/lib/python3.7/dist-packages (from pyglet<=1.5.0,>=1.4.0->gym[box2d]) (0.16.0)\n", + "Installing collected packages: EasyProcess, pyvirtualdisplay, box2d-py\n", + "Successfully installed EasyProcess-0.3 box2d-py-2.3.8 pyvirtualdisplay-2.1\n" + ], + "name": "stdout" + } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "M_-i3cdoYsks" + }, + "source": [ + "接下來,設置好 virtual display,並引入所有必要的套件。" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "nl2nREINDLiw" + }, + "source": [ + "%%capture\n", + "from pyvirtualdisplay import Display\n", + "virtual_display = Display(visible=0, size=(1400, 900))\n", + "virtual_display.start()\n", + "\n", + "%matplotlib inline\n", + "import matplotlib.pyplot as plt\n", + "\n", + "from IPython import display\n", + "\n", + "import numpy as np\n", + "import torch\n", + "import torch.nn as nn\n", + "import torch.optim as optim\n", + "import torch.nn.functional as F\n", + "from torch.distributions import Categorical\n", + "from tqdm.notebook import tqdm" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "HVu9-Vdrl4E3" + }, + "source": [ + "# 請不要更改 random seed !!!!\n", + "# 不然在judgeboi上 你的成績不會被reproduce !!!!" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "fV9i8i2YkRbO" + }, + "source": [ + "seed = 543 # Do not change this\n", + "def fix(env, seed):\n", + " env.seed(seed)\n", + " env.action_space.seed(seed)\n", + " torch.manual_seed(seed)\n", + " torch.cuda.manual_seed(seed)\n", + " torch.cuda.manual_seed_all(seed)\n", + " np.random.seed(seed)\n", + " random.seed(seed)\n", + " torch.set_deterministic(True)\n", + " torch.backends.cudnn.benchmark = False\n", + " torch.backends.cudnn.deterministic = True" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "He0XDx6bzjgC" + }, + "source": [ + "最後,引入 OpenAI 的 gym,並建立一個 [Lunar Lander](https://gym.openai.com/envs/LunarLander-v2/) 環境。" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "N_4-xJcbBt09" + }, + "source": [ + "%%capture\n", + "import gym\n", + "import random\n", + "import numpy as np\n", + "\n", + "env = gym.make('LunarLander-v2')\n", + "\n", + "fix(env, seed)" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "NmiAOfqRwRX5" + }, + "source": [ + "import time\n", + "start = time.time()" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "LcMjEUWTBEEB", + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "7a5146e4-e877-4d26-fd61-652c57ef1f4e" + }, + "source": [ + "!pip freeze" + ], + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "text": [ + "absl-py==0.12.0\n", + "alabaster==0.7.12\n", + "albumentations==0.1.12\n", + "altair==4.1.0\n", + "appdirs==1.4.4\n", + "argon2-cffi==20.1.0\n", + "arviz==0.11.2\n", + "astor==0.8.1\n", + "astropy==4.2.1\n", + "astunparse==1.6.3\n", + "async-generator==1.10\n", + "atari-py==0.2.9\n", + "atomicwrites==1.4.0\n", + "attrs==21.2.0\n", + "audioread==2.1.9\n", + "autograd==1.3\n", + "Babel==2.9.1\n", + "backcall==0.2.0\n", + "beautifulsoup4==4.6.3\n", + "bleach==3.3.0\n", + "blis==0.4.1\n", + "bokeh==2.3.2\n", + "Bottleneck==1.3.2\n", + "box2d-py==2.3.8\n", + "branca==0.4.2\n", + "bs4==0.0.1\n", + "CacheControl==0.12.6\n", + "cached-property==1.5.2\n", + "cachetools==4.2.2\n", + "catalogue==1.0.0\n", + "certifi==2020.12.5\n", + "cffi==1.14.5\n", + "cftime==1.5.0\n", + "chainer==7.4.0\n", + "chardet==3.0.4\n", + "click==7.1.2\n", + "cloudpickle==1.3.0\n", + "cmake==3.12.0\n", + "cmdstanpy==0.9.5\n", + "colorcet==2.0.6\n", + "colorlover==0.3.0\n", + "community==1.0.0b1\n", + "contextlib2==0.5.5\n", + "convertdate==2.3.2\n", + "coverage==3.7.1\n", + "coveralls==0.5\n", + "crcmod==1.7\n", + "cufflinks==0.17.3\n", + "cupy-cuda101==7.4.0\n", + "cvxopt==1.2.6\n", + "cvxpy==1.0.31\n", + "cycler==0.10.0\n", + "cymem==2.0.5\n", + "Cython==0.29.23\n", + "daft==0.0.4\n", + "dask==2.12.0\n", + "datascience==0.10.6\n", + "debugpy==1.0.0\n", + "decorator==4.4.2\n", + "defusedxml==0.7.1\n", + "descartes==1.1.0\n", + "dill==0.3.3\n", + "distributed==1.25.3\n", + "dlib==19.18.0\n", + "dm-tree==0.1.6\n", + "docopt==0.6.2\n", + "docutils==0.17.1\n", + "dopamine-rl==1.0.5\n", + "earthengine-api==0.1.266\n", + "easydict==1.9\n", + "EasyProcess==0.3\n", + "ecos==2.0.7.post1\n", + "editdistance==0.5.3\n", + "en-core-web-sm==2.2.5\n", + "entrypoints==0.3\n", + "ephem==3.7.7.1\n", + "et-xmlfile==1.1.0\n", + "fa2==0.3.5\n", + "fastai==1.0.61\n", + "fastdtw==0.3.4\n", + "fastprogress==1.0.0\n", + "fastrlock==0.6\n", + "fbprophet==0.7.1\n", + "feather-format==0.4.1\n", + "filelock==3.0.12\n", + "firebase-admin==4.4.0\n", + "fix-yahoo-finance==0.0.22\n", + "Flask==1.1.4\n", + "flatbuffers==1.12\n", + "folium==0.8.3\n", + "future==0.16.0\n", + "gast==0.4.0\n", + "GDAL==2.2.2\n", + "gdown==3.6.4\n", + "gensim==3.6.0\n", + "geographiclib==1.50\n", + "geopy==1.17.0\n", + "gin-config==0.4.0\n", + "glob2==0.7\n", + "google==2.0.3\n", + "google-api-core==1.26.3\n", + "google-api-python-client==1.12.8\n", + "google-auth==1.30.0\n", + "google-auth-httplib2==0.0.4\n", + "google-auth-oauthlib==0.4.4\n", + "google-cloud-bigquery==1.21.0\n", + "google-cloud-bigquery-storage==1.1.0\n", + "google-cloud-core==1.0.3\n", + "google-cloud-datastore==1.8.0\n", + "google-cloud-firestore==1.7.0\n", + "google-cloud-language==1.2.0\n", + "google-cloud-storage==1.18.1\n", + "google-cloud-translate==1.5.0\n", + "google-colab==1.0.0\n", + "google-pasta==0.2.0\n", + "google-resumable-media==0.4.1\n", + "googleapis-common-protos==1.53.0\n", + "googledrivedownloader==0.4\n", + "graphviz==0.10.1\n", + "greenlet==1.1.0\n", + "grpcio==1.34.1\n", + "gspread==3.0.1\n", + "gspread-dataframe==3.0.8\n", + "gym==0.17.3\n", + "h5py==3.1.0\n", + "HeapDict==1.0.1\n", + "hijri-converter==2.1.1\n", + "holidays==0.10.5.2\n", + "holoviews==1.14.3\n", + "html5lib==1.0.1\n", + "httpimport==0.5.18\n", + "httplib2==0.17.4\n", + "httplib2shim==0.0.3\n", + "humanize==0.5.1\n", + "hyperopt==0.1.2\n", + "ideep4py==2.0.0.post3\n", + "idna==2.10\n", + "imageio==2.4.1\n", + "imagesize==1.2.0\n", + "imbalanced-learn==0.4.3\n", + "imblearn==0.0\n", + "imgaug==0.2.9\n", + "importlib-metadata==4.0.1\n", + "importlib-resources==5.1.3\n", + "imutils==0.5.4\n", + "inflect==2.1.0\n", + "iniconfig==1.1.1\n", + "install==1.3.4\n", + "intel-openmp==2021.2.0\n", + "intervaltree==2.1.0\n", + "ipykernel==4.10.1\n", + "ipython==5.5.0\n", + "ipython-genutils==0.2.0\n", + "ipython-sql==0.3.9\n", + "ipywidgets==7.6.3\n", + "itsdangerous==1.1.0\n", + "jax==0.2.13\n", + "jaxlib==0.1.66+cuda110\n", + "jdcal==1.4.1\n", + "jedi==0.18.0\n", + "jieba==0.42.1\n", + "Jinja2==2.11.3\n", + "joblib==1.0.1\n", + "jpeg4py==0.1.4\n", + "jsonschema==2.6.0\n", + "jupyter==1.0.0\n", + "jupyter-client==5.3.5\n", + "jupyter-console==5.2.0\n", + "jupyter-core==4.7.1\n", + "jupyterlab-pygments==0.1.2\n", + "jupyterlab-widgets==1.0.0\n", + "kaggle==1.5.12\n", + "kapre==0.3.5\n", + "Keras==2.4.3\n", + "keras-nightly==2.5.0.dev2021032900\n", + "Keras-Preprocessing==1.1.2\n", + "keras-vis==0.4.1\n", + "kiwisolver==1.3.1\n", + "korean-lunar-calendar==0.2.1\n", + "librosa==0.8.0\n", + "lightgbm==2.2.3\n", + "llvmlite==0.34.0\n", + "lmdb==0.99\n", + "LunarCalendar==0.0.9\n", + "lxml==4.2.6\n", + "Markdown==3.3.4\n", + "MarkupSafe==2.0.1\n", + "matplotlib==3.2.2\n", + "matplotlib-inline==0.1.2\n", + "matplotlib-venn==0.11.6\n", + "missingno==0.4.2\n", + "mistune==0.8.4\n", + "mizani==0.6.0\n", + "mkl==2019.0\n", + "mlxtend==0.14.0\n", + "more-itertools==8.7.0\n", + "moviepy==0.2.3.5\n", + "mpmath==1.2.1\n", + "msgpack==1.0.2\n", + "multiprocess==0.70.11.1\n", + "multitasking==0.0.9\n", + "murmurhash==1.0.5\n", + "music21==5.5.0\n", + "natsort==5.5.0\n", + "nbclient==0.5.3\n", + "nbconvert==5.6.1\n", + "nbformat==5.1.3\n", + "nest-asyncio==1.5.1\n", + "netCDF4==1.5.6\n", + "networkx==2.5.1\n", + "nibabel==3.0.2\n", + "nltk==3.2.5\n", + "notebook==5.3.1\n", + "numba==0.51.2\n", + "numexpr==2.7.3\n", + "numpy==1.19.5\n", + "nvidia-ml-py3==7.352.0\n", + "oauth2client==4.1.3\n", + "oauthlib==3.1.0\n", + "okgrade==0.4.3\n", + "opencv-contrib-python==4.1.2.30\n", + "opencv-python==4.1.2.30\n", + "openpyxl==2.5.9\n", + "opt-einsum==3.3.0\n", + "osqp==0.6.2.post0\n", + "packaging==20.9\n", + "palettable==3.3.0\n", + "pandas==1.1.5\n", + "pandas-datareader==0.9.0\n", + "pandas-gbq==0.13.3\n", + "pandas-profiling==1.4.1\n", + "pandocfilters==1.4.3\n", + "panel==0.11.3\n", + "param==1.10.1\n", + "parso==0.8.2\n", + "pathlib==1.0.1\n", + "patsy==0.5.1\n", + "pexpect==4.8.0\n", + "pickleshare==0.7.5\n", + "Pillow==7.1.2\n", + "pip-tools==4.5.1\n", + "plac==1.1.3\n", + "plotly==4.4.1\n", + "plotnine==0.6.0\n", + "pluggy==0.7.1\n", + "pooch==1.3.0\n", + "portpicker==1.3.9\n", + "prefetch-generator==1.0.1\n", + "preshed==3.0.5\n", + "prettytable==2.1.0\n", + "progressbar2==3.38.0\n", + "prometheus-client==0.10.1\n", + "promise==2.3\n", + "prompt-toolkit==1.0.18\n", + "protobuf==3.12.4\n", + "psutil==5.4.8\n", + "psycopg2==2.7.6.1\n", + "ptyprocess==0.7.0\n", + "py==1.10.0\n", + "pyarrow==3.0.0\n", + "pyasn1==0.4.8\n", + "pyasn1-modules==0.2.8\n", + "pycocotools==2.0.2\n", + "pycparser==2.20\n", + "pyct==0.4.8\n", + "pydata-google-auth==1.2.0\n", + "pydot==1.3.0\n", + "pydot-ng==2.0.0\n", + "pydotplus==2.0.2\n", + "PyDrive==1.3.1\n", + "pyemd==0.5.1\n", + "pyerfa==2.0.0\n", + "pyglet==1.5.0\n", + "Pygments==2.6.1\n", + "pygobject==3.26.1\n", + "pymc3==3.11.2\n", + "PyMeeus==0.5.11\n", + "pymongo==3.11.4\n", + "pymystem3==0.2.0\n", + "PyOpenGL==3.1.5\n", + "pyparsing==2.4.7\n", + "pyrsistent==0.17.3\n", + "pysndfile==1.3.8\n", + "PySocks==1.7.1\n", + "pystan==2.19.1.1\n", + "pytest==3.6.4\n", + "python-apt==0.0.0\n", + "python-chess==0.23.11\n", + "python-dateutil==2.8.1\n", + "python-louvain==0.15\n", + "python-slugify==5.0.2\n", + "python-utils==2.5.6\n", + "pytz==2018.9\n", + "PyVirtualDisplay==2.1\n", + "pyviz-comms==2.0.1\n", + "PyWavelets==1.1.1\n", + "PyYAML==3.13\n", + "pyzmq==22.0.3\n", + "qdldl==0.1.5.post0\n", + "qtconsole==5.1.0\n", + "QtPy==1.9.0\n", + "regex==2019.12.20\n", + "requests==2.23.0\n", + "requests-oauthlib==1.3.0\n", + "resampy==0.2.2\n", + "retrying==1.3.3\n", + "rpy2==3.4.4\n", + "rsa==4.7.2\n", + "scikit-image==0.16.2\n", + "scikit-learn==0.22.2.post1\n", + "scipy==1.4.1\n", + "screen-resolution-extra==0.0.0\n", + "scs==2.1.3\n", + "seaborn==0.11.1\n", + "semver==2.13.0\n", + "Send2Trash==1.5.0\n", + "setuptools-git==1.2\n", + "Shapely==1.7.1\n", + "simplegeneric==0.8.1\n", + "six==1.15.0\n", + "sklearn==0.0\n", + "sklearn-pandas==1.8.0\n", + "smart-open==5.0.0\n", + "snowballstemmer==2.1.0\n", + "sortedcontainers==2.4.0\n", + "SoundFile==0.10.3.post1\n", + "spacy==2.2.4\n", + "Sphinx==1.8.5\n", + "sphinxcontrib-serializinghtml==1.1.4\n", + "sphinxcontrib-websupport==1.2.4\n", + "SQLAlchemy==1.4.15\n", + "sqlparse==0.4.1\n", + "srsly==1.0.5\n", + "statsmodels==0.10.2\n", + "sympy==1.7.1\n", + "tables==3.4.4\n", + "tabulate==0.8.9\n", + "tblib==1.7.0\n", + "tensorboard==2.5.0\n", + "tensorboard-data-server==0.6.1\n", + "tensorboard-plugin-wit==1.8.0\n", + "tensorflow==2.5.0\n", + "tensorflow-datasets==4.0.1\n", + "tensorflow-estimator==2.5.0\n", + "tensorflow-gcs-config==2.5.0\n", + "tensorflow-hub==0.12.0\n", + "tensorflow-metadata==1.0.0\n", + "tensorflow-probability==0.12.1\n", + "termcolor==1.1.0\n", + "terminado==0.10.0\n", + "testpath==0.5.0\n", + "text-unidecode==1.3\n", + "textblob==0.15.3\n", + "Theano-PyMC==1.1.2\n", + "thinc==7.4.0\n", + "tifffile==2021.4.8\n", + "toml==0.10.2\n", + "toolz==0.11.1\n", + "torch==1.8.1+cu101\n", + "torchsummary==1.5.1\n", + "torchtext==0.9.1\n", + "torchvision==0.9.1+cu101\n", + "tornado==5.1.1\n", + "tqdm==4.41.1\n", + "traitlets==5.0.5\n", + "tweepy==3.10.0\n", + "typeguard==2.7.1\n", + "typing-extensions==3.7.4.3\n", + "tzlocal==1.5.1\n", + "uritemplate==3.0.1\n", + "urllib3==1.24.3\n", + "vega-datasets==0.9.0\n", + "wasabi==0.8.2\n", + "wcwidth==0.2.5\n", + "webencodings==0.5.1\n", + "Werkzeug==1.0.1\n", + "widgetsnbextension==3.5.1\n", + "wordcloud==1.5.0\n", + "wrapt==1.12.1\n", + "xarray==0.18.2\n", + "xgboost==0.90\n", + "xkit==0.0.0\n", + "xlrd==1.1.0\n", + "xlwt==1.3.0\n", + "yellowbrick==0.9.1\n", + "zict==2.0.0\n", + "zipp==3.4.1\n" + ], + "name": "stdout" + } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "NrkVvTrvWZ5H" + }, + "source": [ + "## 什麼是 Lunar Lander?\n", + "\n", + "“LunarLander-v2” 這個環境是在模擬登月小艇降落在月球表面時的情形。\n", + "這個任務的目標是讓登月小艇「安全地」降落在兩個黃色旗幟間的平地上。\n", + "> Landing pad is always at coordinates (0,0).\n", + "> Coordinates are the first two numbers in state vector.\n", + "\n", + "![](https://gym.openai.com/assets/docs/aeloop-138c89d44114492fd02822303e6b4b07213010bb14ca5856d2d49d6b62d88e53.svg)\n", + "\n", + "所謂的「環境」其實同時包括了 agent 和 environment。\n", + "我們利用 `step()` 這個函式讓 agent 行動,而後函式便會回傳 environment 給予的 observation/state(以下這兩個名詞代表同樣的意思)和 reward。" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "bIbp82sljvAt" + }, + "source": [ + "### Observation / State\n", + "\n", + "首先,我們可以看看 environment 回傳給 agent 的 observation 究竟是長什麼樣子的資料:" + ] + }, + { + "cell_type": "code", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "rsXZra3N9R5T", + "outputId": "9512a449-f90a-4545-8aef-dd9aeb9b2b9e" + }, + "source": [ + "print(env.observation_space)" + ], + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "text": [ + "Box(-inf, inf, (8,), float32)\n" + ], + "name": "stdout" + } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "ezdfoThbAQ49" + }, + "source": [ + "`Box(8,)` 說明我們會拿到 8 維的向量作為 observation,其中包含:垂直及水平座標、速度、角度、加速度等等,這部分我們就不細說。\n", + "\n", + "### Action\n", + "\n", + "而在 agent 得到 observation 和 reward 以後,能夠採取的動作有:" + ] + }, + { + "cell_type": "code", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "p1k4dIrBAaKi", + "outputId": "64cd523a-bbff-4569-cae9-f65123b3c604" + }, + "source": [ + "print(env.action_space)" + ], + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "text": [ + "Discrete(4)\n" + ], + "name": "stdout" + } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "dejXT6PHBrPn" + }, + "source": [ + "`Discrete(4)` 說明 agent 可以採取四種離散的行動:\n", + "- 0 代表不採取任何行動\n", + "- 2 代表主引擎向下噴射\n", + "- 1, 3 則是向左右噴射\n", + "\n", + "接下來,我們嘗試讓 agent 與 environment 互動。\n", + "在進行任何操作前,建議先呼叫 `reset()` 函式讓整個「環境」重置。\n", + "而這個函式同時會回傳「環境」最初始的狀態。" + ] + }, + { + "cell_type": "code", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "pi4OmrmZgnWA", + "outputId": "c358ff73-1879-4a74-9579-9ee97740dc16" + }, + "source": [ + "initial_state = env.reset()\n", + "print(initial_state)" + ], + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "text": [ + "[ 0.00396109 1.4083536 0.40119505 -0.11407257 -0.00458307 -0.09087662\n", + " 0. 0. ]\n" + ], + "name": "stdout" + } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "uBx0mEqqgxJ9" + }, + "source": [ + "接著,我們試著從 agent 的四種行動空間中,隨機採取一個行動" + ] + }, + { + "cell_type": "code", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "vxkOEXRKgizt", + "outputId": "8912cf80-2310-401b-a37e-c0ded59626ee" + }, + "source": [ + "random_action = env.action_space.sample()\n", + "print(random_action)" + ], + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "text": [ + "0\n" + ], + "name": "stdout" + } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "mns-bO01g0-J" + }, + "source": [ + "再利用 `step()` 函式讓 agent 根據我們隨機抽樣出來的 `random_action` 動作。\n", + "而這個函式會回傳四項資訊:\n", + "- observation / state\n", + "- reward\n", + "- 完成與否\n", + "- 其餘資訊" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "E_WViSxGgIk9" + }, + "source": [ + "observation, reward, done, info = env.step(random_action)" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "FdieGq7NuBIm" + }, + "source": [ + "第一項資訊 `observation` 即為 agent 採取行動之後,agent 對於環境的 observation 或者說環境的 state 為何。\n", + "而第三項資訊 `done` 則是 `True` 或 `False` 的布林值,當登月小艇成功著陸或是不幸墜毀時,代表這個回合(episode)也就跟著結束了,此時 `step()` 函式便會回傳 `done = True`,而在那之前,`done` 則保持 `False`。" + ] + }, + { + "cell_type": "code", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "yK7r126kuCNp", + "outputId": "3b99114f-e6b4-4a18-c80b-75189083bd55" + }, + "source": [ + "print(done)" + ], + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "text": [ + "False\n" + ], + "name": "stdout" + } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "GKdS8vOihxhc" + }, + "source": [ + "### Reward\n", + "\n", + "而「環境」給予的 reward 大致是這樣計算:\n", + "- 小艇墜毀得到 -100 分\n", + "- 小艇在黃旗幟之間成功著地則得 100~140 分\n", + "- 噴射主引擎(向下噴火)每次 -0.3 分\n", + "- 小艇最終完全靜止則再得 100 分\n", + "- 小艇每隻腳碰觸地面 +10 分\n", + "\n", + "> Reward for moving from the top of the screen to landing pad and zero speed is about 100..140 points.\n", + "> If lander moves away from landing pad it loses reward back.\n", + "> Episode finishes if the lander crashes or comes to rest, receiving additional -100 or +100 points.\n", + "> Each leg ground contact is +10.\n", + "> Firing main engine is -0.3 points each frame.\n" + ] + }, + { + "cell_type": "code", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "vxQNs77hi0_7", + "outputId": "dacd87b3-734e-44f3-c5b4-361b323def84" + }, + "source": [ + "print(reward) # after doing a random action (0), the immediate reward is stored in this " + ], + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "text": [ + "-0.8588900517154912\n" + ], + "name": "stdout" + } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Mhqp6D-XgHpe" + }, + "source": [ + "### Random Agent\n", + "\n", + "最後,在進入實做之前,我們就來看看這樣一個 random agent 能否成功登陸月球:" + ] + }, + { + "cell_type": "code", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 269 + }, + "id": "Y3G0bxoccelv", + "outputId": "36096915-445e-40fb-b349-a6a9a5b900d5" + }, + "source": [ + "\n", + "env.reset()\n", + "\n", + "img = plt.imshow(env.render(mode='rgb_array'))\n", + "\n", + "done = False\n", + "while not done:\n", + " action = env.action_space.sample()\n", + " observation, reward, done, _ = env.step(action)\n", + "\n", + " img.set_data(env.render(mode='rgb_array'))\n", + " display.display(plt.gcf())\n", + " display.clear_output(wait=True)" + ], + "execution_count": null, + "outputs": [ + { + "output_type": "display_data", + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "tags": [], + "needs_background": "light" + } + } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "F5paWqo7tWL2" + }, + "source": [ + "## Policy Gradient\n", + "\n", + "現在來搭建一個簡單的 policy network。\n", + "我們預設模型的輸入是 8-dim 的 observation,輸出則是離散的四個動作之一:" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "J8tdmeD-tZew" + }, + "source": [ + "class PolicyGradientNetwork(nn.Module):\n", + "\n", + " def __init__(self):\n", + " super().__init__()\n", + " self.fc1 = nn.Linear(8, 16)\n", + " self.fc2 = nn.Linear(16, 16)\n", + " self.fc3 = nn.Linear(16, 4)\n", + "\n", + " def forward(self, state):\n", + " hid = torch.tanh(self.fc1(state))\n", + " hid = torch.tanh(self.fc2(hid))\n", + " return F.softmax(self.fc3(hid), dim=-1)" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "ynbqJrhIFTC3" + }, + "source": [ + "再來,搭建一個簡單的 agent,並搭配上方的 policy network 來採取行動。\n", + "這個 agent 能做到以下幾件事:\n", + "- `learn()`:從記下來的 log probabilities 及 rewards 來更新 policy network。\n", + "- `sample()`:從 environment 得到 observation 之後,利用 policy network 得出應該採取的行動。\n", + "而此函式除了回傳抽樣出來的 action,也會回傳此次抽樣的 log probabilities。" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "zZo-IxJx286z" + }, + "source": [ + "\n", + "class PolicyGradientAgent():\n", + " \n", + " def __init__(self, network):\n", + " self.network = network\n", + " self.optimizer = optim.SGD(self.network.parameters(), lr=0.001)\n", + " \n", + " def forward(self, state):\n", + " return self.network(state)\n", + " def learn(self, log_probs, rewards):\n", + " loss = (-log_probs * rewards).sum() # You don't need to revise this to pass simple baseline (but you can)\n", + "\n", + " self.optimizer.zero_grad()\n", + " loss.backward()\n", + " self.optimizer.step()\n", + " \n", + " def sample(self, state):\n", + " action_prob = self.network(torch.FloatTensor(state))\n", + " action_dist = Categorical(action_prob)\n", + " action = action_dist.sample()\n", + " log_prob = action_dist.log_prob(action)\n", + " return action.item(), log_prob\n", + "\n", + " def save(self, PATH): # You should not revise this\n", + " Agent_Dict = {\n", + " \"network\" : self.network.state_dict(),\n", + " \"optimizer\" : self.optimizer.state_dict()\n", + " }\n", + " torch.save(Agent_Dict, PATH)\n", + "\n", + " def load(self, PATH): # You should not revise this\n", + " checkpoint = torch.load(PATH)\n", + " self.network.load_state_dict(checkpoint[\"network\"])\n", + " #如果要儲存過程或是中斷訓練後想繼續可以用喔 ^_^\n", + " self.optimizer.load_state_dict(checkpoint[\"optimizer\"])\n" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "ehPlnTKyRZf9" + }, + "source": [ + "最後,建立一個 network 和 agent,就可以開始進行訓練了。" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "GfJIvML-RYjL" + }, + "source": [ + "network = PolicyGradientNetwork()\n", + "agent = PolicyGradientAgent(network)\n", + "#agent = PolicyGradientAgent()" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "ouv23glgf5Qt" + }, + "source": [ + "## 訓練 Agent\n", + "\n", + "現在我們開始訓練 agent。\n", + "透過讓 agent 和 environment 互動,我們記住每一組對應的 log probabilities 及 reward,並在成功登陸或者不幸墜毀後,回放這些「記憶」來訓練 policy network。" + ] + }, + { + "cell_type": "code", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 1000, + "referenced_widgets": [ + "2acab9542fe64b979fa2ac2adb3f10a8", + "f288c64b5ff748eb82178bf1de17934f", + "de34e5b178f5470e98e0275102a65042", + "c93cba301cac439ca56fb6b45bd1c4e4", + "43c6ee720b674626ab3a869bda5dd6e3", + "2465d2b109d34922a486341232d86ad6", + "aa27187195be4da9874025395eac35eb", + "02d196d4f9734f998455d92bd9300adb" + ] + }, + "id": "vg5rxBBaf38_", + "outputId": "eae0c9f4-0efc-40fe-a29e-7f7194613f6d" + }, + "source": [ + "agent.network.train() # 訓練前,先確保 network 處在 training 模式\n", + "EPISODE_PER_BATCH = 5 # 每蒐集 5 個 episodes 更新一次 agent\n", + "NUM_BATCH = 400 # 總共更新 400 次\n", + "\n", + "avg_total_rewards, avg_final_rewards = [], []\n", + "\n", + "prg_bar = tqdm(range(NUM_BATCH))\n", + "for batch in prg_bar:\n", + "\n", + " log_probs, rewards = [], []\n", + " total_rewards, final_rewards = [], []\n", + "\n", + " # 蒐集訓練資料\n", + " for episode in range(EPISODE_PER_BATCH):\n", + " \n", + " state = env.reset()\n", + " total_reward, total_step = 0, 0\n", + " seq_rewards = []\n", + " while True:\n", + "\n", + " action, log_prob = agent.sample(state) # at , log(at|st)\n", + " next_state, reward, done, _ = env.step(action)\n", + "\n", + " log_probs.append(log_prob) # [log(a1|s1), log(a2|s2), ...., log(at|st)]\n", + " # seq_rewards.append(reward)\n", + " state = next_state\n", + " total_reward += reward\n", + " total_step += 1\n", + " rewards.append(reward) #改這裡\n", + " # ! 重要 !\n", + " # 現在的reward 的implementation 為每個時刻的瞬時reward, 給定action_list : a1, a2, a3 ......\n", + " # reward : r1, r2 ,r3 ......\n", + " # medium:將reward調整成accumulative decaying reward, 給定action_list : a1, a2, a3 ......\n", + " # reward : r1+0.99*r2+0.99^2*r3+......, r2+0.99*r3+0.99^2*r4+...... ,r3+0.99*r4+0.99^2*r5+ ......\n", + " # boss : implement DQN\n", + " if done:\n", + " final_rewards.append(reward)\n", + " total_rewards.append(total_reward)\n", + " break\n", + "\n", + " print(f\"rewards looks like \", np.shape(rewards)) \n", + " print(f\"log_probs looks like \", np.shape(log_probs)) \n", + " # 紀錄訓練過程\n", + " avg_total_reward = sum(total_rewards) / len(total_rewards)\n", + " avg_final_reward = sum(final_rewards) / len(final_rewards)\n", + " avg_total_rewards.append(avg_total_reward)\n", + " avg_final_rewards.append(avg_final_reward)\n", + " prg_bar.set_description(f\"Total: {avg_total_reward: 4.1f}, Final: {avg_final_reward: 4.1f}\")\n", + "\n", + " # 更新網路\n", + " # rewards = np.concatenate(rewards, axis=0)\n", + " rewards = (rewards - np.mean(rewards)) / (np.std(rewards) + 1e-9) # 將 reward 正規標準化\n", + " agent.learn(torch.stack(log_probs), torch.from_numpy(rewards))\n", + " print(\"logs prob looks like \", torch.stack(log_probs).size())\n", + " print(\"torch.from_numpy(rewards) looks like \", torch.from_numpy(rewards).size())" + ], + "execution_count": null, + "outputs": [ + { + "output_type": "display_data", + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "2acab9542fe64b979fa2ac2adb3f10a8", + "version_minor": 0, + "version_major": 2 + }, + "text/plain": [ + "HBox(children=(FloatProgress(value=0.0, max=400.0), HTML(value='')))" + ] + }, + "metadata": { + "tags": [] + } + }, + { + "output_type": "stream", + "text": [ + "rewards looks like (448,)\n", + "log_probs looks like (448,)\n", + "logs prob looks like torch.Size([448])\n", + "torch.from_numpy(rewards) looks like torch.Size([448])\n", + "rewards looks like (515,)\n", + "log_probs looks like (515,)\n", + "logs prob looks like torch.Size([515])\n", + "torch.from_numpy(rewards) looks like torch.Size([515])\n", + "rewards looks like (392,)\n", + "log_probs looks like (392,)\n", + "logs prob looks like torch.Size([392])\n", + "torch.from_numpy(rewards) looks like torch.Size([392])\n", + "rewards looks like (518,)\n", + "log_probs looks like (518,)\n", + "logs prob looks like torch.Size([518])\n", + "torch.from_numpy(rewards) looks like torch.Size([518])\n", + "rewards looks like (472,)\n", + "log_probs looks like (472,)\n", + "logs prob looks like torch.Size([472])\n", + "torch.from_numpy(rewards) looks like torch.Size([472])\n", + "rewards looks like (530,)\n", + "log_probs looks like (530,)\n", + "logs prob looks like torch.Size([530])\n", + "torch.from_numpy(rewards) looks like torch.Size([530])\n", + "rewards looks like (463,)\n", + "log_probs looks like (463,)\n", + "logs prob looks like torch.Size([463])\n", + "torch.from_numpy(rewards) looks like torch.Size([463])\n", + "rewards looks like (540,)\n", + "log_probs looks like (540,)\n", + "logs prob looks like torch.Size([540])\n", + "torch.from_numpy(rewards) looks like torch.Size([540])\n", + "rewards looks like (513,)\n", + "log_probs looks like (513,)\n", + "logs prob looks like torch.Size([513])\n", + "torch.from_numpy(rewards) looks like torch.Size([513])\n", + "rewards looks like (449,)\n", + "log_probs looks like (449,)\n", + "logs prob looks like torch.Size([449])\n", + "torch.from_numpy(rewards) looks like torch.Size([449])\n", + "rewards looks like (602,)\n", + "log_probs looks like (602,)\n", + "logs prob looks like torch.Size([602])\n", + "torch.from_numpy(rewards) looks like torch.Size([602])\n", + "rewards looks like (542,)\n", + "log_probs looks like (542,)\n", + "logs prob looks like torch.Size([542])\n", + "torch.from_numpy(rewards) looks like torch.Size([542])\n", + "rewards looks like (503,)\n", + "log_probs looks like (503,)\n", + "logs prob looks like torch.Size([503])\n", + "torch.from_numpy(rewards) looks like torch.Size([503])\n", + "rewards looks like (470,)\n", + "log_probs looks like (470,)\n", + "logs prob looks like torch.Size([470])\n", + "torch.from_numpy(rewards) looks like torch.Size([470])\n", + "rewards looks like (518,)\n", + "log_probs looks like (518,)\n", + "logs prob looks like torch.Size([518])\n", + "torch.from_numpy(rewards) looks like torch.Size([518])\n", + "rewards looks like (421,)\n", + "log_probs looks like (421,)\n", + "logs prob looks like torch.Size([421])\n", + "torch.from_numpy(rewards) looks like torch.Size([421])\n", + "rewards looks like (592,)\n", + "log_probs looks like (592,)\n", + "logs prob looks like torch.Size([592])\n", + "torch.from_numpy(rewards) looks like torch.Size([592])\n", + "rewards looks like (520,)\n", + "log_probs looks like (520,)\n", + "logs prob looks like torch.Size([520])\n", + "torch.from_numpy(rewards) looks like torch.Size([520])\n", + "rewards looks like (494,)\n", + "log_probs looks like (494,)\n", + "logs prob looks like torch.Size([494])\n", + "torch.from_numpy(rewards) looks like torch.Size([494])\n", + "rewards looks like (461,)\n", + "log_probs looks like (461,)\n", + "logs prob looks like torch.Size([461])\n", + "torch.from_numpy(rewards) looks like torch.Size([461])\n", + "rewards looks like (572,)\n", + "log_probs looks like (572,)\n", + "logs prob looks like torch.Size([572])\n", + "torch.from_numpy(rewards) looks like torch.Size([572])\n", + "rewards looks like (593,)\n", + "log_probs looks like (593,)\n", + "logs prob looks like torch.Size([593])\n", + "torch.from_numpy(rewards) looks like torch.Size([593])\n", + "rewards looks like (569,)\n", + "log_probs looks like (569,)\n", + "logs prob looks like torch.Size([569])\n", + "torch.from_numpy(rewards) looks like torch.Size([569])\n", + "rewards looks like (546,)\n", + "log_probs looks like (546,)\n", + "logs prob looks like torch.Size([546])\n", + "torch.from_numpy(rewards) looks like torch.Size([546])\n", + "rewards looks like (612,)\n", + "log_probs looks like (612,)\n", + "logs prob looks like torch.Size([612])\n", + "torch.from_numpy(rewards) looks like torch.Size([612])\n", + "rewards looks like (534,)\n", + "log_probs looks like (534,)\n", + "logs prob looks like torch.Size([534])\n", + "torch.from_numpy(rewards) looks like torch.Size([534])\n", + "rewards looks like (513,)\n", + "log_probs looks like (513,)\n", + "logs prob looks like torch.Size([513])\n", + "torch.from_numpy(rewards) looks like torch.Size([513])\n", + "rewards looks like (513,)\n", + "log_probs looks like (513,)\n", + "logs prob looks like torch.Size([513])\n", + "torch.from_numpy(rewards) looks like torch.Size([513])\n", + "rewards looks like (535,)\n", + "log_probs looks like (535,)\n", + "logs prob looks like torch.Size([535])\n", + "torch.from_numpy(rewards) looks like torch.Size([535])\n", + "rewards looks like (533,)\n", + "log_probs looks like (533,)\n", + "logs prob looks like torch.Size([533])\n", + "torch.from_numpy(rewards) looks like torch.Size([533])\n", + "rewards looks like (521,)\n", + "log_probs looks like (521,)\n", + "logs prob looks like torch.Size([521])\n", + "torch.from_numpy(rewards) looks like torch.Size([521])\n", + "rewards looks like (566,)\n", + "log_probs looks like (566,)\n", + "logs prob looks like torch.Size([566])\n", + "torch.from_numpy(rewards) looks like torch.Size([566])\n", + "rewards looks like (586,)\n", + "log_probs looks like (586,)\n", + "logs prob looks like torch.Size([586])\n", + "torch.from_numpy(rewards) looks like torch.Size([586])\n", + "rewards looks like (575,)\n", + "log_probs looks like (575,)\n", + "logs prob looks like torch.Size([575])\n", + "torch.from_numpy(rewards) looks like torch.Size([575])\n", + "rewards looks like (709,)\n", + "log_probs looks like (709,)\n", + "logs prob looks like torch.Size([709])\n", + "torch.from_numpy(rewards) looks like torch.Size([709])\n", + "rewards looks like (486,)\n", + "log_probs looks like (486,)\n", + "logs prob looks like torch.Size([486])\n", + "torch.from_numpy(rewards) looks like torch.Size([486])\n", + "rewards looks like (557,)\n", + "log_probs looks like (557,)\n", + "logs prob looks like torch.Size([557])\n", + "torch.from_numpy(rewards) looks like torch.Size([557])\n", + "rewards looks like (517,)\n", + "log_probs looks like (517,)\n", + "logs prob looks like torch.Size([517])\n", + "torch.from_numpy(rewards) looks like torch.Size([517])\n", + "rewards looks like (550,)\n", + "log_probs looks like (550,)\n", + "logs prob looks like torch.Size([550])\n", + "torch.from_numpy(rewards) looks like torch.Size([550])\n", + "rewards looks like (690,)\n", + "log_probs looks like (690,)\n", + "logs prob looks like torch.Size([690])\n", + "torch.from_numpy(rewards) looks like torch.Size([690])\n", + "rewards looks like (591,)\n", + "log_probs looks like (591,)\n", + "logs prob looks like torch.Size([591])\n", + "torch.from_numpy(rewards) looks like torch.Size([591])\n", + "rewards looks like (689,)\n", + "log_probs looks like (689,)\n", + "logs prob looks like torch.Size([689])\n", + "torch.from_numpy(rewards) looks like torch.Size([689])\n", + "rewards looks like (1059,)\n", + "log_probs looks like (1059,)\n", + "logs prob looks like torch.Size([1059])\n", + "torch.from_numpy(rewards) looks like torch.Size([1059])\n", + "rewards looks like (619,)\n", + "log_probs looks like (619,)\n", + "logs prob looks like torch.Size([619])\n", + "torch.from_numpy(rewards) looks like torch.Size([619])\n", + "rewards looks like (527,)\n", + "log_probs looks like (527,)\n", + "logs prob looks like torch.Size([527])\n", + "torch.from_numpy(rewards) looks like torch.Size([527])\n", + "rewards looks like (514,)\n", + "log_probs looks like (514,)\n", + "logs prob looks like torch.Size([514])\n", + "torch.from_numpy(rewards) looks like torch.Size([514])\n", + "rewards looks like (655,)\n", + "log_probs looks like (655,)\n", + "logs prob looks like torch.Size([655])\n", + "torch.from_numpy(rewards) looks like torch.Size([655])\n", + "rewards looks like (667,)\n", + "log_probs looks like (667,)\n", + "logs prob looks like torch.Size([667])\n", + "torch.from_numpy(rewards) looks like torch.Size([667])\n", + "rewards looks like (712,)\n", + "log_probs looks like (712,)\n", + "logs prob looks like torch.Size([712])\n", + "torch.from_numpy(rewards) looks like torch.Size([712])\n", + "rewards looks like (636,)\n", + "log_probs looks like (636,)\n", + "logs prob looks like torch.Size([636])\n", + "torch.from_numpy(rewards) looks like torch.Size([636])\n", + "rewards looks like (620,)\n", + "log_probs looks like (620,)\n", + "logs prob looks like torch.Size([620])\n", + "torch.from_numpy(rewards) looks like torch.Size([620])\n", + "rewards looks like (543,)\n", + "log_probs looks like (543,)\n", + "logs prob looks like torch.Size([543])\n", + "torch.from_numpy(rewards) looks like torch.Size([543])\n", + "rewards looks like (586,)\n", + "log_probs looks like (586,)\n", + "logs prob looks like torch.Size([586])\n", + "torch.from_numpy(rewards) looks like torch.Size([586])\n", + "rewards looks like (498,)\n", + "log_probs looks like (498,)\n", + "logs prob looks like torch.Size([498])\n", + "torch.from_numpy(rewards) looks like torch.Size([498])\n", + "rewards looks like (586,)\n", + "log_probs looks like (586,)\n", + "logs prob looks like torch.Size([586])\n", + "torch.from_numpy(rewards) looks like torch.Size([586])\n", + "rewards looks like (591,)\n", + "log_probs looks like (591,)\n", + "logs prob looks like torch.Size([591])\n", + "torch.from_numpy(rewards) looks like torch.Size([591])\n", + "rewards looks like (693,)\n", + "log_probs looks like (693,)\n", + "logs prob looks like torch.Size([693])\n", + "torch.from_numpy(rewards) looks like torch.Size([693])\n", + "rewards looks like (648,)\n", + "log_probs looks like (648,)\n", + "logs prob looks like torch.Size([648])\n", + "torch.from_numpy(rewards) looks like torch.Size([648])\n", + "rewards looks like (513,)\n", + "log_probs looks like (513,)\n", + "logs prob looks like torch.Size([513])\n", + "torch.from_numpy(rewards) looks like torch.Size([513])\n", + "rewards looks like (574,)\n", + "log_probs looks like (574,)\n", + "logs prob looks like torch.Size([574])\n", + "torch.from_numpy(rewards) looks like torch.Size([574])\n", + "rewards looks like (718,)\n", + "log_probs looks like (718,)\n", + "logs prob looks like torch.Size([718])\n", + "torch.from_numpy(rewards) looks like torch.Size([718])\n", + "rewards looks like (730,)\n", + "log_probs looks like (730,)\n", + "logs prob looks like torch.Size([730])\n", + "torch.from_numpy(rewards) looks like torch.Size([730])\n", + "rewards looks like (668,)\n", + "log_probs looks like (668,)\n", + "logs prob looks like torch.Size([668])\n", + "torch.from_numpy(rewards) looks like torch.Size([668])\n", + "rewards looks like (754,)\n", + "log_probs looks like (754,)\n", + "logs prob looks like torch.Size([754])\n", + "torch.from_numpy(rewards) looks like torch.Size([754])\n", + "rewards looks like (712,)\n", + "log_probs looks like (712,)\n", + "logs prob looks like torch.Size([712])\n", + "torch.from_numpy(rewards) looks like torch.Size([712])\n", + "rewards looks like (470,)\n", + "log_probs looks like (470,)\n", + "logs prob looks like torch.Size([470])\n", + "torch.from_numpy(rewards) looks like torch.Size([470])\n", + "rewards looks like (665,)\n", + "log_probs looks like (665,)\n", + "logs prob looks like torch.Size([665])\n", + "torch.from_numpy(rewards) looks like torch.Size([665])\n", + "rewards looks like (585,)\n", + "log_probs looks like (585,)\n", + "logs prob looks like torch.Size([585])\n", + "torch.from_numpy(rewards) looks like torch.Size([585])\n", + "rewards looks like (512,)\n", + "log_probs looks like (512,)\n", + "logs prob looks like torch.Size([512])\n", + "torch.from_numpy(rewards) looks like torch.Size([512])\n", + "rewards looks like (702,)\n", + "log_probs looks like (702,)\n", + "logs prob looks like torch.Size([702])\n", + "torch.from_numpy(rewards) looks like torch.Size([702])\n", + "rewards looks like (596,)\n", + "log_probs looks like (596,)\n", + "logs prob looks like torch.Size([596])\n", + "torch.from_numpy(rewards) looks like torch.Size([596])\n", + "rewards looks like (626,)\n", + "log_probs looks like (626,)\n", + "logs prob looks like torch.Size([626])\n", + "torch.from_numpy(rewards) looks like torch.Size([626])\n", + "rewards looks like (566,)\n", + "log_probs looks like (566,)\n", + "logs prob looks like torch.Size([566])\n", + "torch.from_numpy(rewards) looks like torch.Size([566])\n", + "rewards looks like (717,)\n", + "log_probs looks like (717,)\n", + "logs prob looks like torch.Size([717])\n", + "torch.from_numpy(rewards) looks like torch.Size([717])\n", + "rewards looks like (708,)\n", + "log_probs looks like (708,)\n", + "logs prob looks like torch.Size([708])\n", + "torch.from_numpy(rewards) looks like torch.Size([708])\n", + "rewards looks like (565,)\n", + "log_probs looks like (565,)\n", + "logs prob looks like torch.Size([565])\n", + "torch.from_numpy(rewards) looks like torch.Size([565])\n", + "rewards looks like (450,)\n", + "log_probs looks like (450,)\n", + "logs prob looks like torch.Size([450])\n", + "torch.from_numpy(rewards) looks like torch.Size([450])\n", + "rewards looks like (584,)\n", + "log_probs looks like (584,)\n", + "logs prob looks like torch.Size([584])\n", + "torch.from_numpy(rewards) looks like torch.Size([584])\n", + "rewards looks like (670,)\n", + "log_probs looks like (670,)\n", + "logs prob looks like torch.Size([670])\n", + "torch.from_numpy(rewards) looks like torch.Size([670])\n", + "rewards looks like (691,)\n", + "log_probs looks like (691,)\n", + "logs prob looks like torch.Size([691])\n", + "torch.from_numpy(rewards) looks like torch.Size([691])\n", + "rewards looks like (760,)\n", + "log_probs looks like (760,)\n", + "logs prob looks like torch.Size([760])\n", + "torch.from_numpy(rewards) looks like torch.Size([760])\n", + "rewards looks like (752,)\n", + "log_probs looks like (752,)\n", + "logs prob looks like torch.Size([752])\n", + "torch.from_numpy(rewards) looks like torch.Size([752])\n", + "rewards looks like (478,)\n", + "log_probs looks like (478,)\n", + "logs prob looks like torch.Size([478])\n", + "torch.from_numpy(rewards) looks like torch.Size([478])\n", + "rewards looks like (553,)\n", + "log_probs looks like (553,)\n", + "logs prob looks like torch.Size([553])\n", + "torch.from_numpy(rewards) looks like torch.Size([553])\n", + "rewards looks like (1660,)\n", + "log_probs looks like (1660,)\n", + "logs prob looks like torch.Size([1660])\n", + "torch.from_numpy(rewards) looks like torch.Size([1660])\n", + "rewards looks like (751,)\n", + "log_probs looks like (751,)\n", + "logs prob looks like torch.Size([751])\n", + "torch.from_numpy(rewards) looks like torch.Size([751])\n", + "rewards looks like (801,)\n", + "log_probs looks like (801,)\n", + "logs prob looks like torch.Size([801])\n", + "torch.from_numpy(rewards) looks like torch.Size([801])\n", + "rewards looks like (715,)\n", + "log_probs looks like (715,)\n", + "logs prob looks like torch.Size([715])\n", + "torch.from_numpy(rewards) looks like torch.Size([715])\n", + "rewards looks like (708,)\n", + "log_probs looks like (708,)\n", + "logs prob looks like torch.Size([708])\n", + "torch.from_numpy(rewards) looks like torch.Size([708])\n", + "rewards looks like (609,)\n", + "log_probs looks like (609,)\n", + "logs prob looks like torch.Size([609])\n", + "torch.from_numpy(rewards) looks like torch.Size([609])\n", + "rewards looks like (732,)\n", + "log_probs looks like (732,)\n", + "logs prob looks like torch.Size([732])\n", + "torch.from_numpy(rewards) looks like torch.Size([732])\n", + "rewards looks like (603,)\n", + "log_probs looks like (603,)\n", + "logs prob looks like torch.Size([603])\n", + "torch.from_numpy(rewards) looks like torch.Size([603])\n", + "rewards looks like (603,)\n", + "log_probs looks like (603,)\n", + "logs prob looks like torch.Size([603])\n", + "torch.from_numpy(rewards) looks like torch.Size([603])\n", + "rewards looks like (665,)\n", + "log_probs looks like (665,)\n", + "logs prob looks like torch.Size([665])\n", + "torch.from_numpy(rewards) looks like torch.Size([665])\n", + "rewards looks like (658,)\n", + "log_probs looks like (658,)\n", + "logs prob looks like torch.Size([658])\n", + "torch.from_numpy(rewards) looks like torch.Size([658])\n", + "rewards looks like (783,)\n", + "log_probs looks like (783,)\n", + "logs prob looks like torch.Size([783])\n", + "torch.from_numpy(rewards) looks like torch.Size([783])\n", + "rewards looks like (652,)\n", + "log_probs looks like (652,)\n", + "logs prob looks like torch.Size([652])\n", + "torch.from_numpy(rewards) looks like torch.Size([652])\n", + "rewards looks like (892,)\n", + "log_probs looks like (892,)\n", + "logs prob looks like torch.Size([892])\n", + "torch.from_numpy(rewards) looks like torch.Size([892])\n", + "rewards looks like (821,)\n", + "log_probs looks like (821,)\n", + "logs prob looks like torch.Size([821])\n", + "torch.from_numpy(rewards) looks like torch.Size([821])\n", + "rewards looks like (986,)\n", + "log_probs looks like (986,)\n", + "logs prob looks like torch.Size([986])\n", + "torch.from_numpy(rewards) looks like torch.Size([986])\n", + "rewards looks like (916,)\n", + "log_probs looks like (916,)\n", + "logs prob looks like torch.Size([916])\n", + "torch.from_numpy(rewards) looks like torch.Size([916])\n", + "rewards looks like (742,)\n", + "log_probs looks like (742,)\n", + "logs prob looks like torch.Size([742])\n", + "torch.from_numpy(rewards) looks like torch.Size([742])\n", + "rewards looks like (604,)\n", + "log_probs looks like (604,)\n", + "logs prob looks like torch.Size([604])\n", + "torch.from_numpy(rewards) looks like torch.Size([604])\n", + "rewards looks like (818,)\n", + "log_probs looks like (818,)\n", + "logs prob looks like torch.Size([818])\n", + "torch.from_numpy(rewards) looks like torch.Size([818])\n", + "rewards looks like (855,)\n", + "log_probs looks like (855,)\n", + "logs prob looks like torch.Size([855])\n", + "torch.from_numpy(rewards) looks like torch.Size([855])\n", + "rewards looks like (795,)\n", + "log_probs looks like (795,)\n", + "logs prob looks like torch.Size([795])\n", + "torch.from_numpy(rewards) looks like torch.Size([795])\n", + "rewards looks like (868,)\n", + "log_probs looks like (868,)\n", + "logs prob looks like torch.Size([868])\n", + "torch.from_numpy(rewards) looks like torch.Size([868])\n", + "rewards looks like (800,)\n", + "log_probs looks like (800,)\n", + "logs prob looks like torch.Size([800])\n", + "torch.from_numpy(rewards) looks like torch.Size([800])\n", + "rewards looks like (820,)\n", + "log_probs looks like (820,)\n", + "logs prob looks like torch.Size([820])\n", + "torch.from_numpy(rewards) looks like torch.Size([820])\n", + "rewards looks like (760,)\n", + "log_probs looks like (760,)\n", + "logs prob looks like torch.Size([760])\n", + "torch.from_numpy(rewards) looks like torch.Size([760])\n", + "rewards looks like (886,)\n", + "log_probs looks like (886,)\n", + "logs prob looks like torch.Size([886])\n", + "torch.from_numpy(rewards) looks like torch.Size([886])\n", + "rewards looks like (1027,)\n", + "log_probs looks like (1027,)\n", + "logs prob looks like torch.Size([1027])\n", + "torch.from_numpy(rewards) looks like torch.Size([1027])\n", + "rewards looks like (819,)\n", + "log_probs looks like (819,)\n", + "logs prob looks like torch.Size([819])\n", + "torch.from_numpy(rewards) looks like torch.Size([819])\n", + "rewards looks like (934,)\n", + "log_probs looks like (934,)\n", + "logs prob looks like torch.Size([934])\n", + "torch.from_numpy(rewards) looks like torch.Size([934])\n", + "rewards looks like (1648,)\n", + "log_probs looks like (1648,)\n", + "logs prob looks like torch.Size([1648])\n", + "torch.from_numpy(rewards) looks like torch.Size([1648])\n", + "rewards looks like (1057,)\n", + "log_probs looks like (1057,)\n", + "logs prob looks like torch.Size([1057])\n", + "torch.from_numpy(rewards) looks like torch.Size([1057])\n", + "rewards looks like (861,)\n", + "log_probs looks like (861,)\n", + "logs prob looks like torch.Size([861])\n", + "torch.from_numpy(rewards) looks like torch.Size([861])\n", + "rewards looks like (1533,)\n", + "log_probs looks like (1533,)\n", + "logs prob looks like torch.Size([1533])\n", + "torch.from_numpy(rewards) looks like torch.Size([1533])\n", + "rewards looks like (920,)\n", + "log_probs looks like (920,)\n", + "logs prob looks like torch.Size([920])\n", + "torch.from_numpy(rewards) looks like torch.Size([920])\n", + "rewards looks like (905,)\n", + "log_probs looks like (905,)\n", + "logs prob looks like torch.Size([905])\n", + "torch.from_numpy(rewards) looks like torch.Size([905])\n", + "rewards looks like (814,)\n", + "log_probs looks like (814,)\n", + "logs prob looks like torch.Size([814])\n", + "torch.from_numpy(rewards) looks like torch.Size([814])\n", + "rewards looks like (809,)\n", + "log_probs looks like (809,)\n", + "logs prob looks like torch.Size([809])\n", + "torch.from_numpy(rewards) looks like torch.Size([809])\n", + "rewards looks like (873,)\n", + "log_probs looks like (873,)\n", + "logs prob looks like torch.Size([873])\n", + "torch.from_numpy(rewards) looks like torch.Size([873])\n", + "rewards looks like (727,)\n", + "log_probs looks like (727,)\n", + "logs prob looks like torch.Size([727])\n", + "torch.from_numpy(rewards) looks like torch.Size([727])\n", + "rewards looks like (1129,)\n", + "log_probs looks like (1129,)\n", + "logs prob looks like torch.Size([1129])\n", + "torch.from_numpy(rewards) looks like torch.Size([1129])\n", + "rewards looks like (1394,)\n", + "log_probs looks like (1394,)\n", + "logs prob looks like torch.Size([1394])\n", + "torch.from_numpy(rewards) looks like torch.Size([1394])\n", + "rewards looks like (884,)\n", + "log_probs looks like (884,)\n", + "logs prob looks like torch.Size([884])\n", + "torch.from_numpy(rewards) looks like torch.Size([884])\n", + "rewards looks like (1132,)\n", + "log_probs looks like (1132,)\n", + "logs prob looks like torch.Size([1132])\n", + "torch.from_numpy(rewards) looks like torch.Size([1132])\n", + "rewards looks like (1007,)\n", + "log_probs looks like (1007,)\n", + "logs prob looks like torch.Size([1007])\n", + "torch.from_numpy(rewards) looks like torch.Size([1007])\n", + "rewards looks like (711,)\n", + "log_probs looks like (711,)\n", + "logs prob looks like torch.Size([711])\n", + "torch.from_numpy(rewards) looks like torch.Size([711])\n", + "rewards looks like (836,)\n", + "log_probs looks like (836,)\n", + "logs prob looks like torch.Size([836])\n", + "torch.from_numpy(rewards) looks like torch.Size([836])\n", + "rewards looks like (1514,)\n", + "log_probs looks like (1514,)\n", + "logs prob looks like torch.Size([1514])\n", + "torch.from_numpy(rewards) looks like torch.Size([1514])\n", + "rewards looks like (896,)\n", + "log_probs looks like (896,)\n", + "logs prob looks like torch.Size([896])\n", + "torch.from_numpy(rewards) looks like torch.Size([896])\n", + "rewards looks like (912,)\n", + "log_probs looks like (912,)\n", + "logs prob looks like torch.Size([912])\n", + "torch.from_numpy(rewards) looks like torch.Size([912])\n", + "rewards looks like (1478,)\n", + "log_probs looks like (1478,)\n", + "logs prob looks like torch.Size([1478])\n", + "torch.from_numpy(rewards) looks like torch.Size([1478])\n", + "rewards looks like (1279,)\n", + "log_probs looks like (1279,)\n", + "logs prob looks like torch.Size([1279])\n", + "torch.from_numpy(rewards) looks like torch.Size([1279])\n", + "rewards looks like (676,)\n", + "log_probs looks like (676,)\n", + "logs prob looks like torch.Size([676])\n", + "torch.from_numpy(rewards) looks like torch.Size([676])\n", + "rewards looks like (1768,)\n", + "log_probs looks like (1768,)\n", + "logs prob looks like torch.Size([1768])\n", + "torch.from_numpy(rewards) looks like torch.Size([1768])\n", + "rewards looks like (897,)\n", + "log_probs looks like (897,)\n", + "logs prob looks like torch.Size([897])\n", + "torch.from_numpy(rewards) looks like torch.Size([897])\n", + "rewards looks like (1252,)\n", + "log_probs looks like (1252,)\n", + "logs prob looks like torch.Size([1252])\n", + "torch.from_numpy(rewards) looks like torch.Size([1252])\n", + "rewards looks like (995,)\n", + "log_probs looks like (995,)\n", + "logs prob looks like torch.Size([995])\n", + "torch.from_numpy(rewards) looks like torch.Size([995])\n", + "rewards looks like (1075,)\n", + "log_probs looks like (1075,)\n", + "logs prob looks like torch.Size([1075])\n", + "torch.from_numpy(rewards) looks like torch.Size([1075])\n", + "rewards looks like (878,)\n", + "log_probs looks like (878,)\n", + "logs prob looks like torch.Size([878])\n", + "torch.from_numpy(rewards) looks like torch.Size([878])\n", + "rewards looks like (1341,)\n", + "log_probs looks like (1341,)\n", + "logs prob looks like torch.Size([1341])\n", + "torch.from_numpy(rewards) looks like torch.Size([1341])\n", + "rewards looks like (1518,)\n", + "log_probs looks like (1518,)\n", + "logs prob looks like torch.Size([1518])\n", + "torch.from_numpy(rewards) looks like torch.Size([1518])\n", + "rewards looks like (1781,)\n", + "log_probs looks like (1781,)\n", + "logs prob looks like torch.Size([1781])\n", + "torch.from_numpy(rewards) looks like torch.Size([1781])\n", + "rewards looks like (1725,)\n", + "log_probs looks like (1725,)\n", + "logs prob looks like torch.Size([1725])\n", + "torch.from_numpy(rewards) looks like torch.Size([1725])\n", + "rewards looks like (1602,)\n", + "log_probs looks like (1602,)\n", + "logs prob looks like torch.Size([1602])\n", + "torch.from_numpy(rewards) looks like torch.Size([1602])\n", + "rewards looks like (846,)\n", + "log_probs looks like (846,)\n", + "logs prob looks like torch.Size([846])\n", + "torch.from_numpy(rewards) looks like torch.Size([846])\n", + "rewards looks like (1211,)\n", + "log_probs looks like (1211,)\n", + "logs prob looks like torch.Size([1211])\n", + "torch.from_numpy(rewards) looks like torch.Size([1211])\n", + "rewards looks like (3273,)\n", + "log_probs looks like (3273,)\n", + "logs prob looks like torch.Size([3273])\n", + "torch.from_numpy(rewards) looks like torch.Size([3273])\n", + "rewards looks like (744,)\n", + "log_probs looks like (744,)\n", + "logs prob looks like torch.Size([744])\n", + "torch.from_numpy(rewards) looks like torch.Size([744])\n", + "rewards looks like (1751,)\n", + "log_probs looks like (1751,)\n", + "logs prob looks like torch.Size([1751])\n", + "torch.from_numpy(rewards) looks like torch.Size([1751])\n", + "rewards looks like (1244,)\n", + "log_probs looks like (1244,)\n", + "logs prob looks like torch.Size([1244])\n", + "torch.from_numpy(rewards) looks like torch.Size([1244])\n", + "rewards looks like (1313,)\n", + "log_probs looks like (1313,)\n", + "logs prob looks like torch.Size([1313])\n", + "torch.from_numpy(rewards) looks like torch.Size([1313])\n", + "rewards looks like (1993,)\n", + "log_probs looks like (1993,)\n", + "logs prob looks like torch.Size([1993])\n", + "torch.from_numpy(rewards) looks like torch.Size([1993])\n", + "rewards looks like (709,)\n", + "log_probs looks like (709,)\n", + "logs prob looks like torch.Size([709])\n", + "torch.from_numpy(rewards) looks like torch.Size([709])\n", + "rewards looks like (934,)\n", + "log_probs looks like (934,)\n", + "logs prob looks like torch.Size([934])\n", + "torch.from_numpy(rewards) looks like torch.Size([934])\n", + "rewards looks like (1386,)\n", + "log_probs looks like (1386,)\n", + "logs prob looks like torch.Size([1386])\n", + "torch.from_numpy(rewards) looks like torch.Size([1386])\n", + "rewards looks like (635,)\n", + "log_probs looks like (635,)\n", + "logs prob looks like torch.Size([635])\n", + "torch.from_numpy(rewards) looks like torch.Size([635])\n", + "rewards looks like (750,)\n", + "log_probs looks like (750,)\n", + "logs prob looks like torch.Size([750])\n", + "torch.from_numpy(rewards) looks like torch.Size([750])\n", + "rewards looks like (1832,)\n", + "log_probs looks like (1832,)\n", + "logs prob looks like torch.Size([1832])\n", + "torch.from_numpy(rewards) looks like torch.Size([1832])\n", + "rewards looks like (1237,)\n", + "log_probs looks like (1237,)\n", + "logs prob looks like torch.Size([1237])\n", + "torch.from_numpy(rewards) looks like torch.Size([1237])\n", + "rewards looks like (1605,)\n", + "log_probs looks like (1605,)\n", + "logs prob looks like torch.Size([1605])\n", + "torch.from_numpy(rewards) looks like torch.Size([1605])\n", + "rewards looks like (718,)\n", + "log_probs looks like (718,)\n", + "logs prob looks like torch.Size([718])\n", + "torch.from_numpy(rewards) looks like torch.Size([718])\n", + "rewards looks like (966,)\n", + "log_probs looks like (966,)\n", + "logs prob looks like torch.Size([966])\n", + "torch.from_numpy(rewards) looks like torch.Size([966])\n", + "rewards looks like (2696,)\n", + "log_probs looks like (2696,)\n", + "logs prob looks like torch.Size([2696])\n", + "torch.from_numpy(rewards) looks like torch.Size([2696])\n", + "rewards looks like (762,)\n", + "log_probs looks like (762,)\n", + "logs prob looks like torch.Size([762])\n", + "torch.from_numpy(rewards) looks like torch.Size([762])\n", + "rewards looks like (1048,)\n", + "log_probs looks like (1048,)\n", + "logs prob looks like torch.Size([1048])\n", + "torch.from_numpy(rewards) looks like torch.Size([1048])\n", + "rewards looks like (1573,)\n", + "log_probs looks like (1573,)\n", + "logs prob looks like torch.Size([1573])\n", + "torch.from_numpy(rewards) looks like torch.Size([1573])\n", + "rewards looks like (2192,)\n", + "log_probs looks like (2192,)\n", + "logs prob looks like torch.Size([2192])\n", + "torch.from_numpy(rewards) looks like torch.Size([2192])\n", + "rewards looks like (599,)\n", + "log_probs looks like (599,)\n", + "logs prob looks like torch.Size([599])\n", + "torch.from_numpy(rewards) looks like torch.Size([599])\n", + "rewards looks like (758,)\n", + "log_probs looks like (758,)\n", + "logs prob looks like torch.Size([758])\n", + "torch.from_numpy(rewards) looks like torch.Size([758])\n", + "rewards looks like (1955,)\n", + "log_probs looks like (1955,)\n", + "logs prob looks like torch.Size([1955])\n", + "torch.from_numpy(rewards) looks like torch.Size([1955])\n", + "rewards looks like (1770,)\n", + "log_probs looks like (1770,)\n", + "logs prob looks like torch.Size([1770])\n", + "torch.from_numpy(rewards) looks like torch.Size([1770])\n", + "rewards looks like (1343,)\n", + "log_probs looks like (1343,)\n", + "logs prob looks like torch.Size([1343])\n", + "torch.from_numpy(rewards) looks like torch.Size([1343])\n", + "rewards looks like (507,)\n", + "log_probs looks like (507,)\n", + "logs prob looks like torch.Size([507])\n", + "torch.from_numpy(rewards) looks like torch.Size([507])\n", + "rewards looks like (1343,)\n", + "log_probs looks like (1343,)\n", + "logs prob looks like torch.Size([1343])\n", + "torch.from_numpy(rewards) looks like torch.Size([1343])\n", + "rewards looks like (1341,)\n", + "log_probs looks like (1341,)\n", + "logs prob looks like torch.Size([1341])\n", + "torch.from_numpy(rewards) looks like torch.Size([1341])\n", + "rewards looks like (1489,)\n", + "log_probs looks like (1489,)\n", + "logs prob looks like torch.Size([1489])\n", + "torch.from_numpy(rewards) looks like torch.Size([1489])\n", + "rewards looks like (3342,)\n", + "log_probs looks like (3342,)\n", + "logs prob looks like torch.Size([3342])\n", + "torch.from_numpy(rewards) looks like torch.Size([3342])\n", + "rewards looks like (1891,)\n", + "log_probs looks like (1891,)\n", + "logs prob looks like torch.Size([1891])\n", + "torch.from_numpy(rewards) looks like torch.Size([1891])\n", + "rewards looks like (1401,)\n", + "log_probs looks like (1401,)\n", + "logs prob looks like torch.Size([1401])\n", + "torch.from_numpy(rewards) looks like torch.Size([1401])\n", + "rewards looks like (2964,)\n", + "log_probs looks like (2964,)\n", + "logs prob looks like torch.Size([2964])\n", + "torch.from_numpy(rewards) looks like torch.Size([2964])\n", + "rewards looks like (1404,)\n", + "log_probs looks like (1404,)\n", + "logs prob looks like torch.Size([1404])\n", + "torch.from_numpy(rewards) looks like torch.Size([1404])\n", + "rewards looks like (780,)\n", + "log_probs looks like (780,)\n", + "logs prob looks like torch.Size([780])\n", + "torch.from_numpy(rewards) looks like torch.Size([780])\n", + "rewards looks like (1632,)\n", + "log_probs looks like (1632,)\n", + "logs prob looks like torch.Size([1632])\n", + "torch.from_numpy(rewards) looks like torch.Size([1632])\n", + "rewards looks like (1578,)\n", + "log_probs looks like (1578,)\n", + "logs prob looks like torch.Size([1578])\n", + "torch.from_numpy(rewards) looks like torch.Size([1578])\n", + "rewards looks like (1082,)\n", + "log_probs looks like (1082,)\n", + "logs prob looks like torch.Size([1082])\n", + "torch.from_numpy(rewards) looks like torch.Size([1082])\n", + "rewards looks like (1423,)\n", + "log_probs looks like (1423,)\n", + "logs prob looks like torch.Size([1423])\n", + "torch.from_numpy(rewards) looks like torch.Size([1423])\n", + "rewards looks like (2867,)\n", + "log_probs looks like (2867,)\n", + "logs prob looks like torch.Size([2867])\n", + "torch.from_numpy(rewards) looks like torch.Size([2867])\n", + "rewards looks like (1733,)\n", + "log_probs looks like (1733,)\n", + "logs prob looks like torch.Size([1733])\n", + "torch.from_numpy(rewards) looks like torch.Size([1733])\n", + "rewards looks like (646,)\n", + "log_probs looks like (646,)\n", + "logs prob looks like torch.Size([646])\n", + "torch.from_numpy(rewards) looks like torch.Size([646])\n", + "rewards looks like (1576,)\n", + "log_probs looks like (1576,)\n", + "logs prob looks like torch.Size([1576])\n", + "torch.from_numpy(rewards) looks like torch.Size([1576])\n", + "rewards looks like (1869,)\n", + "log_probs looks like (1869,)\n", + "logs prob looks like torch.Size([1869])\n", + "torch.from_numpy(rewards) looks like torch.Size([1869])\n", + "rewards looks like (1862,)\n", + "log_probs looks like (1862,)\n", + "logs prob looks like torch.Size([1862])\n", + "torch.from_numpy(rewards) looks like torch.Size([1862])\n", + "rewards looks like (3182,)\n", + "log_probs looks like (3182,)\n", + "logs prob looks like torch.Size([3182])\n", + "torch.from_numpy(rewards) looks like torch.Size([3182])\n", + "rewards looks like (1746,)\n", + "log_probs looks like (1746,)\n", + "logs prob looks like torch.Size([1746])\n", + "torch.from_numpy(rewards) looks like torch.Size([1746])\n", + "rewards looks like (1855,)\n", + "log_probs looks like (1855,)\n", + "logs prob looks like torch.Size([1855])\n", + "torch.from_numpy(rewards) looks like torch.Size([1855])\n", + "rewards looks like (2710,)\n", + "log_probs looks like (2710,)\n", + "logs prob looks like torch.Size([2710])\n", + "torch.from_numpy(rewards) looks like torch.Size([2710])\n", + "rewards looks like (1707,)\n", + "log_probs looks like (1707,)\n", + "logs prob looks like torch.Size([1707])\n", + "torch.from_numpy(rewards) looks like torch.Size([1707])\n", + "rewards looks like (1723,)\n", + "log_probs looks like (1723,)\n", + "logs prob looks like torch.Size([1723])\n", + "torch.from_numpy(rewards) looks like torch.Size([1723])\n", + "rewards looks like (1590,)\n", + "log_probs looks like (1590,)\n", + "logs prob looks like torch.Size([1590])\n", + "torch.from_numpy(rewards) looks like torch.Size([1590])\n", + "rewards looks like (1432,)\n", + "log_probs looks like (1432,)\n", + "logs prob looks like torch.Size([1432])\n", + "torch.from_numpy(rewards) looks like torch.Size([1432])\n", + "rewards looks like (2742,)\n", + "log_probs looks like (2742,)\n", + "logs prob looks like torch.Size([2742])\n", + "torch.from_numpy(rewards) looks like torch.Size([2742])\n", + "rewards looks like (3007,)\n", + "log_probs looks like (3007,)\n", + "logs prob looks like torch.Size([3007])\n", + "torch.from_numpy(rewards) looks like torch.Size([3007])\n", + "rewards looks like (2064,)\n", + "log_probs looks like (2064,)\n", + "logs prob looks like torch.Size([2064])\n", + "torch.from_numpy(rewards) looks like torch.Size([2064])\n", + "rewards looks like (1447,)\n", + "log_probs looks like (1447,)\n", + "logs prob looks like torch.Size([1447])\n", + "torch.from_numpy(rewards) looks like torch.Size([1447])\n", + "rewards looks like (4007,)\n", + "log_probs looks like (4007,)\n", + "logs prob looks like torch.Size([4007])\n", + "torch.from_numpy(rewards) looks like torch.Size([4007])\n", + "rewards looks like (611,)\n", + "log_probs looks like (611,)\n", + "logs prob looks like torch.Size([611])\n", + "torch.from_numpy(rewards) looks like torch.Size([611])\n", + "rewards looks like (1633,)\n", + "log_probs looks like (1633,)\n", + "logs prob looks like torch.Size([1633])\n", + "torch.from_numpy(rewards) looks like torch.Size([1633])\n", + "rewards looks like (3295,)\n", + "log_probs looks like (3295,)\n", + "logs prob looks like torch.Size([3295])\n", + "torch.from_numpy(rewards) looks like torch.Size([3295])\n", + "rewards looks like (975,)\n", + "log_probs looks like (975,)\n", + "logs prob looks like torch.Size([975])\n", + "torch.from_numpy(rewards) looks like torch.Size([975])\n", + "rewards looks like (1991,)\n", + "log_probs looks like (1991,)\n", + "logs prob looks like torch.Size([1991])\n", + "torch.from_numpy(rewards) looks like torch.Size([1991])\n", + "rewards looks like (2409,)\n", + "log_probs looks like (2409,)\n", + "logs prob looks like torch.Size([2409])\n", + "torch.from_numpy(rewards) looks like torch.Size([2409])\n", + "rewards looks like (1587,)\n", + "log_probs looks like (1587,)\n", + "logs prob looks like torch.Size([1587])\n", + "torch.from_numpy(rewards) looks like torch.Size([1587])\n", + "rewards looks like (1334,)\n", + "log_probs looks like (1334,)\n", + "logs prob looks like torch.Size([1334])\n", + "torch.from_numpy(rewards) looks like torch.Size([1334])\n", + "rewards looks like (1070,)\n", + "log_probs looks like (1070,)\n", + "logs prob looks like torch.Size([1070])\n", + "torch.from_numpy(rewards) looks like torch.Size([1070])\n", + "rewards looks like (1082,)\n", + "log_probs looks like (1082,)\n", + "logs prob looks like torch.Size([1082])\n", + "torch.from_numpy(rewards) looks like torch.Size([1082])\n", + "rewards looks like (1084,)\n", + "log_probs looks like (1084,)\n", + "logs prob looks like torch.Size([1084])\n", + "torch.from_numpy(rewards) looks like torch.Size([1084])\n", + "rewards looks like (1192,)\n", + "log_probs looks like (1192,)\n", + "logs prob looks like torch.Size([1192])\n", + "torch.from_numpy(rewards) looks like torch.Size([1192])\n", + "rewards looks like (1287,)\n", + "log_probs looks like (1287,)\n", + "logs prob looks like torch.Size([1287])\n", + "torch.from_numpy(rewards) looks like torch.Size([1287])\n", + "rewards looks like (1718,)\n", + "log_probs looks like (1718,)\n", + "logs prob looks like torch.Size([1718])\n", + "torch.from_numpy(rewards) looks like torch.Size([1718])\n", + "rewards looks like (1859,)\n", + "log_probs looks like (1859,)\n", + "logs prob looks like torch.Size([1859])\n", + "torch.from_numpy(rewards) looks like torch.Size([1859])\n", + "rewards looks like (1215,)\n", + "log_probs looks like (1215,)\n", + "logs prob looks like torch.Size([1215])\n", + "torch.from_numpy(rewards) looks like torch.Size([1215])\n", + "rewards looks like (1181,)\n", + "log_probs looks like (1181,)\n", + "logs prob looks like torch.Size([1181])\n", + "torch.from_numpy(rewards) looks like torch.Size([1181])\n", + "rewards looks like (1378,)\n", + "log_probs looks like (1378,)\n", + "logs prob looks like torch.Size([1378])\n", + "torch.from_numpy(rewards) looks like torch.Size([1378])\n", + "rewards looks like (1851,)\n", + "log_probs looks like (1851,)\n", + "logs prob looks like torch.Size([1851])\n", + "torch.from_numpy(rewards) looks like torch.Size([1851])\n", + "rewards looks like (2218,)\n", + "log_probs looks like (2218,)\n", + "logs prob looks like torch.Size([2218])\n", + "torch.from_numpy(rewards) looks like torch.Size([2218])\n", + "rewards looks like (2502,)\n", + "log_probs looks like (2502,)\n", + "logs prob looks like torch.Size([2502])\n", + "torch.from_numpy(rewards) looks like torch.Size([2502])\n", + "rewards looks like (1642,)\n", + "log_probs looks like (1642,)\n", + "logs prob looks like torch.Size([1642])\n", + "torch.from_numpy(rewards) looks like torch.Size([1642])\n", + "rewards looks like (1892,)\n", + "log_probs looks like (1892,)\n", + "logs prob looks like torch.Size([1892])\n", + "torch.from_numpy(rewards) looks like torch.Size([1892])\n", + "rewards looks like (2003,)\n", + "log_probs looks like (2003,)\n", + "logs prob looks like torch.Size([2003])\n", + "torch.from_numpy(rewards) looks like torch.Size([2003])\n", + "rewards looks like (3407,)\n", + "log_probs looks like (3407,)\n", + "logs prob looks like torch.Size([3407])\n", + "torch.from_numpy(rewards) looks like torch.Size([3407])\n", + "rewards looks like (3425,)\n", + "log_probs looks like (3425,)\n", + "logs prob looks like torch.Size([3425])\n", + "torch.from_numpy(rewards) looks like torch.Size([3425])\n", + "rewards looks like (1840,)\n", + "log_probs looks like (1840,)\n", + "logs prob looks like torch.Size([1840])\n", + "torch.from_numpy(rewards) looks like torch.Size([1840])\n", + "rewards looks like (1529,)\n", + "log_probs looks like (1529,)\n", + "logs prob looks like torch.Size([1529])\n", + "torch.from_numpy(rewards) looks like torch.Size([1529])\n", + "rewards looks like (1407,)\n", + "log_probs looks like (1407,)\n", + "logs prob looks like torch.Size([1407])\n", + "torch.from_numpy(rewards) looks like torch.Size([1407])\n", + "rewards looks like (2541,)\n", + "log_probs looks like (2541,)\n", + "logs prob looks like torch.Size([2541])\n", + "torch.from_numpy(rewards) looks like torch.Size([2541])\n", + "rewards looks like (1194,)\n", + "log_probs looks like (1194,)\n", + "logs prob looks like torch.Size([1194])\n", + "torch.from_numpy(rewards) looks like torch.Size([1194])\n", + "rewards looks like (1431,)\n", + "log_probs looks like (1431,)\n", + "logs prob looks like torch.Size([1431])\n", + "torch.from_numpy(rewards) looks like torch.Size([1431])\n", + "rewards looks like (3340,)\n", + "log_probs looks like (3340,)\n", + "logs prob looks like torch.Size([3340])\n", + "torch.from_numpy(rewards) looks like torch.Size([3340])\n", + "rewards looks like (897,)\n", + "log_probs looks like (897,)\n", + "logs prob looks like torch.Size([897])\n", + "torch.from_numpy(rewards) looks like torch.Size([897])\n", + "rewards looks like (1821,)\n", + "log_probs looks like (1821,)\n", + "logs prob looks like torch.Size([1821])\n", + "torch.from_numpy(rewards) looks like torch.Size([1821])\n", + "rewards looks like (1906,)\n", + "log_probs looks like (1906,)\n", + "logs prob looks like torch.Size([1906])\n", + "torch.from_numpy(rewards) looks like torch.Size([1906])\n", + "rewards looks like (2688,)\n", + "log_probs looks like (2688,)\n", + "logs prob looks like torch.Size([2688])\n", + "torch.from_numpy(rewards) looks like torch.Size([2688])\n", + "rewards looks like (1169,)\n", + "log_probs looks like (1169,)\n", + "logs prob looks like torch.Size([1169])\n", + "torch.from_numpy(rewards) looks like torch.Size([1169])\n", + "rewards looks like (1444,)\n", + "log_probs looks like (1444,)\n", + "logs prob looks like torch.Size([1444])\n", + "torch.from_numpy(rewards) looks like torch.Size([1444])\n", + "rewards looks like (1376,)\n", + "log_probs looks like (1376,)\n", + "logs prob looks like torch.Size([1376])\n", + "torch.from_numpy(rewards) looks like torch.Size([1376])\n", + "rewards looks like (1395,)\n", + "log_probs looks like (1395,)\n", + "logs prob looks like torch.Size([1395])\n", + "torch.from_numpy(rewards) looks like torch.Size([1395])\n", + "rewards looks like (899,)\n", + "log_probs looks like (899,)\n", + "logs prob looks like torch.Size([899])\n", + "torch.from_numpy(rewards) looks like torch.Size([899])\n", + "rewards looks like (2152,)\n", + "log_probs looks like (2152,)\n", + "logs prob looks like torch.Size([2152])\n", + "torch.from_numpy(rewards) looks like torch.Size([2152])\n", + "rewards looks like (2294,)\n", + "log_probs looks like (2294,)\n", + "logs prob looks like torch.Size([2294])\n", + "torch.from_numpy(rewards) looks like torch.Size([2294])\n", + "rewards looks like (881,)\n", + "log_probs looks like (881,)\n", + "logs prob looks like torch.Size([881])\n", + "torch.from_numpy(rewards) looks like torch.Size([881])\n", + "rewards looks like (1050,)\n", + "log_probs looks like (1050,)\n", + "logs prob looks like torch.Size([1050])\n", + "torch.from_numpy(rewards) looks like torch.Size([1050])\n", + "rewards looks like (1294,)\n", + "log_probs looks like (1294,)\n", + "logs prob looks like torch.Size([1294])\n", + "torch.from_numpy(rewards) looks like torch.Size([1294])\n", + "rewards looks like (1419,)\n", + "log_probs looks like (1419,)\n", + "logs prob looks like torch.Size([1419])\n", + "torch.from_numpy(rewards) looks like torch.Size([1419])\n", + "rewards looks like (1433,)\n", + "log_probs looks like (1433,)\n", + "logs prob looks like torch.Size([1433])\n", + "torch.from_numpy(rewards) looks like torch.Size([1433])\n", + "rewards looks like (2196,)\n", + "log_probs looks like (2196,)\n", + "logs prob looks like torch.Size([2196])\n", + "torch.from_numpy(rewards) looks like torch.Size([2196])\n", + "rewards looks like (1811,)\n", + "log_probs looks like (1811,)\n", + "logs prob looks like torch.Size([1811])\n", + "torch.from_numpy(rewards) looks like torch.Size([1811])\n", + "rewards looks like (1418,)\n", + "log_probs looks like (1418,)\n", + "logs prob looks like torch.Size([1418])\n", + "torch.from_numpy(rewards) looks like torch.Size([1418])\n", + "rewards looks like (1536,)\n", + "log_probs looks like (1536,)\n", + "logs prob looks like torch.Size([1536])\n", + "torch.from_numpy(rewards) looks like torch.Size([1536])\n", + "rewards looks like (1353,)\n", + "log_probs looks like (1353,)\n", + "logs prob looks like torch.Size([1353])\n", + "torch.from_numpy(rewards) looks like torch.Size([1353])\n", + "rewards looks like (1260,)\n", + "log_probs looks like (1260,)\n", + "logs prob looks like torch.Size([1260])\n", + "torch.from_numpy(rewards) looks like torch.Size([1260])\n", + "rewards looks like (1514,)\n", + "log_probs looks like (1514,)\n", + "logs prob looks like torch.Size([1514])\n", + "torch.from_numpy(rewards) looks like torch.Size([1514])\n", + "rewards looks like (2095,)\n", + "log_probs looks like (2095,)\n", + "logs prob looks like torch.Size([2095])\n", + "torch.from_numpy(rewards) looks like torch.Size([2095])\n", + "rewards looks like (1695,)\n", + "log_probs looks like (1695,)\n", + "logs prob looks like torch.Size([1695])\n", + "torch.from_numpy(rewards) looks like torch.Size([1695])\n", + "rewards looks like (2109,)\n", + "log_probs looks like (2109,)\n", + "logs prob looks like torch.Size([2109])\n", + "torch.from_numpy(rewards) looks like torch.Size([2109])\n", + "rewards looks like (967,)\n", + "log_probs looks like (967,)\n", + "logs prob looks like torch.Size([967])\n", + "torch.from_numpy(rewards) looks like torch.Size([967])\n", + "rewards looks like (1231,)\n", + "log_probs looks like (1231,)\n", + "logs prob looks like torch.Size([1231])\n", + "torch.from_numpy(rewards) looks like torch.Size([1231])\n", + "rewards looks like (1355,)\n", + "log_probs looks like (1355,)\n", + "logs prob looks like torch.Size([1355])\n", + "torch.from_numpy(rewards) looks like torch.Size([1355])\n", + "rewards looks like (1351,)\n", + "log_probs looks like (1351,)\n", + "logs prob looks like torch.Size([1351])\n", + "torch.from_numpy(rewards) looks like torch.Size([1351])\n", + "rewards looks like (1674,)\n", + "log_probs looks like (1674,)\n", + "logs prob looks like torch.Size([1674])\n", + "torch.from_numpy(rewards) looks like torch.Size([1674])\n", + "rewards looks like (2394,)\n", + "log_probs looks like (2394,)\n", + "logs prob looks like torch.Size([2394])\n", + "torch.from_numpy(rewards) looks like torch.Size([2394])\n", + "rewards looks like (2296,)\n", + "log_probs looks like (2296,)\n", + "logs prob looks like torch.Size([2296])\n", + "torch.from_numpy(rewards) looks like torch.Size([2296])\n", + "rewards looks like (897,)\n", + "log_probs looks like (897,)\n", + "logs prob looks like torch.Size([897])\n", + "torch.from_numpy(rewards) looks like torch.Size([897])\n", + "rewards looks like (2389,)\n", + "log_probs looks like (2389,)\n", + "logs prob looks like torch.Size([2389])\n", + "torch.from_numpy(rewards) looks like torch.Size([2389])\n", + "rewards looks like (1798,)\n", + "log_probs looks like (1798,)\n", + "logs prob looks like torch.Size([1798])\n", + "torch.from_numpy(rewards) looks like torch.Size([1798])\n", + "rewards looks like (1232,)\n", + "log_probs looks like (1232,)\n", + "logs prob looks like torch.Size([1232])\n", + "torch.from_numpy(rewards) looks like torch.Size([1232])\n", + "rewards looks like (1173,)\n", + "log_probs looks like (1173,)\n", + "logs prob looks like torch.Size([1173])\n", + "torch.from_numpy(rewards) looks like torch.Size([1173])\n", + "rewards looks like (1165,)\n", + "log_probs looks like (1165,)\n", + "logs prob looks like torch.Size([1165])\n", + "torch.from_numpy(rewards) looks like torch.Size([1165])\n", + "rewards looks like (1602,)\n", + "log_probs looks like (1602,)\n", + "logs prob looks like torch.Size([1602])\n", + "torch.from_numpy(rewards) looks like torch.Size([1602])\n", + "rewards looks like (1164,)\n", + "log_probs looks like (1164,)\n", + "logs prob looks like torch.Size([1164])\n", + "torch.from_numpy(rewards) looks like torch.Size([1164])\n", + "rewards looks like (2235,)\n", + "log_probs looks like (2235,)\n", + "logs prob looks like torch.Size([2235])\n", + "torch.from_numpy(rewards) looks like torch.Size([2235])\n", + "rewards looks like (1038,)\n", + "log_probs looks like (1038,)\n", + "logs prob looks like torch.Size([1038])\n", + "torch.from_numpy(rewards) looks like torch.Size([1038])\n", + "rewards looks like (1698,)\n", + "log_probs looks like (1698,)\n", + "logs prob looks like torch.Size([1698])\n", + "torch.from_numpy(rewards) looks like torch.Size([1698])\n", + "rewards looks like (1436,)\n", + "log_probs looks like (1436,)\n", + "logs prob looks like torch.Size([1436])\n", + "torch.from_numpy(rewards) looks like torch.Size([1436])\n", + "rewards looks like (1223,)\n", + "log_probs looks like (1223,)\n", + "logs prob looks like torch.Size([1223])\n", + "torch.from_numpy(rewards) looks like torch.Size([1223])\n", + "rewards looks like (2006,)\n", + "log_probs looks like (2006,)\n", + "logs prob looks like torch.Size([2006])\n", + "torch.from_numpy(rewards) looks like torch.Size([2006])\n", + "rewards looks like (1162,)\n", + "log_probs looks like (1162,)\n", + "logs prob looks like torch.Size([1162])\n", + "torch.from_numpy(rewards) looks like torch.Size([1162])\n", + "rewards looks like (2239,)\n", + "log_probs looks like (2239,)\n", + "logs prob looks like torch.Size([2239])\n", + "torch.from_numpy(rewards) looks like torch.Size([2239])\n", + "rewards looks like (1104,)\n", + "log_probs looks like (1104,)\n", + "logs prob looks like torch.Size([1104])\n", + "torch.from_numpy(rewards) looks like torch.Size([1104])\n", + "rewards looks like (1389,)\n", + "log_probs looks like (1389,)\n", + "logs prob looks like torch.Size([1389])\n", + "torch.from_numpy(rewards) looks like torch.Size([1389])\n", + "rewards looks like (1419,)\n", + "log_probs looks like (1419,)\n", + "logs prob looks like torch.Size([1419])\n", + "torch.from_numpy(rewards) looks like torch.Size([1419])\n", + "rewards looks like (1526,)\n", + "log_probs looks like (1526,)\n", + "logs prob looks like torch.Size([1526])\n", + "torch.from_numpy(rewards) looks like torch.Size([1526])\n", + "rewards looks like (1618,)\n", + "log_probs looks like (1618,)\n", + "logs prob looks like torch.Size([1618])\n", + "torch.from_numpy(rewards) looks like torch.Size([1618])\n", + "rewards looks like (2276,)\n", + "log_probs looks like (2276,)\n", + "logs prob looks like torch.Size([2276])\n", + "torch.from_numpy(rewards) looks like torch.Size([2276])\n", + "rewards looks like (2973,)\n", + "log_probs looks like (2973,)\n", + "logs prob looks like torch.Size([2973])\n", + "torch.from_numpy(rewards) looks like torch.Size([2973])\n", + "rewards looks like (1418,)\n", + "log_probs looks like (1418,)\n", + "logs prob looks like torch.Size([1418])\n", + "torch.from_numpy(rewards) looks like torch.Size([1418])\n", + "rewards looks like (1273,)\n", + "log_probs looks like (1273,)\n", + "logs prob looks like torch.Size([1273])\n", + "torch.from_numpy(rewards) looks like torch.Size([1273])\n", + "rewards looks like (2355,)\n", + "log_probs looks like (2355,)\n", + "logs prob looks like torch.Size([2355])\n", + "torch.from_numpy(rewards) looks like torch.Size([2355])\n", + "rewards looks like (1308,)\n", + "log_probs looks like (1308,)\n", + "logs prob looks like torch.Size([1308])\n", + "torch.from_numpy(rewards) looks like torch.Size([1308])\n", + "rewards looks like (1403,)\n", + "log_probs looks like (1403,)\n", + "logs prob looks like torch.Size([1403])\n", + "torch.from_numpy(rewards) looks like torch.Size([1403])\n", + "rewards looks like (1794,)\n", + "log_probs looks like (1794,)\n", + "logs prob looks like torch.Size([1794])\n", + "torch.from_numpy(rewards) looks like torch.Size([1794])\n", + "rewards looks like (1101,)\n", + "log_probs looks like (1101,)\n", + "logs prob looks like torch.Size([1101])\n", + "torch.from_numpy(rewards) looks like torch.Size([1101])\n", + "rewards looks like (1165,)\n", + "log_probs looks like (1165,)\n", + "logs prob looks like torch.Size([1165])\n", + "torch.from_numpy(rewards) looks like torch.Size([1165])\n", + "rewards looks like (1162,)\n", + "log_probs looks like (1162,)\n", + "logs prob looks like torch.Size([1162])\n", + "torch.from_numpy(rewards) looks like torch.Size([1162])\n", + "rewards looks like (1317,)\n", + "log_probs looks like (1317,)\n", + "logs prob looks like torch.Size([1317])\n", + "torch.from_numpy(rewards) looks like torch.Size([1317])\n", + "rewards looks like (993,)\n", + "log_probs looks like (993,)\n", + "logs prob looks like torch.Size([993])\n", + "torch.from_numpy(rewards) looks like torch.Size([993])\n", + "rewards looks like (2078,)\n", + "log_probs looks like (2078,)\n", + "logs prob looks like torch.Size([2078])\n", + "torch.from_numpy(rewards) looks like torch.Size([2078])\n", + "rewards looks like (1419,)\n", + "log_probs looks like (1419,)\n", + "logs prob looks like torch.Size([1419])\n", + "torch.from_numpy(rewards) looks like torch.Size([1419])\n", + "rewards looks like (1354,)\n", + "log_probs looks like (1354,)\n", + "logs prob looks like torch.Size([1354])\n", + "torch.from_numpy(rewards) looks like torch.Size([1354])\n", + "rewards looks like (1216,)\n", + "log_probs looks like (1216,)\n", + "logs prob looks like torch.Size([1216])\n", + "torch.from_numpy(rewards) looks like torch.Size([1216])\n", + "rewards looks like (1661,)\n", + "log_probs looks like (1661,)\n", + "logs prob looks like torch.Size([1661])\n", + "torch.from_numpy(rewards) looks like torch.Size([1661])\n", + "rewards looks like (2095,)\n", + "log_probs looks like (2095,)\n", + "logs prob looks like torch.Size([2095])\n", + "torch.from_numpy(rewards) looks like torch.Size([2095])\n", + "rewards looks like (2455,)\n", + "log_probs looks like (2455,)\n", + "logs prob looks like torch.Size([2455])\n", + "torch.from_numpy(rewards) looks like torch.Size([2455])\n", + "rewards looks like (2383,)\n", + "log_probs looks like (2383,)\n", + "logs prob looks like torch.Size([2383])\n", + "torch.from_numpy(rewards) looks like torch.Size([2383])\n", + "rewards looks like (2222,)\n", + "log_probs looks like (2222,)\n", + "logs prob looks like torch.Size([2222])\n", + "torch.from_numpy(rewards) looks like torch.Size([2222])\n", + "rewards looks like (2269,)\n", + "log_probs looks like (2269,)\n", + "logs prob looks like torch.Size([2269])\n", + "torch.from_numpy(rewards) looks like torch.Size([2269])\n", + "rewards looks like (2995,)\n", + "log_probs looks like (2995,)\n", + "logs prob looks like torch.Size([2995])\n", + "torch.from_numpy(rewards) looks like torch.Size([2995])\n", + "rewards looks like (1474,)\n", + "log_probs looks like (1474,)\n", + "logs prob looks like torch.Size([1474])\n", + "torch.from_numpy(rewards) looks like torch.Size([1474])\n", + "rewards looks like (2666,)\n", + "log_probs looks like (2666,)\n", + "logs prob looks like torch.Size([2666])\n", + "torch.from_numpy(rewards) looks like torch.Size([2666])\n", + "rewards looks like (1386,)\n", + "log_probs looks like (1386,)\n", + "logs prob looks like torch.Size([1386])\n", + "torch.from_numpy(rewards) looks like torch.Size([1386])\n", + "rewards looks like (2039,)\n", + "log_probs looks like (2039,)\n", + "logs prob looks like torch.Size([2039])\n", + "torch.from_numpy(rewards) looks like torch.Size([2039])\n", + "rewards looks like (2172,)\n", + "log_probs looks like (2172,)\n", + "logs prob looks like torch.Size([2172])\n", + "torch.from_numpy(rewards) looks like torch.Size([2172])\n", + "rewards looks like (2070,)\n", + "log_probs looks like (2070,)\n", + "logs prob looks like torch.Size([2070])\n", + "torch.from_numpy(rewards) looks like torch.Size([2070])\n", + "rewards looks like (2534,)\n", + "log_probs looks like (2534,)\n", + "logs prob looks like torch.Size([2534])\n", + "torch.from_numpy(rewards) looks like torch.Size([2534])\n", + "rewards looks like (1660,)\n", + "log_probs looks like (1660,)\n", + "logs prob looks like torch.Size([1660])\n", + "torch.from_numpy(rewards) looks like torch.Size([1660])\n", + "rewards looks like (1406,)\n", + "log_probs looks like (1406,)\n", + "logs prob looks like torch.Size([1406])\n", + "torch.from_numpy(rewards) looks like torch.Size([1406])\n", + "rewards looks like (1472,)\n", + "log_probs looks like (1472,)\n", + "logs prob looks like torch.Size([1472])\n", + "torch.from_numpy(rewards) looks like torch.Size([1472])\n", + "rewards looks like (2711,)\n", + "log_probs looks like (2711,)\n", + "logs prob looks like torch.Size([2711])\n", + "torch.from_numpy(rewards) looks like torch.Size([2711])\n", + "rewards looks like (1529,)\n", + "log_probs looks like (1529,)\n", + "logs prob looks like torch.Size([1529])\n", + "torch.from_numpy(rewards) looks like torch.Size([1529])\n", + "rewards looks like (1867,)\n", + "log_probs looks like (1867,)\n", + "logs prob looks like torch.Size([1867])\n", + "torch.from_numpy(rewards) looks like torch.Size([1867])\n", + "rewards looks like (1218,)\n", + "log_probs looks like (1218,)\n", + "logs prob looks like torch.Size([1218])\n", + "torch.from_numpy(rewards) looks like torch.Size([1218])\n", + "rewards looks like (1345,)\n", + "log_probs looks like (1345,)\n", + "logs prob looks like torch.Size([1345])\n", + "torch.from_numpy(rewards) looks like torch.Size([1345])\n", + "rewards looks like (1188,)\n", + "log_probs looks like (1188,)\n", + "logs prob looks like torch.Size([1188])\n", + "torch.from_numpy(rewards) looks like torch.Size([1188])\n", + "rewards looks like (1945,)\n", + "log_probs looks like (1945,)\n", + "logs prob looks like torch.Size([1945])\n", + "torch.from_numpy(rewards) looks like torch.Size([1945])\n", + "rewards looks like (987,)\n", + "log_probs looks like (987,)\n", + "logs prob looks like torch.Size([987])\n", + "torch.from_numpy(rewards) looks like torch.Size([987])\n", + "rewards looks like (2017,)\n", + "log_probs looks like (2017,)\n", + "logs prob looks like torch.Size([2017])\n", + "torch.from_numpy(rewards) looks like torch.Size([2017])\n", + "rewards looks like (2001,)\n", + "log_probs looks like (2001,)\n", + "logs prob looks like torch.Size([2001])\n", + "torch.from_numpy(rewards) looks like torch.Size([2001])\n", + "rewards looks like (1335,)\n", + "log_probs looks like (1335,)\n", + "logs prob looks like torch.Size([1335])\n", + "torch.from_numpy(rewards) looks like torch.Size([1335])\n", + "rewards looks like (1343,)\n", + "log_probs looks like (1343,)\n", + "logs prob looks like torch.Size([1343])\n", + "torch.from_numpy(rewards) looks like torch.Size([1343])\n", + "rewards looks like (2834,)\n", + "log_probs looks like (2834,)\n", + "logs prob looks like torch.Size([2834])\n", + "torch.from_numpy(rewards) looks like torch.Size([2834])\n", + "rewards looks like (1391,)\n", + "log_probs looks like (1391,)\n", + "logs prob looks like torch.Size([1391])\n", + "torch.from_numpy(rewards) looks like torch.Size([1391])\n", + "rewards looks like (1852,)\n", + "log_probs looks like (1852,)\n", + "logs prob looks like torch.Size([1852])\n", + "torch.from_numpy(rewards) looks like torch.Size([1852])\n", + "rewards looks like (1256,)\n", + "log_probs looks like (1256,)\n", + "logs prob looks like torch.Size([1256])\n", + "torch.from_numpy(rewards) looks like torch.Size([1256])\n", + "rewards looks like (1184,)\n", + "log_probs looks like (1184,)\n", + "logs prob looks like torch.Size([1184])\n", + "torch.from_numpy(rewards) looks like torch.Size([1184])\n", + "rewards looks like (1939,)\n", + "log_probs looks like (1939,)\n", + "logs prob looks like torch.Size([1939])\n", + "torch.from_numpy(rewards) looks like torch.Size([1939])\n", + "rewards looks like (1274,)\n", + "log_probs looks like (1274,)\n", + "logs prob looks like torch.Size([1274])\n", + "torch.from_numpy(rewards) looks like torch.Size([1274])\n", + "rewards looks like (1367,)\n", + "log_probs looks like (1367,)\n", + "logs prob looks like torch.Size([1367])\n", + "torch.from_numpy(rewards) looks like torch.Size([1367])\n", + "rewards looks like (1284,)\n", + "log_probs looks like (1284,)\n", + "logs prob looks like torch.Size([1284])\n", + "torch.from_numpy(rewards) looks like torch.Size([1284])\n", + "rewards looks like (1127,)\n", + "log_probs looks like (1127,)\n", + "logs prob looks like torch.Size([1127])\n", + "torch.from_numpy(rewards) looks like torch.Size([1127])\n", + "rewards looks like (1298,)\n", + "log_probs looks like (1298,)\n", + "logs prob looks like torch.Size([1298])\n", + "torch.from_numpy(rewards) looks like torch.Size([1298])\n", + "rewards looks like (1638,)\n", + "log_probs looks like (1638,)\n", + "logs prob looks like torch.Size([1638])\n", + "torch.from_numpy(rewards) looks like torch.Size([1638])\n", + "rewards looks like (1144,)\n", + "log_probs looks like (1144,)\n", + "logs prob looks like torch.Size([1144])\n", + "torch.from_numpy(rewards) looks like torch.Size([1144])\n", + "rewards looks like (1370,)\n", + "log_probs looks like (1370,)\n", + "logs prob looks like torch.Size([1370])\n", + "torch.from_numpy(rewards) looks like torch.Size([1370])\n", + "rewards looks like (1835,)\n", + "log_probs looks like (1835,)\n", + "logs prob looks like torch.Size([1835])\n", + "torch.from_numpy(rewards) looks like torch.Size([1835])\n", + "rewards looks like (2149,)\n", + "log_probs looks like (2149,)\n", + "logs prob looks like torch.Size([2149])\n", + "torch.from_numpy(rewards) looks like torch.Size([2149])\n", + "rewards looks like (1033,)\n", + "log_probs looks like (1033,)\n", + "logs prob looks like torch.Size([1033])\n", + "torch.from_numpy(rewards) looks like torch.Size([1033])\n", + "rewards looks like (989,)\n", + "log_probs looks like (989,)\n", + "logs prob looks like torch.Size([989])\n", + "torch.from_numpy(rewards) looks like torch.Size([989])\n", + "rewards looks like (1900,)\n", + "log_probs looks like (1900,)\n", + "logs prob looks like torch.Size([1900])\n", + "torch.from_numpy(rewards) looks like torch.Size([1900])\n", + "rewards looks like (1706,)\n", + "log_probs looks like (1706,)\n", + "logs prob looks like torch.Size([1706])\n", + "torch.from_numpy(rewards) looks like torch.Size([1706])\n", + "rewards looks like (1235,)\n", + "log_probs looks like (1235,)\n", + "logs prob looks like torch.Size([1235])\n", + "torch.from_numpy(rewards) looks like torch.Size([1235])\n", + "rewards looks like (2693,)\n", + "log_probs looks like (2693,)\n", + "logs prob looks like torch.Size([2693])\n", + "torch.from_numpy(rewards) looks like torch.Size([2693])\n", + "rewards looks like (1021,)\n", + "log_probs looks like (1021,)\n", + "logs prob looks like torch.Size([1021])\n", + "torch.from_numpy(rewards) looks like torch.Size([1021])\n", + "rewards looks like (1126,)\n", + "log_probs looks like (1126,)\n", + "logs prob looks like torch.Size([1126])\n", + "torch.from_numpy(rewards) looks like torch.Size([1126])\n", + "rewards looks like (1334,)\n", + "log_probs looks like (1334,)\n", + "logs prob looks like torch.Size([1334])\n", + "torch.from_numpy(rewards) looks like torch.Size([1334])\n", + "rewards looks like (1337,)\n", + "log_probs looks like (1337,)\n", + "logs prob looks like torch.Size([1337])\n", + "torch.from_numpy(rewards) looks like torch.Size([1337])\n", + "rewards looks like (1502,)\n", + "log_probs looks like (1502,)\n", + "logs prob looks like torch.Size([1502])\n", + "torch.from_numpy(rewards) looks like torch.Size([1502])\n", + "rewards looks like (2059,)\n", + "log_probs looks like (2059,)\n", + "logs prob looks like torch.Size([2059])\n", + "torch.from_numpy(rewards) looks like torch.Size([2059])\n", + "rewards looks like (2057,)\n", + "log_probs looks like (2057,)\n", + "logs prob looks like torch.Size([2057])\n", + "torch.from_numpy(rewards) looks like torch.Size([2057])\n", + "rewards looks like (1300,)\n", + "log_probs looks like (1300,)\n", + "logs prob looks like torch.Size([1300])\n", + "torch.from_numpy(rewards) looks like torch.Size([1300])\n", + "rewards looks like (3078,)\n", + "log_probs looks like (3078,)\n", + "logs prob looks like torch.Size([3078])\n", + "torch.from_numpy(rewards) looks like torch.Size([3078])\n", + "rewards looks like (1724,)\n", + "log_probs looks like (1724,)\n", + "logs prob looks like torch.Size([1724])\n", + "torch.from_numpy(rewards) looks like torch.Size([1724])\n", + "rewards looks like (1468,)\n", + "log_probs looks like (1468,)\n", + "logs prob looks like torch.Size([1468])\n", + "torch.from_numpy(rewards) looks like torch.Size([1468])\n", + "rewards looks like (2674,)\n", + "log_probs looks like (2674,)\n", + "logs prob looks like torch.Size([2674])\n", + "torch.from_numpy(rewards) looks like torch.Size([2674])\n", + "rewards looks like (1376,)\n", + "log_probs looks like (1376,)\n", + "logs prob looks like torch.Size([1376])\n", + "torch.from_numpy(rewards) looks like torch.Size([1376])\n", + "rewards looks like (1564,)\n", + "log_probs looks like (1564,)\n", + "logs prob looks like torch.Size([1564])\n", + "torch.from_numpy(rewards) looks like torch.Size([1564])\n", + "rewards looks like (1452,)\n", + "log_probs looks like (1452,)\n", + "logs prob looks like torch.Size([1452])\n", + "torch.from_numpy(rewards) looks like torch.Size([1452])\n", + "rewards looks like (1205,)\n", + "log_probs looks like (1205,)\n", + "logs prob looks like torch.Size([1205])\n", + "torch.from_numpy(rewards) looks like torch.Size([1205])\n", + "rewards looks like (1520,)\n", + "log_probs looks like (1520,)\n", + "logs prob looks like torch.Size([1520])\n", + "torch.from_numpy(rewards) looks like torch.Size([1520])\n", + "rewards looks like (1099,)\n", + "log_probs looks like (1099,)\n", + "logs prob looks like torch.Size([1099])\n", + "torch.from_numpy(rewards) looks like torch.Size([1099])\n", + "rewards looks like (1506,)\n", + "log_probs looks like (1506,)\n", + "logs prob looks like torch.Size([1506])\n", + "torch.from_numpy(rewards) looks like torch.Size([1506])\n", + "rewards looks like (1175,)\n", + "log_probs looks like (1175,)\n", + "logs prob looks like torch.Size([1175])\n", + "torch.from_numpy(rewards) looks like torch.Size([1175])\n", + "rewards looks like (1251,)\n", + "log_probs looks like (1251,)\n", + "logs prob looks like torch.Size([1251])\n", + "torch.from_numpy(rewards) looks like torch.Size([1251])\n", + "rewards looks like (1318,)\n", + "log_probs looks like (1318,)\n", + "logs prob looks like torch.Size([1318])\n", + "torch.from_numpy(rewards) looks like torch.Size([1318])\n", + "rewards looks like (1446,)\n", + "log_probs looks like (1446,)\n", + "logs prob looks like torch.Size([1446])\n", + "torch.from_numpy(rewards) looks like torch.Size([1446])\n", + "rewards looks like (1220,)\n", + "log_probs looks like (1220,)\n", + "logs prob looks like torch.Size([1220])\n", + "torch.from_numpy(rewards) looks like torch.Size([1220])\n", + "rewards looks like (1343,)\n", + "log_probs looks like (1343,)\n", + "logs prob looks like torch.Size([1343])\n", + "torch.from_numpy(rewards) looks like torch.Size([1343])\n", + "rewards looks like (1186,)\n", + "log_probs looks like (1186,)\n", + "logs prob looks like torch.Size([1186])\n", + "torch.from_numpy(rewards) looks like torch.Size([1186])\n", + "rewards looks like (1443,)\n", + "log_probs looks like (1443,)\n", + "logs prob looks like torch.Size([1443])\n", + "torch.from_numpy(rewards) looks like torch.Size([1443])\n", + "rewards looks like (1212,)\n", + "log_probs looks like (1212,)\n", + "logs prob looks like torch.Size([1212])\n", + "torch.from_numpy(rewards) looks like torch.Size([1212])\n", + "rewards looks like (1346,)\n", + "log_probs looks like (1346,)\n", + "logs prob looks like torch.Size([1346])\n", + "torch.from_numpy(rewards) looks like torch.Size([1346])\n", + "rewards looks like (2124,)\n", + "log_probs looks like (2124,)\n", + "logs prob looks like torch.Size([2124])\n", + "torch.from_numpy(rewards) looks like torch.Size([2124])\n", + "rewards looks like (1461,)\n", + "log_probs looks like (1461,)\n", + "logs prob looks like torch.Size([1461])\n", + "torch.from_numpy(rewards) looks like torch.Size([1461])\n", + "rewards looks like (1425,)\n", + "log_probs looks like (1425,)\n", + "logs prob looks like torch.Size([1425])\n", + "torch.from_numpy(rewards) looks like torch.Size([1425])\n", + "rewards looks like (1457,)\n", + "log_probs looks like (1457,)\n", + "logs prob looks like torch.Size([1457])\n", + "torch.from_numpy(rewards) looks like torch.Size([1457])\n", + "rewards looks like (1223,)\n", + "log_probs looks like (1223,)\n", + "logs prob looks like torch.Size([1223])\n", + "torch.from_numpy(rewards) looks like torch.Size([1223])\n", + "rewards looks like (1310,)\n", + "log_probs looks like (1310,)\n", + "logs prob looks like torch.Size([1310])\n", + "torch.from_numpy(rewards) looks like torch.Size([1310])\n", + "rewards looks like (2446,)\n", + "log_probs looks like (2446,)\n", + "logs prob looks like torch.Size([2446])\n", + "torch.from_numpy(rewards) looks like torch.Size([2446])\n", + "\n" + ], + "name": "stdout" + } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "vNb_tuFYhKVK" + }, + "source": [ + "### 訓練結果\n", + "\n", + "訓練過程中,我們持續記下了 `avg_total_reward`,這個數值代表的是:每次更新 policy network 前,我們讓 agent 玩數個回合(episodes),而這些回合的平均 total rewards 為何。\n", + "理論上,若是 agent 一直在進步,則所得到的 `avg_total_reward` 也會持續上升,直至 250 上下。\n", + "若將其畫出來則結果如下:" + ] + }, + { + "cell_type": "code", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 281 + }, + "id": "wZYOI8H10SHN", + "outputId": "80307382-3743-4f70-e08a-66c5e92451da" + }, + "source": [ + "end = time.time()\n", + "plt.plot(avg_total_rewards)\n", + "plt.title(\"Total Rewards\")\n", + "plt.show()" + ], + "execution_count": null, + "outputs": [ + { + "output_type": "display_data", + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAX8AAAEICAYAAAC3Y/QeAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAgAElEQVR4nOx9d5gcxZn++3XPzCbtSqucAyAJRAYhkjFgkgAbsLE5MLYB+36cbbizjW0MxhljY5992DjgA4wDDphzgiNLHBkEiCSShFYJZa3S5t0JXb8/uqu7uro6TFjtjqbe59GjnZ7u6uru6be+er9QxBiDhoaGhkZtwRjqDmhoaGho7Hlo8tfQ0NCoQWjy19DQ0KhBaPLX0NDQqEFo8tfQ0NCoQWjy19DQ0KhBaPLXqHkQESOi/Ya6H6WCiE4iog1D3Q+N6oImf41hCyLqFv5ZRNQnfL445JiKEiERPU5E/c45txPR34loUqXa19AYKmjy1xi2YIyN4P8AvAvgA8K2P+7Brlzp9GE/ACMA/GgPntsHIkoN1bk19i5o8teoOhBRHRH9hIg2Of9+4mxrAvAggMnCDGEyES0goueIaDcRbSainxNRptjzMsZ2A/gngMOEvuxPRIuIaCcRrSCiC5zts5zzGc7n24hom3DcnUT0eefvy4jobSLqIqLVRPRvwn4nEdEGIvoKEW0B8BsiaiCi3xLRLiJ6C8BR0v35ChFtdNpbQUSnFHutGns/NPlrVCOuA3AMbBI+FMACAF9jjPUAOBPAJmGGsAlAAcAXAIwFcCyAUwB8ttiTEtEYAB8C0OZ8bgKwCMCfAIwHcCGAXxLRPMbYGgCdAA53Dn8vgG4iOsD5fCKAJ5y/twF4P4AWAJcBuImIjhBOPRHAaAAzAFwO4JsA9nX+nQHgEqGPcwFcCeAoxliz8/3aYq9VY++HJn+NasTFAL7DGNvGGGsH8G0AHw/bmTH2EmNsCWMszxhbC+C/YZNvUtxMRB0AtsMeQP7d2f5+AGsZY79x2n4FwN8AfMT5/gkAJxLRROfzX53Ps2AT/WtO/+5njK1iNp4A8AiAE4TzWwC+yRgbYIz1AbgAwA2MsZ2MsfUAbhb2LQCoAzCPiNKMsbWMsVVFXKtGjUCTv0Y1YjKAdcLndc42JYhoDhHdR0RbiKgTwPdgk3hS/AdjbCSAQwC0ApjqbJ8B4GhH3tlNRLthD0yc7J8AcBJsq/9JAI/DHnROBPAUY8xy+ncmES1xpKPdAM6S+tfOGOuXrn+9dP0AAMZYG4DPA/gWgG1EdBcRhd4bjdqFJn+NasQm2MTLMd3ZBgCqMrW3AFgOYDZjrAXAVwFQsSdljL0O4LsAfkFEBJuAn2CMjRL+jWCMfcY55AnYFvxJzt9PAzgeguRDRHWwZws/AjCBMTYKwANS/+Rr2gxgmvB5utTPPzHG3gP7HjEAPyj2WjX2fmjy16hG/BnA14hoHBGNBfANAH9wvtsKYAwRjRT2b4atv3cT0f4APoPS8TsAEwCcA+A+AHOI6ONElHb+HcV1fcbYSgB9AD4Ge5DodPp3Pjy9PwNbpmkHkCeiMwGcHtOHuwFcS0StRDQVngwFIppLRO9zBpV+5/xWGdersZdCk79GNeK7AJYCWAbgdQAvO9vAGFsOe3BY7UgxkwF8CcBHAXQBuA3AX0o9MWMsC+CnAL7OGOuCTdQXwp55bIFtZdcJhzwBYIejzfPP5PQZThv/AZvQdzn9vDemG9+GLfWsge0fuFP4rg7AjbD9E1tgO6KvLeFSNfZykF7MRUNDQ6P2oC1/DQ0NjRqEJn8NDQ2NGoQmfw0NDY0ahCZ/DQ0NjRpE2UWiiGgagN/DDn9jAG5ljP2UiEbDjqqYCTu9/ALG2C4nPvqnsBNZegFcyhh7OeocY8eOZTNnziy3qxoaGho1hZdeemk7Y2yc6rtKVAjMA/giY+xlImoG8BIRLQJwKYBHGWM3EtE1AK4B8BXYtVdmO/+Ohp2Ac3TUCWbOnImlS5dWoKsaGhoatQMiWhf2XdmyD2NsM7fcnZjltwFMAXAu7IQYOP+f5/x9LoDfO3VMlgAYpeuja2hoaOxZVFTzJ6KZsKsYPg87VX2z89UW2LIQYA8MYl2SDc42ua3LiWgpES1tb2+vZDc1NDQ0ah4VI38iGgG7RsnnnTR2F8zOJCsqm4wxditjbD5jbP64cUrJSkNDQ0OjRFSE/IkoDZv4/8gY+7uzeSuXc5z/+UIWG+EvSjXV2aahoaGhsYdQNvk70Tu/BvA2Y+y/hK/uhbfIxCUA7hG2f4JsHAOgQ5CHNDQ0NDT2ACoR7XM87IU0XieiV51tX4VdXOpuIvoU7CJUFzjfPQA7zLMNdqjnZRXog4aGhoZGESib/BljTyO8Nnpg7VBH/7+i3PNqaGhoaJQOneGroVFFyOYt3L10PSxLV+PVKA+VkH00NDT2EH7+WBtufnQlGtImPnCoXp1Ro3Roy19Do4qwrdNeyrerPz/EPdGodmjy19CoIljO4ktG0SsQa2j4oclfQ6OKwKV+gzT7a5QHTf4aGlUEbvlr7tcoF5r8NTSqCDrKR6NS0OSvoVFF4NyfLVhD2xGNqocmfw2NKgKXfbJ5Tf4a5UGTv4ZGFcHhfgxo8tcoE5r8NTSqCAVH9xnIafLXKA+a/DU0qggD+YLvfw2NUqHJX0OjitCX4+SvLX+N8qDJX0OjitDnyD3a8tcoF5r8NTSqCP1Zx/LXmr9GmdDkr6FRRejN2QXdtOyjUS40+WtoVBH6slr20agMNPlraFQR+rXDV6NC0OSvoVElYIx50T5a89coE5r8NTSqBNmC5SV5adlHo0xUhPyJ6A4i2kZEbwjbRhPRIiJa6fzf6mwnIrqZiNqIaBkRHVGJPmho7O3oF6x9LftolItKWf6/BbBQ2nYNgEcZY7MBPOp8BoAzAcx2/l0O4JYK9UFDY68GY145Z03+GuWiIuTPGHsSwE5p87kAfuf8/TsA5wnbf89sLAEwiogmVaIfGhqloDebR192+MsoAvdr2UejbAym5j+BMbbZ+XsLgAnO31MArBf22+Bs84GILieipUS0tL29fRC7qVHrmPeNh3H09xYPdTdiYYmWv3b4apSJPeLwZfZ8tagliBhjtzLG5jPG5o8bN26QeqahYaOzPz/UXYjEii1duP9125YyDdKyj0bZSA1i21uJaBJjbLMj62xztm8EME3Yb6qzTUNDIwRn/ORJ9+/6lKFlH42yMZiW/70ALnH+vgTAPcL2TzhRP8cA6BDkIQ0NjRikU4Yb8qmhUSoqYvkT0Z8BnARgLBFtAPBNADcCuJuIPgVgHYALnN0fAHAWgDYAvQAuq0QfNDRqBSmDoLlfo1xUhPwZYxeFfHWKYl8G4IpKnFdDo1xUowVtGlSV/dYYXtAZvho1jd7s8Hb0qpAy7NdWjPvX0CgWmvw1ahrVEN8vwzQIQHXOWjSGDzT5a9Q0eqqQ/FOmQ/7a8tcoA5r8NWoaPQPVKPvY5F8q9/dm81izvaeCPdKoRmjy16hp9Fah5W86mn+pss+nfrsUJ//ocffzrp4stnX1V6JrGlUETf4aNY2eqnT42pa/VaLp/9zqHQA8h/Hh1y/CghserUznNBKhoy+HzR19Q9oHTf4aNY3egWq0/B3yL7PCw3B0GC9duxN/efFdAEBnf26IezN4OP2mJ3Ds9/9vSPugyV+jplGNln/aLM/y5xiODuMP/+o5fOVvr2Pp2p044juLfNbxDfe/he/e95Zv/427+/Cjh1cMi7DX1e3dOPOnT2F3bzZ2362dA3ugR9HQ5K9Rs8gVLNy/rPoqi7ihnmUSXrkzh8HEpo5+5C2G7V0ekd721Brc/vQa336Pvr0VP3+sDZs7ht5n8YvHVuHtzZ1Y9NbWoe5KImjy16hZPPjGFjzxTvWVC+dJXuVa/vlhzP75gt23XEwfs0510+FQ5ZTsMbmoKKyhlN40+WvULCzhxXOM6apApTT/Ycz9yHHyjyH1rLNff27ofTf8N1TMoDyUsqMmf42ahVgW2WL+wSAJVmzpwpzrHsSGXb2V7lokyo324ShGNuoeyGPtHswNyBXsvuVjnkkub38/PMjfyb8o4hg54ODB1zfj8t8vrWCvwqHJX6NmwRdEv+z4mQCK19D//MK7yBYsPPzmntV4K1XeoRjZ52O3P4+ThNyAwYYr+xTiLH+bPIeT7GMxhkVvbcX27ninrmz5f+aPL+ORPeQz0OSvUbPoc6zF5jq7uO1wDH1UIW1WRvMvRvZ5df1uAHuumBy3/Pn/cfsNB8sfsNm/d6CA//f7pbjkjhdijwjLMN8T91mTv8Zeh1fe3YVdPfHhdryoW5ND/nESA8eO7gG8uamj9A6WCVfzL4IffvzIChzzPX8iVynRQv17aO1g7ujNx1n+ea75h++3paMfC3/yJJZv6Yxs66E3tqCrjNwCrvnzvq9uj5fJekLyTPaEIaLJX2Ovwwd/+SwuvHVJ7H79+QIyKcO1pAsxVibHWTc/hbNvftqd5qvwztYuXP3X1wblJU6VIPv87P/asKXTHw6Z9HpF9O0hC5uTes7iMwA1uQ+40T7h/Xp1/W4s39KF//jzK6H7rGrvxqf/8BKu/uuyyH5ZFsPtT61WWuxc88879zVqcOW5GrykeH+ugP2//qD7fVJDpBxo8tfYq8Cnyyu2dsXu258toCFtulUy48IKOeQEHdUU/bN/fBl3L92A1e3didosBqZb2G3PJ3ntqfUPOKnn8tHRPHxQGIiw/OtSNs29s7UblsWwur0bP3houc/Bz2eB63ZEO++Xb+nCd+9/G0+tDIYIc2OA9ykqgKAuZQLwqsq2dw34Zi+a/DU0ikQxL01/zrLJv8RCaYRw099zWPrb/PvLG7AzgSQVhUqVdC7mejmx7SltnZ+HO6XDZB1X9omw/MVZQ95i+MJfXsUtj6/CKmFgpoRhmu5go3Awc8ufh59GPR8+IIVp/nFyVyWgyV9jr0IxhNaXK6A+bbgySqnW1t9f3oi/vrQBAPDyu7vwbNt2V4/v6PM05PU7e3HV3a/hyj+9nLhty2IBC9JN8tqDtX3Szjn3VBVUTvZxDt0klr/4XPOWBXJIWpTBogZyf1vcFxG8d3wA4QNS1Dgik7/8LLTlr6FRJIp5aWzyN73QySI1cOZEdL+1uRNf+p/XAAAf+uWz+Ojtz7sW5HOrd2CrQzLcItySoBTBlo5+3LToHezz1Qfw8Tue931nFhnnL1qRolQkE87Ma+7Hs23blW3w2UYxK5+9sbEDd7+4PvH+IgYcss/FJHFlY2QhsQ3A/n20NqYB2BKPZTFs6+qPzc5ljKFgMSH/IDjY8AEkK8wKZl5zv1Keq0vbsg8fTOWZhGpwqTSGjPyJaCERrSCiNiK6Zqj6obF3oRgC73fInxNbseUOojRd/r7f/OhKnH7TkwA8WSCJXPPwm1vw00dXAgCeadvh+67YJC9RZhIJX3X8n154V9kGd4r3FiH7vP9nT+Pqv3kOVMYYTvuvJ/DPVzbGHstlnLxr+Qefzb5ffQCPLt/m218FUXorFBgaMjbxrtvRg8Vvb8V7fvAYOp0ZGgtJ0bpp8Urs+9UHXJ+HKgSVR/vITvEoJzm3/GWH9Z4ovTEk5E9EJoBfADgTwDwAFxHRvKHoi8behaiX5tm27T6dtz9nO3zNEmUfeX/RKhaJlUs/JiUn7e6IFcZMs7hon21dnoNa7HNRso9zzv4yZB+LASu3dePzf3k1dl8u47jlGxTkLvY/UvaRLP+ufvvertvRi21dA8jmLXdb2C350/PrAADbu7OBNjkM53ckz0K6+4PPkv9OeZJXtoYs/wUA2hhjqxljWQB3ATh3iPqisRchitA+evvzOOXHTwAAvvw/r+HFtbvQkPEcvsW+cPK5xPLDqtBE16mYwKiLIv9iLX8x07Tg07+D/oTwc5av+Rcz2AQt/+jzRlr+kubf6ZDxmu09AQduWAQVj87hg4TKUODHytJYp4L8+Qw1VPbZWy1/AFMAiGLgBmebCyK6nIiWEtHS9vbqq7yoMTQIs97FqIrO/hweW2HLBaLDt9hoH3nqv2m3p+WLjl4O3n6SEE2VtchhulU9E3XTJSzATzIWY7GzncdWbMNr63d7mn8J0T78uovJSOYyDyfBOF/DQM7C+p29ysFMLA6XLzBX4lnV3u0+Jz64hPWwPm3f8w6nVr9K9uH3Ur5HqoGc78vvTVD22Xst/1gwxm5ljM1njM0fN27cUHdHYw+iP1fALx9vKykkMsx6Xy8UX3ts+TZ3+l6fNl0ZpVhrqyDtv8ln+Xv9SJuE51btcLX+JO91lOWfLnKwEmchIjEVLBZoQ27xst+8iHN/8Yyr+Rfj8OXgkoZ4riffacd9yzaFHsPJ2JN9/PdaHkBXb+/BCT98DDctfifQlvhcC5ZN/lNGNcBiwEvrdgEQBkWh2fauASxxlrysdxy0u53BQiX78N+efI9UA3lBIv9akn02ApgmfJ7qbNPQwOsbO/DDh+xyBMUmMoUR+PqdHjGv2OIlgNWlzIDl39GXw2f/+FJsVI5snYXtnyswXHTbEry4ZieAymn+SS1pkVj6hCStgsUSJ7Zxv0gplr9L/kJ/P3HHC7jyT69g4271OrZunH+I7CMPWjyiSrWQijgQ27JPDu+dMxYA8OJa+5lwy1u8p5/+w0u48NYlbmAAAOzqtck/J53/w7c8izuX2H4B+R6pSkYELX9Z9tl7yf9FALOJaBYRZQBcCODeIeqLxjADf+GzBavoao1h1vD6nZ7lv0kgnN5s3iU2ThJ3PrcWD7y+Bb97bm1R54oqMQAAr22w6wElIv8I2ScVUs//xbU78S///VzAivRZ/lm/FZw0OooPwqVo/gNO5U1RkpnQUgcAuCskusiVfdw4/mhZhP9mVPXxxevv6s8jV2CYProJY0dk3PO4mr9w3LYue0BZubUbDdzy71U7fJc6MwhAQf6KgVy2/OXf+b2vbkoUElwOhoT8GWN5AFcCeBjA2wDuZoy9ORR90Rh+sEJi0Tv6cgFikxFmMW3Y5RG+qM3v7Ml6tX2cY/m+jc4LHwZZ942z1t7daRf6Klf2CVvJ6+q/LsPza3bi3Z3+EgUisfhkH5bc8ufXVkqGr0r2Gdlgx9qL/ghxlscH0mxIqKd8r/li76pCaaKEwqXEkQ1pZEyP/lzNX2h2v3EjAADLt3R6mj+XfSIeojjAAtHRPmHkf8cza3DRbfH1qcrBkGn+jLEHGGNzGGP7MsZuGKp+aAw/iCTBpYJ8wcKh334EV90dHSYYZvnv6BlwrTdRm9/dmxNCPe0XcLkjC+2I8TnImr9sRX/1rP1x0JQW9/M7W+0w0ySWf1jaPxBe3oGTWdDy9/bzkX8hqPmHeTy507SU2j4q2YeTtCjTiYQqW/7ioGMpZix8RqIaNMUBjpN/S0MKmZRHf57l77U7pbUBgC0T1rmyD3f4hg+afdI96orS/J17Is9sAIRKYpXCsHX4atQuRJLgUsEzq2zHW9zi2GEWWd5ibqy6OJ2eO7HZp/nnChbe2mSX/t3aWZzmL5PxvuNG4Ph9x7qf2514+yThlSqpgENV2G3mNfe7xezksEe/7CNo/owldixyjbtPEU//0BtbXL1bhaxb6MzbxgcR8fy+uH0e6ulsE3MVomYsqpmheA5OxI0Z053xAV6egPgIeXdWbut2yzHs7uUO3/D7JifCdQ/Ea/5ZVd4AAV/56zJc8cfk5UCKQWpQWtXQKAOqEgQPvbEFALBg1ujIY8MKYlkWQ8p52fMWw9gRdfjtZUdhv/Ej3MSvvMWwur3HfRHjyF+2muXPRF6UiIgkftokmn/Bsu+VrMPLx2ZDZJ9/u/MlXHXanPjOwLuvslUL2I5RAPj4MTOUx+byDN0DeV+CHa9mKQ6g4iCVK/hJ8e3NXi1+VZRSFFTRTinD8JN/Pij78GvuzebRavllKnHGIgclyM9Xno1YFnP3cWUfxaBqEGH19m5X5qs0NPlrDDuI/M2tab44i2q6nc1bmPO1B/H9Dx2MWWOb3O2MMbeIV8FiLmkCtlV10JSRAOBL8uILfhw0pSVQuhmwF4rhCDgdA+RPbikBEXGyT75gRUbVmILmf+3fX8ddUv0cmWzCHL4A8F+LgqGRKnAyLmW5xGyhgEvueMENqwS8AcmXfauwpu9fthl1qVddKQ6wrztueUdV3wGv/ymDfLJPv2v5i5FBXhim/GxVbYZBTvLyF5oLt/xNIuzuzWFfx/dQaWjZR2PYwVd/xnknuKWlijbhGaw3LXrHd6z4glrMT/7iQixibZ+3NnciYxo4bt+x2NbV75No1m7vwQd/+ayyn0Bw1mEQuX4G3/XFkH9PTESNW4W0wALEDwRDC0VyitPsRc27oLDKS1mcZiBv+YhfhBgyGSbl/P3ljegeyGOfcfbA/ptn1kbmG0Q9F66tp0zD5/B1LX9FO9lCUB7ztxlN/vJMzP/7jrD8DUJHXw6jnEJ0lYYmf41hB5/swx2+zkuicuhxJ96I+pTPqhJDLwsWc+PjAX8JX1Hzb9vajX3GNWFCSz1yBeZz1skOONkalCNSCFCSfxx/RkX6AN5g9dqG3crvZQejaCUXY7mLx3HLtJT486gILdFxG+d/OGCi7Tz/z4dX4Gf/1xa6X6eUXa2UfUxSOnzFWRnvT66gsPxF53RMiK/svC9Is4vvP/A27nhmTeA4g+ykMh4ZVWlo8tcYdlA5fPmL2KsI5eOWf3NdyheBI5JO3mJuTXrAq8AIwFfYras/j9bGjEsM4nRctqjlaB9ZqjGIUK+QfeIS16L0fsAbrDp61evNRsk+xZC/SHiyRl0MosjfF+0TQ/7jndwAwHOeq7BbJn9FAbiUQW4AAODdl+3dWbyxscPXt2zeCjzrtzd14tlV231thkGWdArSjPS/n1ytPi5vIZu3MFJb/hq1Al+oJyd/XgVRYRXzUg1NdSkfgch1bFKi5U+i5e/F+fflCmjImMjwpR2FFzeg3RZky18mf7XlH+fwjbP8ueYfllQmW/4i+cYloqkcniJKsvwj9Hk5+zYK45o98o9uUy6V4H3mVrrs8OXPrmAxvP9nT7t/8/bk6169vQcfve15X5thCJRuSDjgcflPW/4aNQNLJfsIGZyy5bzDsfxH1KV8A4f40hUs5pIm4Nf8XcvfcbQ2pL0wQDk7VIRMCAELMIT84xy+sbKP098wK16eOYgEW0ySnIpgZQs4CZJa/lGEDvhJMMrhq6qTw58DJ/l0iOwjwo04yluRJB2X+Ca3HbemgoxRDZnYfUqBJn+NYQeRX1zZx60MGdTWuexjGuSf4ouWvwXfNN/n8BVkn76sXcdFTf6y7ON/cVWyT106+IrFav4xso8ZR/7S4CGSahz552Kib0opOBZ2zsaM6ba3paMfC3/yVGQ7osM+ekDx9zFbsNwMXf7bMQ3yO3ylZ8dX7uLHR8ldcVJa0PL3a/6iBKmCtvw1agayQ8z+33uB5PotXPbpzxVCNf8CYy5pApLDV1gcpT9XQEPGkwSy+WCCEIdsfarIP2m9fBFR2b1if8MknE5pkBLvQ6zlH+OALUXzD7PSRzak3ee7YVevch8O0yB3JTTAu44ZYxpx9cK5kedTW/5ynL//mIG85baTzVuRZTDKsvwthqZMdMS9jvbRqBlYKs1fICKZHLnl35+zJM3fH+0jx/lzuHH+XPNPm8ikgpq/bPnLRCqHH8ZZdGGIyu4V+xtG5CqHL7/2OGklq4jwEVFqqKcKIxvSrh6vSoYTYRL5Bm/etx+efwimtTb69s0pNHbePi8NHYj2kQh8e/eA++xzcZZ/hMP36Fmjsb17AGfc9CTW7ehx+uOXNVVBASK05a9RMxDfM0sI9eTvvly8i6fc9+cLoZq/Hecvav4ekYiaPy/fm0Tzl8lRtsSJgCNntOKTx8/CpcfNDL9gCWXLPs7x97y6Ed+6901k85Zr+cZFpviXPBxch29LfVqQ86LbJYKf/AUSlwcOudxytsDcfTjJmwZFWv7v+cFjePnd3U7fou+byuGbMghrbzwbh04bBQBYsbULv356DSyL4VePrwIAZFIG8gUWOxvT0T4aNQNR9hFL33ILSJZ9vIqT/qgMbgHbcpAc7eMdz63i3mwBFoOP/LNR5B9j+RMRUqaBb3xgHqaMaoi7bKHf6hBOub9hhMQT4T5316v47bNrkStYrnUZZ/nnImSfjGlUNNSzpSHli6WPgiz78P1ThuHW3XG/Czh8LTfT2pV9DAPplNdenHTTm1MPyIwx5XPgA5XoVyAAT7Vtx1+W2ol5dSn7fsbJfHGyUKnQ5K8x7CDKPtwizBUsj/wVtVIA26oTyWlnTxZt27qw/9cfwpubOv0ZvsLxhkEg8gYLf7SP1568NKNMavJqU4ZidpEE3YpcBhG8rbAQQ5nIbMs3Wiri4KT60rpdAR9GfdqIDMdkivwMAPjJ4pWBfU2DUJc2XS1d9K2o0FSXUlr+phG0/OU+5oXrHxBmDHURlr8MuSyGdy6mfA789yNKS0Tk80nVpQz05QrIWwz/8b79cMH8qYF2UgYV9dspBpr8NYYcdy5Zh/uXbXY/++v58/8Fy18iR/6y9+cKPgtyR0/WV1bADJF9ANsS5HKJHefvkL9ACjL5B6p6yrV9hL+LI/+8z1/wkSP9pOA6fGMsf45svuDKPknIf2tnP86/5Vl85g/+apL1aTMy2sdXDiImJDRj2msnR1W15PjsSfvizk8tUGr+adNwid37TsrGLVgBh68c5x8nZ6kK2gH2/dykKL3Mr0uelWRMb6CqS5nub2p0UwZnHzI50E4mNXgUrQu7aQw5vv7PNwAAZx9yNgB1kleuwNBcb5N/bzaPbN7CRbctwaSR9a7c0p/3HHMG2U478eUTQz1lLjYNch2tDWnTlQTEwSQu/l6GaPkbxZB/fw7N9WmXGI7fbyzOO3wKLr7dTipKxSR59eUKPss7W7Awut5OkBqIkVd4ljPgOdI56tOmcmF68dj+fB692TxG1EVTSyZlIGUYnuyjGJQuWjAd+45rwr+esA8AYMPO4BrJpkGoS5nu3wWLBZLTcpbl1uMPK+8Qh95cwW1fxG+eWYNfPLYqsD+fDcjnEKVH8bumuhROnDMOa+xo6qQAACAASURBVG88GzOvud/dLg8elYQmf41hB9Hy538XLE+3zRUYdvVmA8XC+p0pNABMaKnHju6sL1Ii5dNf/WScMsi1/MM0/2LJX+VXSIKegQJGNnjkz2Upua0oqUL8rmeggMkjnQEjRtvOOSUFVGhIm24dJQ45YenMnz6J9Tv78No3T488T13KQNokX0SNjGP2GY1zD5viflbNntKmV6ahMWM6yzQqZB9ngBCreoqWfxwYA5rqzECW9+K3t4XuD8iyj9+PIhJ72GA5mJa/ln00hh18JZ2FUM9Gh/xXt3fjmbbtgeP6cwX35RrfUo8dPQO+yJmwqp6AvSi6q/mLso/gkIyTTGT4sojlE0agayDvG7RMIt9gxUkwm7dDOJvrg8QhVu/s6s+5A2esw9diofuoNH9fUpjFsN6xzsPWVeDIpAyfJa06p5wdrZo9iZo/d4zKy2vmCgyZlAGDbFI2DQIRIZ2AWBuFMEw+8xQRZ5lnJINDvF+y5a88XpO/Ri1BWd7B8sj/9qfX4Kq7X1Mc5023x42ow47urM9aF2UYWfNPGRTi8LVf1riIDBVKd/jm0NLgkYFBUl6CW4LajmB6/qun4Noz9wfgSVuiszYnJDnFDWDtXQOhi6rXpU305yzc/tRqt503NwmLrPjq9ERr6LblbwiWf3B/eS0E1QCaNr1on8Y6PjMMDlAZk1y5jBsBdQksf5H8m+qC8fhbYhb8kclbHOTqkpB/EbOTYqHJX2PYgBOKL9pHyPBtSMerlD0DeaQMwrjmDLZ3Z33hmabhWeMyjaQMw923Pm24JMqJRLUOaxyiyP/Hj6zAr58OlvG1r6GAFsHKNAzyWb1iW2nDQGMmhRGO9c+tYDnsNCn5AwisETDeKajG2/zu/W/jlsdX4ZE3t+D8W7z1DUTCV627ICKTMpEyyD2G3+f7/v09gT5zqBa0Ei1/Lp2oau+nTMO9b5z8xVDPMIgDkEqaWbfDn5ksk71I8NlCIdTyD5d9ohPAykFZ5E9EHyGiN4nIIqL50nfXElEbEa0gojOE7QudbW1EdE0559fw46V1u/D759YOdTdKxu4+W09WxfnnLYaGTPjPlfNs90AepkEY01SHnT0DvlIHpuGJJzKRmAa58fW2w9cfGinnFiSBaKjKksXP/q8N19/3lvK4XN7yhS/aso8HsTQ17yffo15ybHKIcf5FKFAAgO998GAcNm0UjhaW0Hx942489OYW337ijG23s9D5lSfvp/R31KUMmCYF4vxbm7wiZnIIp9LyF+L8uewjS0g5Z4bEZ0zc95NE8xcHoBEK2QcAPnT4FPzg/IOdPvj7XCeQd3/O8pG/+J0cscQxnGWfNwB8CMCT4kYimgfgQgAHAlgI4JdEZBKRCeAXAM4EMA/ARc6+GhXA/yxdj5sSLss3HGBZzLeuK8/UlTN8C5a95mnGNEPlE/7ic8t/2ugGWAx49V1vwRODvEShgMPXJLfoV33a0/x7swV86943saa9J/JaVN3yl5BIzrg5y/JN97lG7bYrtMXb5V9zspItbzHDtxj/AwAcMLkF/7zieB8xr27vcRe65xAtf/4s50xsxkxhaU2OTMpA2vB8CHyQrRfILiD7qDR/0743mZThyjKy5W/LPoZn8TuDgJHgPjQICVYjFLIPAMyb3ILZE5oBAI1SQpZcOTQn5DOIzziM5IdttA9j7G0gqJ8COBfAXYyxAQBriKgNwALnuzbG2GrnuLucfdUmkEZR6M0WSsrAHCo8sbIdn/zti+5nvk6vv7aPF8efMu1KjH1WUFJoqjPRPZB3Lf/j9h0LwF8nxxSiZgIOX4FYGjKe5n/va5vQtq0bD76xGVHIpIzgSl4iYccQzc8eXYkT5ozDYdNGoWAxXzVQwyDfQMJ9ABbzrFf+NbcgQ2Ufp86PSNSZlBEpB3GyFO/R6u09gXsoav58Fpc2SElgdSkDKZNgMft5c2tdJPwkDl9O6PUpA3VpEwb5NX9uOKQMw83z4NeRZAhsTEfLPgAwdkSd2w/ZL+BfJ7jg1/yFZxw2CxlM8h+slqcAEIXDDc62sO0BENHlRLSUiJa2t7cPUjf3LvDyBNWC3b1Z3+Ih72ztwhPvtEtVPb1ibfLqSyK45d89UEDaNDBtdKNvMXeAyz7cUg46fDka0qZTTgB419F0UyrBWYDKMScSfpzl/+NF7+C8XzwDxhhyBeZ76W2y9w8k/DO/H3Mn2pbnSXPHAwjKPpxU7XUN/H1RrTkgIi05SjnkcjxiJBBfZcw01PH0Ym2dnOVZxKIUIvdLNWNxyT9tos6p1CkmmPFcCNGPw59lkgmQ3+GrJv+mupT7PEzpdyL+LuQkxCSW/5A6fIloMRG9ofh37qD1CgBj7FbG2HzG2Pxx48YN5qn2GvTl8lVl+cvT86/f8yYuueOFQEE2bqWmTCP0JeHFr3Z0D7jkduXJ+/n2MYkiHb4cXGtOm4ZrqcUnLQUJNGmGrz9WnrcnyD7kj/Mn8qxgTqCHT2/FkmtPwceOngEgKPv4fAhSX8L0Zo6UwvKPuw6+lKIYjSNCrNJZsBhyBQumVMogkezjbLv8vfvgnMMm2+QvSCt8BlSf9iRDz4AozuHbHEr+pkv+8kJD/sqhVqjDN4zkh1TzZ4ydyhg7SPHvnojDNgKYJnye6mwL265RAfRmCz6rebgjbKDaJqzPes3fXse9r20CYFt5YS/JvEn24t4bdvW51uD5R07Fmu+f5X4WZR+ZR1KuDgyBJPwJOgDQooipB+Au+ygiSYbvV//xuuscBYCtTuigaAEbUlEzg8i1gsXEtYkj61HvOMVly79VqAwpW/BxBMPvQ9zspUcYcLjmb2fSBgdGwyC3vVzBJn95VicPGjL5pwRfyL+esA9OmjseadN23HOLn9dbakibvt8BkMzyF2cfoZZ/JuUbyESI93ZbVz+2dHi/bfH84u/6yBmtyuMrjcFq+V4AFxJRHRHNAjAbwAsAXgQwm4hmEVEGtlP43kHqQ82hL1soafGQOGTzVuxiG6UgLBZ8s1ArZSBvueUfUmZ4Ys7YEXWY2mpXzjSltXr5i2lEyD4qwhfJiId6toTUVq9TSCdJMnz/9Py7uOVxrzzAa+ttB7XP8pcyfAFv8JIHHe5wlGvRiM5amUTjpAV+T2RJQ4a43gEf0FKGAZVSJ2bY5gsWsgUroHuHPaOwz4A9GN69dAPO/Km9Khi3/OvSRuAZJ9L8Bctf7N8CIfKpqc50B4mJI+t9x4v3du2OXtzxjBfeK858RePgb585Dl88bY59jcWGZhWBshy+RPRBAD8DMA7A/UT0KmPsDMbYm0R0N2xHbh7AFYyxgnPMlQAeBmACuIMx9mZZV6DhohjL/52tXRjTlMGYEXWx+37lb8vwj1c24q3vnBGIZigHYZb/5g514kxUSn7KIBwydaRj+RuB7wZgv0j8HQvKPvaWjI/8vb87+3PIpMJlJxWB+kI9I15iMVGIz3rEcxtShq9BXtx/Sjpvgxvn73fgtjZm3IxamTTFWcYPP3wIrv7rMt/3cnx8GMRcCE/2UVelNIyg7BM3CMn3UPVb4G2sdqKzeCE3MXGPz/JUA7YMMdonZRJOnDMOHzh0Mo6c0YqTf/Q4AHvAnTyqAT+98DC8d/Y49OYKSPNEshBJ7ZxDJ2N/x0+jAn+uimCaiqHcaJ9/APhHyHc3ALhBsf0BAA+Uc14NNXqzBTBm645xP5rTb3oS45vr8MJ1p8a2u/jtrQBga6llriVdsBjW7+zFzLFNoZb/po5glUQgWIlRhGEQDpw8Eg+8viVQ0thn+ZM6zM8lOMFMFc/V1Z/H6KZMqCWmGhR8Dt8QRzXgL6C2w4l4SpvkRvSYBvnyEkQHsEzI3MHam8u7xwPAqIa0R/7SNWRSBr5w6hwsmDUa7VIxN7ntKNz/uhcR5ck+hvK3aJLnvM9Z9oImcXH3ySx//zb+WxA1fz6DOWG/sfjyGXPx9uZO3LdMHc0lWv6mQfjdJ+2gRXEmzIMNeB2iVuH4OjM4wDSkTdx80eH4c0gmNSBGWIXuUjZ0hm+FsKq9G5+765XYRSkGE3yqH+f05dKQqK0n2T/RPDkGP3pkBU760eNYv7PXV9tcBGNqPTaqEmPKIOzjRPfIMwduJfuSpaT2OfGIBKSKigkjwDjyj7L8+RrEgO2wBnhooueHkEtTuIuFKM7bkDbRly34+jqqMe1p3maQ/D936mwcu++YyNDCqAEMgK8sNy9KlzJIOWCahlduIV+wkCuw2IxbuR0l+Rsy+XshpG6cv2AMXHHyfhjX7J/9is59n+wjjMDiLKUxJP4fUD8fl9gjfhPesx88y1+Tf4Vw1d2v4Z5XN+H1jR1Dcn7GGHodKydO+im2VAFvrxL+hGdX7QAAtHcPRNZ/SSv05ZRhKB2rgP2yzBoXTCYCPJ4XyzvIvKHS/OUqnoYRPg2Pk32irGaf5d/NtXLPyWuQOs6f7ydDRf7N9WnPSRwRjhhF/sXUJ+LVP1MmKcsyGORl3D7dth3/eGVjrOUvt6MqHudbU6BguY7v+pTpSinyICY/u298YB7OPmSSfVxIlFQ6RB6UoSJ/vi3qfqaEAWqwoMm/CGTzFm59cpUyIYaXyh3MpIwoDOQtN+5aNKjX7ejBYd95BGu3exmqu3qzKAa8vUpEEvGfMmP+pCAZqnh+ed1V+buZY0LIn7wXif+tKukM+F9WuXa9HZ6o7q/qJU9a0plLJICfND2ZgiBOVcgX5x88b33aCKwqZhrkWvxhA599bLgVK17DU1efjMe/dJJyv1GNdjnq5voUJrU0KAfMlGD5X/cP26Efp/nLZCmXVwb8pax39WQ9zT8j1vaRBj/p2Yl5FOL77JMEE77nUQQf9R1/MwaR+3U9/2Lwu2fX4nsPLIdB5C4wwcEHBPHH0t41gLRJGNVYplCeAGJct0jSbdu6sbs3h3cdnR0AdhZL/hW0/EUeiLT8UwYgxarbtdvDyT+KuAApzj+EAEWCk+UzI0TCsPtWuuwjYnuPI/uYhnsunnCmaltN/qa7brGIMMtfJCHx9/v0V07GFkFCE6N9po1uBGDfR9km+N1lC/Dc6h04bd4EjGxMK6/dMChogccQapLIF5H8t3dnXcu/LuXJPnGWv2nAjVAynNIRdvlsdSRYHH7/yQX4tztfcvsiLkQj/i+Cv2ta9hkm4KUCVLKJamGNo25YjMOvXzTo/QL89dvFl573VfRF7C6S/PlgUskcAubU7AmDKqM2iRU1e/wIHDp1pO87EvYRpRTV+aKm8GKUjQzVjE/cM04v53Atf8M7l2iJcgQTljxw8pfLJIuObxHiRzHyZ2prI+bP9EIaVbMX1f2aOLIenz5xX+w7boR9XpX/Rojz54jzVSWRnURn/46eAXdG3pAxXeNAPq9sxYv3mwGY7gx0cjXVpHjvnHG46/Jj3M/8Ol3yVxA8f3Sa/IcZVD9RTv6yDLmncq7EWi6Wj/xtSUEk/109nszw0Bv+yowinmnbjmfatrvXUInsYf5TzhasSMtfpe2nTUN57wHvBVp01Ym458r3+L7j749Y1VN+pzg5R03nTQUJu/0tw+HL0ZQxvSgZwdo3KNhf/lkO9QRs2UdVUjklzG5OPWC8sm9R2b7KlbSkbb/62JGY0OKPdVddu2lQoO/cx7Jg5mhcetzMROeXIRphOwTLvyFtYoyT6xDl8wCAKaM8qcpizCV/MXu3WC3+0Gmj8NMLDwPgvYveYBzcn59rsBZvBzT5F4Wox5B1MgrllY5EvLh2p1u8rNIIk324Liouai1q/p/+w0tYs11dsfLi2593140F/L6EUsFfqmzeCo32AdQkXKp+yp+cv7BbiCUd0U6U7BNX2yfJSyzmXMiaf7GWv7j4DC9LYAjt3X7JUbj+vIMCfYuKfVda/tJzmj1hRGAfleZvEAXuNZ+l3v3pY/Gtcw4MHiPsP765Ttkf0UDZ0tnvq9TK768c7cSvIWMa+Ptnj8P8maNd3w5jDNOc5MENu9QhyEnBZx7c6HFlPcX94VLrYGr+mvwjwBjDd+97C23b7LLD7jNSmPMD7kIk4W195FfP+ci0kugNsfx5PXtxgWzRwQgk/4FVUva59Dcv4ran1IuZAGESAwVqp3BEkasX4SM6fNXni5Z9wksCKB2+wqYkJZ3HjvB8QynD8ElUAR9FlOafMl0ivfS4mXjqKyf7+iD7N0RSLTbaR7aiVX4X1aWbRrA9MUNYeX7hJjx59cl4/VtnROwNrNzajb5cARlnIRd+f/ulWRFf0YuB4YjprU6fueUPXHyMXS/ppLnl1RhrFIrrAd4gpJpFWFr2GRq88u4uHPO9R7Fiaxduf3oNLrnjBQDBCBERruwTQk78+7c2dyq/Lxd9OUHzZ6LsE9T85WifpHJOJWWfOKiLeJUWhihq/mGyD29btGRv+pdD8flTZ/vOURT5C38nkQnGFmH5u7KP4p7Upw138ZlpoxvdgANDIn9VLHlktI9ilvHd8/wWurKIW8izlGUf1VKOYe3Up81A4TcR8ya1YMXWTvTnCm6WLb+/chQXf3aiDMmNhILFMGdCM9beeLZbs79UNIYsSxk1gxnMDN+aJf+X1u0KdXz+ZPFKbOnsdxcC6ZXqpMg/UTHbL4wgeXXIwXqWPtnHYrjtydU47DuPqB2+0o8/KalblQj1DLl+mTRUFk/KoFAfShLLX4z2CTp87c+ir+GDh0/F50+dI7RDoQZAbIZvAvJvbVRb/qpBx5V9FIlR9WnTXbhelIVcy5/k/71joy3/4HcLD5qEJ758UuTxKgI7cc64oiJmgOKs4CNntGLl1m70DOTdkhdjHMtf/v3z2ZP42xJlnzD8xymz8auPHZm4T/LAKs/ERFiuUzhx80WjZsn//Fuexdk3Px3YvnxLpxvexjlRFckj4j0/eMz9O4wgeSjoYBVq8ss+wA0PvI3dvTnXyhE1/14peSmpnFOO5b+tqx/rd/aGkqcsXyRJ3fd9F2n5e9NrbyUvddsqS9pwB4/QU8Rq/uLfH3dkBBkj6v11ZMTqkzLxuaGeSsvf9HRlI9gHfq2urmwEBwgVwr7zh4oGrXGZ5Fd97ywcu++YRHWlws4Th4OnjsRA3sKKrd0u6Y5pss8ny55RA3fUb/6q0+Zg4UETE/dJrotlKGZeHKcdOAEAcNbBkxK3Xyxqkvz5A924O+jAWfiTp7Bia5ezn03YnPz5MxK5Uv5xyDXqOTj5l6PhrdjShZnX3I9nV20PfNcX4vBtd0o4iNmQPdkCJgvVB8P6LKMc8l9ww6M44YePhX6fqG6LYYCFxPtE3Vcv2geC7KO2/KMczVHko7JiwzJ8rz/vINxy8RGB/cWyAmmhvINczx/wJAoVcYkWpq8kgZRZyslfNEiiZIaw64+Lf//CqXPwL/O9Su68nSmjGnz7zRjTGHpuIJlvinf/0KmjANhVUrnlzzX/zhDZR8Q4Z2AKq+JaCmTZJ8ry339iC9beeDYOnDwy8F2lUJPkL0og27rUFSQBL1ImivRk6Yi/UPJ0UR5ASgEn/YcV4Zmy7MN/UO3O9YnX3JctYO7EZtzwwYN8fY5D0v0YY3jynXb1lDnk+uWqmmFZoWFdiJoV8G98Dt8wzV9J4uT+H/b8lJcaQv72eYKvXrNg+YtSj+2r8B/PgxDmKKJrxHBNVfau6Vq16r6FIczyFycfqufW2pTBDz58iPLYmy86HPuMbcLiq96Le694j3IfuW1euluFx790Ev74r0dj9vgRGOkQN78fvKy1vMC76pl/+qR98cPzD8F5hykXGiwJsuwTZfnvCdRkhq9IhG9t6sT4ufXK/WQLQfWIdvaonaei86h7IO8udp3kRdvc0YcJzfUBJyEnGNULJtZvt5i9DGBvtuBWifTJPtk8GusaMdmxvKLi7UUktfz/Z+kGXP23ZfjxRw7F+UdOTXSMWFUzWwhJDIog+CQvUMowIlby8g8+/radPgrkX5cyfHKg6s5EhXqqZhg+y9/0wkoNg2CE3PtDp40KbBNJRrxn3PJ0SymzcKeiGHkUdg3uOYpIeJJxzqGTcc6hkxPvf+enFmBuhON1xpgmzHDKfBw1sxWL397mK9r3/06YhfftP8F3jMpPkTYNXHDUtMD2clCM5b8nUKPk771IUQtXd4aEnonSww6Z/J0XShxgLvvNC3hx7S4A8Zp/R18Ox37//3DJsTPw7XMPitxXhLhyU8HyyJ8PGGJ/erMFNAorG8llG97e3OkuhiEiqeXPHeDrVQvAhFnuIZEoIgyKsPwjCIiTm7+wm1pmUkamCI5X/uxl8uf3RlwcXWwpkFWqGMhGSJa/KPsUQn43k0YGreB6sR6NcF+4/GFKz13u25NfPhktDUFqCLvHe5K8TpidPNzyk8fPwkvrduHw6d4Aed3Z8wL7xRWUqxTk84gO/aFAjZK/JfztsYlcB14OCVMZX7Llz18ocR1RTvxAvOyzzVnY43fPrXPJf1tXP97cFB0iKss+tuPN678Y59+bLaAxY7qkJlv+f3tpg/IciatVO+129efR2Z9DS72nm8pTbg4e8+ytsiQ6Ku17H1UuO4nxKabtBzJ8I6bg4jF88KlPm77CYnx72jSQtwqBtpIsROK3/P3RPvLxFy2Y5vPziAiz/Bsky9/V/CXymR6ivcvJUe72ISKvOBy331i88o3TY/cbzKUSVeAzAHHp0KFATZK/aO1zQrnl8VX4wUPLffvJ5M+zd0UOki1/TqQDBfWLGfeiiINJX7aAhoyJi297Hiu3deO6sw4IPc5X3oGxwApCfss/j8a6VMAC5AgboOJkn5yzFB8//NdPr8Gvn16DtTee7e4TFjklxzyLs6vffXIBHlvejmmtjb7tk0bWu7X7oy1/5xwRcf485lz1eMT4eObcZlm/5f6NtEngP5viNX9hnV0p2kfWqb7/IbWGLvdNtOoDso8bSx7alA9hmn+SMNbhjLhqopXEXZcf4xbF4yhHNisHNenwFS3dbMHCW5s6A8QPBMmfzxJEIt3ZHWL5h0TQxGnTIvmvaredeut22PIJn5momgha/v5HmxX6niswNAoLk6za3uOrKR8W8REl+9y3bBNmX/cg3t3RG3mNYVmccj19caCZO7EZ3/jAPBgG4cKjpgMAHvnCe7H4qhOF40NP6Uk9BvmctyJShno74A0IYqatfH95d8XFysln+fvbVMk+8qpR4iI0xfCr6PAVE6lk2ef4/cYCAM48KFk4oW8wUmyvVuwp2QcAjtlnjBvlxKPsBrNmfxRqkvz9so+Fi29fotwvYPk7x4lhkzt7/KthuZq/YOGKzzbuQYszCW4l82lpnyRL5QoWlm+x5aDenGT5S/HW/Jr5INGQ8cj/6/98A8d871F337AeRln+P3xoBQDbWR01vslOdI6PObHvpzvxzeLYKYYrnnf4FKy98WzMmdCMpjpRI4+w/OGRaJjD192usvyFMgsc8szqhNmcSNVx3/KgorI2xQElLSx8rpJ9ohBm+fP1aHlbB0yywwnFxcijICeJcQzmIuN7AkO1Bocnuw3J6WuT/LNSnZtdvTlMaAkmnHQIySDZvFeFUoyc2dGTxaSR9ThgUgsAjyDFAUakzLAXZXV7Nz712xd9uQcDef8CMTL533D/21j4k6ewYVevL9qnYAV/0HmX/O39mgTZB1CntsuISgZ7d6fn3I0a31QLcADAwVNGYu2NZ2OWs+aAGCaaZOGMKAJSyT5hZKpKQvNkH2+bPLgePNXu/1EJiTStkJnEomqmSe5MIyrEVIU42Ue1AlYShC0tOFSWa6WwpzV/DrfGzxANnmVdNRH9JxEtJ6JlRPQPIholfHctEbUR0QoiOkPYvtDZ1kZE15Rz/lIhEjOPTDn1gAmB/USi6s3m3ePEl6e9awDTWhvdpB0+mouOTZEzw96T7z2wHI8u34Z7XtnobstKlj+vUMgJ6qmV7QDsUNLebMF1GBYslebPnOuwB5DGTPh6tGG/xbDFXMRch7hs6DBwAuH6pzjLSFIGIEltn6g4/6gqiqLswyGXPvZmB7FdBaDOKK6XLH+ximjJlr/o8HW2l/qM3Htc3VwfwJ6UfUTwlecuO37WkJy/3KteBOAgxtghAN4BcC0AENE8ABcCOBDAQgC/JCKTiEwAvwBwJoB5AC5y9t2jEPX49TttS5tbnGHoHsi7Gl2uYOHav7+OJ99pR3vXAMa11LkvhrePmijDrOpRjbazb5OwclJ71wC6+nOe7ONY7bwJPhjk8gx9AvnLsk9TxnQHI+4YjlqMvFjZR5yRZPNWZPhsGLyqmn5nJJBs4Yzo2j5i1IyzTdrHy6FQtC3E2/O2ZMvf9RkkZEYvmc3bVieRtrh+QHGWf0iop2P5l/J8xLbOODB5SYNqwFBZ/q1NGay98Wycd3jlEsmKQVnRPoyxR4SPSwB82Pn7XAB3McYGAKwhojYAC5zv2hhjqwGAiO5y9n2rnH4kQa5gr3G7tbMfdy9d727nseh81aEw9AwUkHOifVZv78GyDR14+M0tyOYtvHdEnWu5Woo4fxFhJNXa6EV6NNen0NWfx5f/ugxf++cb7mISsuzDP7uWf30K6Aw6fEePyLj94XXem+pSoVEaoZZ/TN0iwJ7xlGJZyssS+mZLCczpJE5HMVNWtqT56VQWtjt4iJq/RBieJBLbDQBCSKtI/r74fHKcvk4fijC3xaghUyH7hIXbxiGTMvDcte9za+TsLaj2aKVSUclQz08C+Ivz9xTYgwHHBmcbAKyXth+taoyILgdwOQBMnz697M6dftOTAGyHo+hU5Vr1TMHyv+q0OXht/W48unybu60vV3Ct+mUbOuxjxjTi5Xd3Y3xLXSBlPhcW0hjyQxsp1BCZNLIeXf12pM9A3lLIPnA+2+Tfm82jN5vHRKdeT4ExH3mObqrzHL7CsnbF6N7itckQyT6bD5J/ksxgVwaRio4lReKSziRtdGBFZLtyAjbIO0wm/7AoojDw6xTvtUhClsT7JwAAIABJREFURCTJPomaBQBMElbSSiuifUqVfQB1Ulm1YzDLJg9nxM53iGgxEb2h+HeusM91APIA/lipjjHGbmWMzWeMzR83rrxFFABgzfYerNneE4jLZ8wm3tFNXjr7afMm+D4Dts4vr9LFp9Hjm+tdgigoNH8RSWrDzJ3Y4vvOk338lj8n/55sAX25glsbxrKYS2b1aQONadNNOusTNP+w+OLQOP8kln/ech3Vqu/DIC8uXuzCMZHWG3f4Cpp/wPJPLPvY21TVK8OOV4GT8jH7eA7iwOpiVFr9F3GmJA6K5co+cfjgEMkXGqUh1vJnjJ0a9T0RXQrg/QBOYV6IxkYAYmGMqc42RGwfNETV5AaACS11vhjrupQRqMORdeLjRWzabevz9pJy9sscF+cfBnGwmD+jFf/72ib3M9eH+QLynAe4Qd3Rm0WuwHwOX07+LfVppExyJSJX9smkii4yF+bwFcl+IF/AQM5PLnLmtAqcpDJOffqwc4UhShpyHb5iVU9pH/4bUbUSFeq5z7gm3P6J+YF941CfNvHg507AjDGNmPeNh5X7GEShsfVxaKlPobM/L8k+9u9jMMhfTOTTqA6UG+2zEMDVAM5hjImFXO4FcCER1RHRLACzAbwA4EUAs4loFhFlYDuF7y2nD0kQFl7I0VSX8k2P69NmYC3TXIEFQuR4WOa45rpAqYQwzT9usRcAmD+z1fcdt/y7B7zQU1/EkZNoJjp8+Xma61PImIbbnz5B9gmz/MMKvYX1XST7AYXsI/sqVPCWFnQG0SKrR0dZ/qpFUWQy5faBsryDIr6dR9TUpUzsI/iLiuHoAya1BGq8y+c1ipSTOHgEiWjEcNmnVM1fY+9CuZr/zwHUAVjk/DiXMMY+zRh7k4juhu3IzQO4gjE7MZ6IrgTwMAATwB2MsTfL7EMknlrZjvHN6qqdHPXSFL4+bfrC7gBbw5dJkVtQ45rrXNnHiiF/7jc46T8fw0ULpuPfTtzX1xYAzJGqFnJ9ma/OxBiwS8hB2OFk5za5lr9Hns31abveTIFh5jX3u8c0Zkx0D6gHxTB/RegSlcK12g5fP9mHnUeEWNXTvoYiLf8Ecf7ifvL+3pqpqrad/4U8Af5M5EXoK5mpL0b7FIvPnzobFx89HeMF/Z/3WX4+ewqjGtOBd204QZ7t7+0oN9pnv4jvbgBwg2L7AwAeKOe8SdG2rRufuOOFUGssbRJyBRZYC7QuZSgsfyt00RPxR+Nq/iEEygeFtTt68f0Hl7vkz7ff9+/vCcQde5a/TaJ5i/nW4eV/8+soME/2OXjKSOzuy/msPSJ7wOvPhQxQIcQbKvvkoh2+XdLM6wfnH4xcgeFr/3zD3ebKPia3/Cun+XPKZiw8k9eKEP1VMfy8n/K9KiYqh+O1b56urHZqV/Ysujm7H0Q+4ge835FYdHBP4qWvnTYk502CJ798Mprqaov89+oM3/3Gj8B3zjnQterklYP4lFtO2KlLGYraOEGHL0dGqMDYmy3gij+9jDc2dij3LVhM6YPI5i1MbKnHQVOCK/dwouF+hILFfGUSdvXYf/NpPXf4zp/RiuvPOwhpk3zrEDekTV8kSaAvYZJV6OL0/jh/WfOXLf+0aQTKG3iWf2nkH6n5O18xFr6S1wGT7NnWvEnBWvF+2Yj7JoLJaOK5VPjUe2bh+vOCZbpHNqQxsjG4YlTUMyoFfFWq9zilKPY0xDLVww3TxzQWvaxktWOvr+p5wVHT8PV7bGVpv/EjfOUTGjMmOvpygaloygxa/vkCUzpxifwRFYvf2oq3NoeXX84VLCW55gosNNlEJqq8xXxrDbiWv9PngmVr/py0MqaBzj6PgN3qjiGZs2Hp/2GWvzjLGVBE+3RLlv/U1kZs3O2v9S9X9SxWlk4Sq83AvJBM6buFB03C4qvei/3GK8jfDfX07Ho+O5Nng1EyzdffX1w+oxjtUwmMqEvhqatPxnhFKRON2sNeT/5iSJ5cv4eToEz09nHBksgqyz9jGiDypuec+BvSptLRmbeYMs46m7dCyxgE1wm2XDLPmIa7IDWfwRQYg2V5pJU2DV9f3LruIUQVFqkk9mN3b9b2jaTN2Dh/0VENAEfOaA0sn5lyl1H0LP+ZYxqxpTN8mU0RSS3KqAJuKuIHoCyxkA7JR6hkyPi5h09RrtRVDuRywhq1i72e/AG7hvaLa3ZiW5e/AieXfRoU5C/Xaw/T/Lm1zotvcS5ozESQv0Jrt5O51Jqj7D8oCJb/pFH12OoQJO/zsg270ZPNuyUj5OUPm5zrDiPMcNnH+/ui257HsfuMwTc+MM/tX33aiNX8F8waDdOgQKSRN1B5pProF09S9kOFKPLng524Em4xjlRxJTCvv/7oLnd7Bdn/5LnjgbkVa05Dw4eaIP9j9hmDY/YZg2//rz+wiFvKsuYPBC3/bIEpI3jE/Uwi5B32DwtvzIfKPlao7CO3lbeYW3F0Qku9W++fk/8flrwLwCszLOvr8opOqj6qIMo+2zr78c7WLgBexmhzfdqO9smpo32+/6GD8WFnTV8ez88h1/O3hEXokyBKG//5R4/AXS+8iwMmNXuhk4lbVtfz5xnZnz1pX9++5VL/3f92bKCUuIbGYKAmyJ/jogXT8Ztn1rqfORFw0vzxRw5Fm7OAitLyF8ivKWOiJ1vwEavB1xtEOPlbLJipC9jWfSZE9pHb4pZ/fdrw1QSSZzByvRwOLneFcWuo7CNIHNm8hU0dtv+Ea/zN9Sm17ONY/iMb0r7FtEXIoZ7FxvlHaeNTRjXgi6fbJrQn+xQxsPAsW4PA6T1jGsrEpnJLBSStra+hUS726mgfGXMmNOP1b3lreoolEADg/COn4isL9weg0PzzftlnVKNd/iEtFePikINVvnjaHHz5DJuAxMgbt/0oy18aLHKO5t9Sn3YlHPs6ZPK3/w+Sv31MGFElSVAbyFvY0tEPxpjrr6hPmXaGb97CJ4+fhaVfs5PDueUv3h9Z9jGlgarYOP+k4MJPMRytKuwWhmEazKKhEUBNkT/gJ0huycdp/kR8+UOPFFubbItbtPyjyCGdMlwLUgx9XLJ6BwBbZw+rKx5m+bc0pH3O6oaM2pqW45fjklnCyJ/LPowxZAsWerMFdPblbX+FaSCTMrD47W3oHsijLm2495Br/uJAKZOkK/sYnsN3MOBa/kUINKriamG9q9UiYRrVh5ojf5FgOZnFRfukTQM5i/lkn1bH8hetdVl6EHkgbRqu5ds74JH5hbcuwdrtPY7sk8zy56GeLfUpXz/lYmNc9mmp98eQh2VZ/uCh5ZHJbFz2EWWdb9z7Bn799BrUpU3fvegdyLt94zWJxLr8wXslyz6DS/7FWOjuMUJ5iLDuactfo1pQc+QvgpO5LJcA8nqqhFzeb/mPUpC/7KAUZxQZk9yBp0eSfTZ39CNbhMPXTvLKo6Uh7eu7nKnskn+D37UTZpze8vgqPPD65ljLX3RYP/LmVgB2lI54z845bLJ7vU++Y684JkYdhXFkseUdzjtscqL9vPMWL/twiHH+oe1r8teoEtQ0+XOCkev4AP4BIZ0yXIvYi/awCVW01mVrVbS406Yn+/QM+Ml8d2+2SMvfciz/tC9SSR7E+Plkyz8K3QN5pcPXNEhZuoIPTGIfv3jaHBw5w3ZcniuQszjrCpNH+OzghNnJynjf9C+HYfX3zkq0L+CFlBYjz/DH6h/ci1upTUNjuEGTP4IWMxCUfbIFhrxl4dMn7ou1N57tSieitV6QSDMjtcFj2HukcgdbO/uLC/Us2OUdWhpSfstfdvhy8m/wk3+UotKfs5SWv0nkZt2qktQ6+/Nu9VQxg1QsZyAmsYVxpGEQHv/SSfjVx44M76QAvuhJUpRt+cfKPpr8NaoDNRXqKaPAwmWfep9kYycv5QrM1e15DRrRWpdr36QMQn3aQH/Osp2dDufLtW62dA44ETPJxuJcwUJHXw6jGjK+WUsw1NP+ny/ykgT9uYKysJthCIvThxSt4wu5i1VUm+u8c4vXF0WSM2PWUy4HpTh8OZKMMZr6NaoFNU3+liv7xFn+hL6c47R0Sd9f4AsIZnumTEJrYwabO/qRMQmMOQ7frMryD6/tI2NXbw4Wc0rkioOUvK5siMOXhcaq2OSvInfb8o8jfzs5aVyzZ/mLMoho+R88ZSQuPno6Zo1tchfF2RMgReROUiRJOtOWv0a1oKbJ33P4Bkk3ZfolG05QY53Kf3wQEAlBdlKahoGRDWls7uhH2jRc2aRH0vC3dPTbmr9A3j/88CH4yaJ3sKkjSIy8fv+oxowvdl4mJ050suUfJfv0ZQvKGkaG4ZE/T+qS6xfxzNSwwmGi5W8ahBs+eHB4RwYJ/A6V7vC1DwwP9SypWxoaexw1qfn/z6ePxV8uP8az/BWyj4i0abjVQHlxOD44iJatTP5px/IHbPJMRWj+cpz/BfOn4W+fPU7ZHz54tDamA8loInhz4kB2wuyx+Nyps0OP6c8XQh2+suwzY4y/SFiT4zsZ06Qm/1RCWWswYZQg+/CZkkHx5K7JX6NaUJOW/1Ez7UgUrtHLhc9kpFMG2p2icBOcBTK4hBF1pGmQmwzW1Z/HCCfZSiZ/3rZM5GGLhHOMakz7SjWrzi/jzk8dHdlmX9Z2+F5y7Axc+b7ZOOqGxXZbJFr+NvnPGtuE5Vu63GP/ecXxeHX97lB5JD0MguArJftoh69GtaMmyZ9jzoRmbO0cUGb4ihBr7ox39Gyvmmf4cSmD3BK6BcuCadgDgRzq6SZBSYOQOBiIFUM5RjVmlJE33jHFE1F/roBcwUJd2nSLlwH2zEW2/E/efzw6+nJYeNBEzBzThNkTmjF7QrAs8oi6FLoH8kUVahssuD0o4d4QEY7ZZwwefGMLZo5Vl0bW3K9RLahp8v/5R4/Amxs73IStMPA6NAbBXe2Hb4t611OGgS+cOgejGzP4wCGT8cKanQCCSV4ccpy/6ANorksFFqIf1ZCOrABZyipQ/bkC8gWGtEl+f4LC8j9o8khcMH9abJtfPmMuvnnvm+4aw0OJUjJ8OUyDcPHR03HqvAmBVeE4tOWvUS0Y+rdxCDGyIY3j9gtf0u6U/cdjn3FNWLHVrvQ5rrlOKD3M5YPwlz1lEurTprtOL9e8wxY0l+v5i+TdXJ8OkP/IhugFsUWCu/S4mYG1dFXoHsgjbzGkDMONnz9p7jis3NrtOsi5wzdpdNIlx83EJcfNTLTvnkIpoZ4m2cs4hhG/3a6GRnWgLA8cEV1PRMuI6FUieoSIJjvbiYhuJqI25/sjhGMuIaKVzr9Lyr2AwcSvLz0K1509z5V9JggLYhtewHgo5KUFeUG1XT1Z1e4BMhVrycgRO831KaRMQxmpJB7P8a1zDsSPLzg0vLMO+EyC9+XZa96HX33sSDvOXwr1jHI2D3cUY6BHrO2uaFfTv0Z1oNy39z8ZY4cwxg4DcB+AbzjbzwQw2/l3OYBbAICIRgP4JoCjASwA8E0iai2zD4MOHoUjkn9U5RlOnKZUtphr6Lt6PammScgunjyyHjL4ACKTP48iiopUKkX24Q7cMU12+5NHNaA+bdqyDwN29mTd2j7VTP6DF+dfQmc0NIYAZb29jDFxpfImeJx4LoDfMxtLAIwiokkAzgCwiDG2kzG2C8AiAAvL6cOeAJdrprZ6033GePhf8G3nFr7swJXLLACeDwEApo8JOhFNl/z9x452yDmKgEtd/LspY+KsQyYF2lq5tQtHXL8Iv392nXPuaEf5cEYpFnoS8teWv0a1oGzNn4huAPAJAB0ATnY2TwGwXthtg7MtbLuq3cthzxowffr0crtZFnodjX5qq0fOfJRTveoNaRO7kQuQRXNdKhC1M2ZEBu/utJdhnDQyqCXbjmUrUH+I685Rln8pzsf9xo/AhUdNC2QFm0RYvb0HALDCWb4xqeY/nFCMhCMjCbEPpuU/Y0wjZocsMq+hUSxiyZ+IFgOYqPjqOsbYPYyx6wBcR0TXArgStqxTNhhjtwK4FQDmz58/OMXdE2Jrl51l63P0RZAIDx2Va/UYBgWidkYLkUYqy5JvM4nwm0uPwv8u24S/v7wRk0fZElGk7FMkN2dMA4uvOjGkLQqUdahG8uco1eE7GO0mxRNfPjl+Jw2NhIh9exljpzLGDlL8u0fa9Y8Aznf+3ghAjAGc6mwL2z6ssa3TTsISZR8e86562aMWSB/prLnLHbVxBMo1f4PsuHq+bCOfJURJEUVb/hG7Rw1M1QRvDd/kx3DLI1Fht+q7JRo1inKjfcQ6AecCWO78fS+ATzhRP8cA6GCMbQbwMIDTiajVcfSe7mwb1tjmZOBOaw1q8qqXPUzzB4DmOpv8eQkEwyBccuwM/PDDhyjP7S4e7pxoa6c9C5mocA7LKJb8o8itGoleBS77lHI5SXwomvw1qgXlav43EtFcABaAdQA+7Wx/AMBZANoA9AK4DAAYYzuJ6HoALzr7fYcxtrPMPgw6rjx5P/z8sTbfilhzJtra64JZowP7NzjWuYowuRU5ZkQGG3f3IWUQvn3uQYH9OFzyd/7f5paZUNfPEVEs+UdJFt6SkMFks2rEYMk+OslLo1pQFvkzxs4P2c4AXBHy3R0A7ijnvHsaXzpjLr50xlzftiOmt2LJtacoLXBeYz9lBCdWPEHq48fMwJLVO/GVhXMD+4gwBdkHsIu5AUHn8BHTRwEAjpzRipfW7XKOjWw6gCSW/6HTRuGplduLa3gYoqSqngnupyZ/jWpBTWf4losw6YU7euUkLwAYyNlO02mjG/GRBKUReBucfH98wWF4ce1OTBacz8uvX+h+/7fPHIdLf/MCHl/RXnSoZ1Q0C7d6J7bEy03VgKJCMl2pSMs+GnsPNPkPAjjpqkoY87o445vjZRvAI31OVqObMjjjQH/wlRzxw/mneNknHLxu/4wxjbj9E/PdEtfVilI4OlmcfwkNa2gMATT5DwLcpR4VZDFvcguefKcdYxOSv1hUrlgUm+Ebtfuqdru+0dyJLTh13oTiOzPMUJLDd4hDPTU0KglN/oMAzhEqS/FnFx2O5Zs7A0lUYRDj/JOCk1QlZZ9eZwGZ/SfuHUlGpWTiJiH/vSQoSqMGUL1ZOsMQ/7ziePz0wsNc608V6jmyIY2j9xmTuE2+0EwxZFVM2eLj9vX6kmT/qIqW1YTSMnzj99EOX41qgSb/CuKwaaNw7mFT3GX/5MJupcCUHL5JwAeKJMfc+amj8bOLDvcdp8Lp8yZgamtDyfWChhuKGUz580xyhOZ+jWqBln0GESrLv1hwuacYzuW7JiE40yC3OFzU3rd+Yn7yDlQBSnkySe6nLuymUS3Qlv9gwAkNrERWrJzhmwSuzyHhMZ60VFzfqhmDJc/U0j3UqG5o8h8E8CxeVahnseDEXIzcwn0OSU/vEWHtMFcpi7kkgdb8NaoFmvwHAbzWvyrUs1iYJYR6esXLkh0kZxHXAgbrUmvpHmpUNzT5DyIqQf6pMkI9kx7D96slo3XQZJ8amj1pVDc0+Q8CPNmnfCIwXGIu3uObNNjIKMGvUPUYpFDPWrqFGtUNTf6DiEpYgXzyUJTD1z22ONmnFnjLq82vHb4atQ1N/oOAYhyEcfCyhYs5pjgGKml2UeUo5kqLeZw1NXvSqGpo8h9EVIIH+OyhqAxf5/+kg5BXPK6YnlUnvByIwWlfk79GtUCT/yCgkgsOc92+uAxf3o9kPdEO38qhhm6hRpVDk/8ggFVQ9+EWfzGBQ5zYknaDDzC1ZLUWc6kHTm4BALQ2ZirarobGUEKXdxgEcM6thIZerPNWPMYqVvZJ3q2qRSnP5rqzD8C5h03BfuNHxO5bS34TjeqGtvyHOdzyzEWVd+CWf3GyT01Z/kXsW5cyceSM1kHri4bGUECT/yCiElRaTHlm+Zik4lMtVXcYbIevhka1oCLkT0RfJCJGRGOdz0RENxNRGxEtI6IjhH0vIaKVzr9LKnH+YYcKenzdbN1SsoUT9oNPEGqBDwc7zl9Do1pQtuZPRNMAnA7gXWHzmQBmO/+OBnALgKOJaDSAbwKYD/s9fImI7mWM7Sq3H8MRlQn15G0VU97B/j9ptI9VxALlewtq50o1NNSohOV/E4Cr4bczzwXwe2ZjCYBRRDQJwBkAFjHGdjqEvwjAwgr0YVjh6H1GAwBmjmkqv7GSMnztfZM6fN3FSmqIEWvpWjU0VCjL8ieicwFsZIy9JlmmUwCsFz5vcLaFbVe1fTmAywFg+vTp5XRzj+Pjx8zAKQdMqMiSh57sk/wYV/MvUvapKcu/hq5VQ0OFWPInosUAJiq+ug7AV2FLPhUHY+xWALcCwPz58yuZNzXoIKKKrXVbzopTVkL2n9BSDwA473DlOLxXQlO/Rq0jlvwZY6eqthPRwQBmAeBW/1QALxPRAgAbAUwTdp/qbNsI4CRp++Ml9LtmUGzCFlC8pDG6KYPl1y90l3OsBQz2LGfa6L1joXuNvRclv+2MsdcZY+MZYzMZYzNhSzhHMMa2ALgXwCecqJ9jAHQwxjYDeBjA6UTUSkStsGcND5d/GXsvePZtUv0eAM45dDIA4Jh9xiQ+pj5t1pQUMpiX+uDnTsC9V7xn8E6goVEBDFaG7wMAzgLQBqAXwGUAwBjbSUTXA3jR2e87jLGdg9SHvQTFSTiATfprbzx7sDq0V6CSlVdlHDCpZfAa19CoECpG/o71z/9mAK4I2e8OAHdU6rx7O7ywTQ0NDY3KoXZE3iqFK08Mpqlag6ghhUtDQwlN/sMcPGZfU7+GhkYlocl/mIPLPlYxHl8NDQ2NGGjyH+bwYvaHuCMaGhp7FTT5D3MUW6FTQ0NDIwk0+Q9zuJq/dvhqaGhUEJr8hzlczV+Tf0WgB1ENDRua/Ic5ii3SpqGhoZEEmvyHOdwlGYe4H3sLaqmEhYZGFDT5D3Noy7+y0LKPhoYNTf7DHHPGNwMAZo5pHOKeaGho7E0YrMJuGhXCh46Ygn3Hj8Bh00YNdVc0NDT2ImjLf5iDiDTxDwK09K9R69Dkr1GT0NK/Rq1Dk7+GhoZGDUKTv4aGhkYNQpO/Rk1Bx/lraNjQ5K9RU9Bx/hoaNjT5a2hoaNQgNPlraGho1CDKIn8i+hYRbSSiV51/ZwnfXUtEbUS0gojOELYvdLa1EdE15ZxfQ6NUaOlfo9ZRiQzfmxhjPxI3ENE8ABcCOBDAZACLiWiO8/UvAJwGYAOAF4noXsbYWxXoh4ZGYmjpX6PWMVjlHc4FcBdjbADAGiJqA7DA+a6NMbYaAIjoLmdfTf4aGhoaexCV0PyvJKJlRHQHEbU626YAWC/s8//bO7sQK8owjv/+brqFRqUtIa3UKgshEZtsYfRB2JdZtAleCEFdBEUpFBGlCFEXXRT0CVFomfatWZEIUZZCV7mu5ceaaVZGiblWaHXTl08X82xO2zlHlz0zczzz/GA477wznPnt/8x5nHlnjvO991Xr/x+SbpPUJ6nvwIEDddAMgiPEsE9Qdo5a/CV9KKm/wtQDPAtMAbqAfcBj9RIzs8Vm1m1m3W1tbfV626DktJ7QAsCoqP5ByTnqsI+ZXXksbyRpCbDGZ/cCk1KL272PGv1BkDkPzz6XyW1jubQzDiiCcjPSu30mpmZnA/3eXg3MldQqqQPoBHqBjUCnpA5JY0guCq8eiUMQDIcJ41q5b+Y5tIyKI/+g3Iz0gu+jkrpInjK4B7gdwMy2S1pJciH3L2Cemf0NIGk+8D7QAiw1s+0jdAiCIAiGiY6Hn7t3d3dbX19f0RpBEATHFZI2mVl3pWXxC98gCIISEsU/CIKghETxD4IgKCFR/IMgCEpIFP8gCIISEsU/CIKghBwXt3pKOgB8O4K3OB34sU469SS8hkd4DY9G9YLGdWs2r7PMrOLP2Y+L4j9SJPVVu9e1SMJreITX8GhUL2hctzJ5xbBPEARBCYniHwRBUELKUvwXFy1QhfAaHuE1PBrVCxrXrTRepRjzD4IgCP5LWY78gyAIghRR/IMgCEpIUxd/STMl7ZS0W9KCgl32SNomabOkPu8bL2mtpC/99bSjvU+dXJZKGpDUn+qr6KKEpz3DrZKm5ez1oKS9nttmSbNSyxa6105J12ToNUnSekmfS9ou6S7vLzSzGl6FZibpREm9kra410Pe3yFpg29/hT/QCX/o0wrv3yDp7Jy9lkn6JpVXl/fntu/79lokfSZpjc9nm5eZNeVE8rCYr4DJwBhgCzC1QJ89wOlD+h4FFnh7AfBITi6XAdOA/qO5ALOA9wAB04ENOXs9CNxbYd2p/pm2Ah3+Wbdk5DURmObtk4Fdvv1CM6vhVWhm/neP8/ZoYIPnsBKY6/3PAXd4+07gOW/PBVZklFc1r2XAnArr57bv+/buAV4D1vh8pnk185H/hcBuM/vazP4A3gB6CnYaSg+w3NvLgRvz2KiZfQz8fIwuPcBLlvAJcKr++/jOrL2q0QO8YWa/m9k3wG6SzzwLr31m9qm3fwV2AGdScGY1vKqRS2b+d//ms6N9MmAGsMr7h+Y1mOMq4ApJdX/OZg2vauS270tqB64Dnvd5kXFezVz8zwS+S81/T+0vRtYY8IGkTZJu874zzGyft38AzihGraZLI+Q430+7l6aGxgrx8lPs80mOGhsmsyFeUHBmPoSxGRgA1pKcZRw0s78qbPtfL19+CJiQh5eZDeb1sOf1hKTWoV4VnOvNk8B9wGGfn0DGeTVz8W80LjGzacC1wDxJl6UXWnIO1xD33TaSC/AsMAXoAvYBjxUlImkc8BZwt5n9kl5WZGYVvArPzMz+NrMuoJ3k7OKcvB0qMdRL0rnAQhK/C4DxwP15Okm6Hhgws015breZi/9eYFJqvt37CsHM9vrrAPAOyRdi/+BppL8OFOVXw6XQHM1sv39hDwNLODJMkauXpNEkBfZVM3vbuwvPrJJXo2TmLgeB9cBFJMMmJ1TY9r9evvwU4KecvGb68JmZ2e87eR+UAAABeElEQVTAi+Sf18XADZL2kAxPzwCeIuO8mrn4bwQ6/Yr5GJILI6uLEJE0VtLJg23gaqDffW7x1W4B3i3Cz6nmshq42e98mA4cSg11ZM6QMdbZJLkNes31Ox86gE6gNyMHAS8AO8zs8dSiQjOr5lV0ZpLaJJ3q7ZOAq0iuR6wH5vhqQ/MazHEOsM7PpPLw+iL1D7hIxtXTeWX+OZrZQjNrN7OzSerUOjO7iazzqufV6kabSK7W7yIZb1xUoMdkkrsstgDbB11Ixuk+Ar4EPgTG5+TzOslwwJ8kY4m3VnMhudPhGc9wG9Cds9fLvt2tvtNPTK2/yL12Atdm6HUJyZDOVmCzT7OKzqyGV6GZAecBn/n2+4EHUt+DXpILzW8Crd5/os/v9uWTc/Za53n1A69w5I6g3Pb9lOPlHLnbJ9O84r93CIIgKCHNPOwTBEEQVCGKfxAEQQmJ4h8EQVBCovgHQRCUkCj+QRAEJSSKfxAEQQmJ4h8EQVBC/gFtzJkqOWV8fAAAAABJRU5ErkJggg==\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "tags": [], + "needs_background": "light" + } + } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "mV5jj4dThz0Y" + }, + "source": [ + "另外,`avg_final_reward` 代表的是多個回合的平均 final rewards,而 final reward 即是 agent 在單一回合中拿到的最後一個 reward。\n", + "如果同學們還記得環境給予登月小艇 reward 的方式,便會知道,不論**回合的最後**小艇是不幸墜毀、飛出畫面、或是靜止在地面上,都會受到額外地獎勵或處罰。\n", + "也因此,final reward 可被用來觀察 agent 的「著地」是否順利等資訊。" + ] + }, + { + "cell_type": "code", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 281 + }, + "id": "txDZ5vlGWz5w", + "outputId": "bc284774-255a-45ac-dabf-3dfb5e1e5565" + }, + "source": [ + "plt.plot(avg_final_rewards)\n", + "plt.title(\"Final Rewards\")\n", + "plt.show()\n" + ], + "execution_count": null, + "outputs": [ + { + "output_type": "display_data", + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "tags": [], + "needs_background": "light" + } + } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "gyT7tNwkVdS-" + }, + "source": [ + "訓練時間\n" + ] + }, + { + "cell_type": "code", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "_t-JsKxUViFy", + "outputId": "333aa287-0455-4028-b91c-f83c8d2e1b57" + }, + "source": [ + "print(f\"total time is {end-start} sec\")" + ], + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "text": [ + "total time is 674.2419369220734 sec\n" + ], + "name": "stdout" + } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "u2HaGRVEYGQS" + }, + "source": [ + "## 測試" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "5yFuUKKRYH73", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 500 + }, + "outputId": "7901d4d3-a71b-468e-a12e-6bd9edff551e" + }, + "source": [ + "fix(env, seed)\n", + "agent.network.eval() # 測試前先將 network 切換為 evaluation 模式\n", + "NUM_OF_TEST = 5 # Do not revise it !!!!!\n", + "test_total_reward = []\n", + "action_list = []\n", + "for i in range(NUM_OF_TEST):\n", + " actions = []\n", + " state = env.reset()\n", + "\n", + " img = plt.imshow(env.render(mode='rgb_array'))\n", + "\n", + " total_reward = 0\n", + "\n", + " done = False\n", + " while not done:\n", + " action, _ = agent.sample(state)\n", + " actions.append(action)\n", + " state, reward, done, _ = env.step(action)\n", + "\n", + " total_reward += reward\n", + "\n", + " #img.set_data(env.render(mode='rgb_array'))\n", + " #display.display(plt.gcf())\n", + " #display.clear_output(wait=True)\n", + " print(total_reward)\n", + " test_total_reward.append(total_reward)\n", + "\n", + " action_list.append(actions) #儲存你測試的結果\n", + " print(\"length of actions is \", len(actions))" + ], + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "text": [ + "/usr/local/lib/python3.7/dist-packages/torch/__init__.py:422: UserWarning: torch.set_deterministic is deprecated and will be removed in a future release. Please use torch.use_deterministic_algorithms instead\n", + " \"torch.set_deterministic is deprecated and will be removed in a future \"\n" + ], + "name": "stderr" + }, + { + "output_type": "stream", + "text": [ + "260.62265430635034\n", + "length of actions is 299\n", + "-212.89375915819693\n", + "length of actions is 319\n", + "11.862808485612831\n", + "length of actions is 241\n", + "8.015383611389638\n", + "length of actions is 231\n", + "-219.21903722619058\n", + "length of actions is 256\n" + ], + "name": "stdout" + }, + { + "output_type": "display_data", + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "tags": [], + "needs_background": "light" + } + } + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "Aex7mcKr0J01", + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "f1706a79-1fbd-4d61-bdcd-ab257cb152e5" + }, + "source": [ + "print(f\"Your final reward is : %.2f\"%np.mean(test_total_reward))" + ], + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "text": [ + "Your final reward is : -30.32\n" + ], + "name": "stdout" + } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "leyebGYRpqsF" + }, + "source": [ + "Action list 的長相" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "hGAH4YWDpp4u", + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "c7f5fa21-7b7a-43a8-8478-df76dce7a4ad" + }, + "source": [ + "print(\"Action list looks like \", action_list)\n", + "print(\"Action list's shape looks like \", np.shape(action_list))" + ], + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "text": [ + "Action list looks like [[1, 1, 1, 1, 2, 2, 1, 2, 1, 2, 1, 2, 2, 2, 2, 1, 2, 2, 1, 2, 1, 2, 1, 2, 2, 2, 1, 1, 2, 2, 1, 2, 2, 1, 2, 3, 2, 2, 3, 2, 2, 3, 3, 3, 2, 2, 3, 2, 3, 3, 3, 3, 3, 3, 2, 2, 3, 3, 3, 3, 3, 2, 3, 2, 2, 3, 3, 3, 2, 2, 3, 2, 2, 3, 2, 3, 3, 2, 2, 2, 3, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 2, 2, 2, 2, 2, 2, 2, 2, 3, 2, 2, 1, 1, 2, 1, 1, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 2, 1, 1, 2, 1, 1, 2, 1, 1, 1, 2, 1, 1, 1, 2, 2, 1, 1, 2, 1, 2, 1, 2, 2, 2, 2, 2, 2, 1, 2, 2, 2, 1, 2, 1, 2, 2, 2, 2, 2, 2, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 2, 3, 3, 3, 3, 2, 3, 3, 3, 3, 3, 3, 3, 3, 2, 3, 3, 3, 2, 3, 3, 3, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 2, 3, 3, 2, 3, 2, 2, 2, 2, 3, 2, 2, 3, 2, 0, 1, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [2, 2, 2, 2, 2, 2, 2, 3, 2, 2, 3, 2, 0, 2, 2, 3, 2, 0, 3, 2, 2, 2, 3, 2, 2, 3, 2, 3, 2, 3, 2, 2, 2, 3, 2, 3, 2, 2, 0, 1, 2, 2, 2, 0, 1, 2, 2, 1, 2, 1, 2, 1, 1, 2, 2, 0, 2, 1, 2, 2, 1, 1, 1, 2, 2, 2, 1, 2, 2, 1, 2, 2, 1, 1, 2, 2, 2, 2, 2, 1, 2, 2, 0, 1, 0, 2, 2, 2, 2, 3, 3, 2, 3, 2, 3, 2, 3, 3, 2, 3, 2, 3, 3, 3, 2, 3, 2, 3, 3, 3, 2, 2, 3, 3, 3, 2, 2, 3, 2, 3, 2, 3, 2, 2, 2, 2, 2, 2, 3, 2, 3, 2, 3, 3, 3, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 2, 2, 1, 1, 1, 2, 1, 1, 2, 1, 1, 1, 1, 2, 1, 1, 2, 2, 1, 1, 1, 1, 1, 1, 2, 2, 1, 1, 2, 2, 1, 1, 1, 2, 2, 1, 2, 1, 1, 2, 1, 2, 2, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 2, 2, 3, 2, 2, 0, 3, 3, 3, 3, 3, 2, 3, 3, 3, 3, 3, 2, 3, 2, 3, 3, 3, 3, 2, 3, 3, 3, 3, 2, 2, 2, 3, 3, 3, 3, 3, 2, 2, 3, 3, 2, 2, 3, 2, 2, 2, 2, 3, 2, 2, 2, 3, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 2], [2, 1, 2, 2, 1, 1, 2, 2, 1, 1, 1, 2, 2, 2, 2, 2, 1, 2, 2, 2, 0, 2, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 2, 3, 2, 3, 3, 3, 3, 2, 3, 3, 2, 2, 2, 3, 3, 3, 3, 2, 2, 3, 3, 3, 3, 2, 3, 2, 2, 3, 2, 2, 3, 2, 2, 2, 2, 3, 2, 2, 2, 3, 2, 2, 2, 2, 3, 2, 2, 2, 3, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 1, 1, 1, 1, 2, 1, 1, 2, 1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 2, 2, 1, 1, 1, 2, 2, 1, 1, 2, 1, 2, 2, 1, 2, 2, 1, 1, 2, 2, 2, 2, 1, 2, 1, 2, 1, 2, 2, 2, 2, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 2, 3, 2, 3, 3, 3, 3, 3, 3, 2, 3, 2, 3, 3, 3, 3, 3, 3, 3, 2, 2, 3, 2, 3, 3, 3, 3, 3, 2, 3, 3, 3, 2, 2, 2, 2, 3, 3, 3, 2, 2, 3, 3, 2, 3, 3, 2, 2, 2, 2, 2], [1, 1, 1, 1, 2, 1, 2, 2, 1, 1, 2, 2, 1, 2, 2, 1, 2, 1, 2, 2, 2, 1, 2, 1, 2, 2, 2, 1, 2, 2, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 2, 3, 2, 3, 0, 2, 2, 3, 3, 2, 3, 3, 3, 3, 2, 3, 2, 3, 3, 3, 2, 3, 3, 2, 3, 3, 2, 3, 2, 2, 3, 3, 3, 2, 3, 2, 3, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 2, 3, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 2, 1, 2, 1, 1, 1, 2, 1, 1, 1, 2, 1, 2, 1, 1, 1, 1, 2, 1, 1, 2, 1, 1, 2, 2, 1, 1, 1, 2, 1, 1, 2, 2, 1, 2, 1, 1, 2, 1, 2, 2, 2, 1, 1, 2, 2, 1, 2, 2, 2, 2, 1, 1, 2, 2, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 2, 3, 3, 3, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 2, 3, 3, 2, 3, 3, 3, 2, 3, 2, 2, 3, 3, 3, 2, 3, 2, 3, 2, 3, 2, 2, 2, 2, 3, 2, 2, 2, 2], [1, 1, 0, 1, 0, 1, 1, 1, 1, 1, 2, 2, 0, 1, 2, 2, 0, 2, 3, 2, 3, 2, 3, 2, 3, 2, 2, 3, 2, 3, 3, 2, 3, 3, 2, 2, 3, 3, 3, 2, 2, 3, 2, 3, 2, 3, 3, 2, 3, 2, 2, 2, 3, 2, 3, 2, 2, 2, 2, 3, 2, 3, 2, 2, 2, 2, 3, 2, 2, 1, 1, 2, 1, 1, 2, 1, 1, 2, 2, 1, 2, 2, 1, 2, 2, 1, 1, 1, 1, 1, 1, 2, 2, 1, 1, 1, 2, 2, 1, 1, 2, 1, 1, 2, 1, 2, 2, 1, 2, 2, 2, 2, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 2, 1, 2, 2, 2, 2, 2, 3, 3, 2, 3, 3, 2, 3, 3, 3, 2, 3, 2, 3, 3, 3, 3, 3, 3, 3, 3, 2, 2, 3, 3, 3, 2, 2, 3, 3, 3, 2, 3, 2, 3, 3, 2, 2, 3, 3, 2, 3, 2, 2, 2, 3, 2, 2, 3, 2, 2, 2, 3, 2, 2, 2, 2, 2, 3, 3, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 1, 1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]\n", + "Action list's shape looks like (5,)\n" + ], + "name": "stdout" + }, + { + "output_type": "stream", + "text": [ + "/usr/local/lib/python3.7/dist-packages/numpy/core/_asarray.py:83: VisibleDeprecationWarning: Creating an ndarray from ragged nested sequences (which is a list-or-tuple of lists-or-tuples-or ndarrays with different lengths or shapes) is deprecated. If you meant to do this, you must specify 'dtype=object' when creating the ndarray\n", + " return array(a, dtype, copy=False, order=order)\n" + ], + "name": "stderr" + } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "l7sokqEUtrFY" + }, + "source": [ + "Action 的分布\n" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "WHdAItjj1nxw", + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "5129773b-1f4a-4085-d2bf-3bc2abc2598c" + }, + "source": [ + "distribution = {}\n", + "for actions in action_list:\n", + " for action in actions:\n", + " if action not in distribution.keys():\n", + " distribution[action] = 1\n", + " else:\n", + " distribution[action] += 1\n", + "print(distribution)" + ], + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "text": [ + "{1: 278, 2: 698, 3: 297, 0: 73}\n" + ], + "name": "stdout" + } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "ricE0schY75M" + }, + "source": [ + "儲存 Model Testing的結果\n" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "GZsMkGmIY42b", + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "8c55c932-4654-4f8c-f6b0-fa52ac3e8b96" + }, + "source": [ + "PATH = \"Action_List_test.npy\" # 可以改成你想取的名字或路徑\n", + "np.save(PATH ,np.array(action_list)) " + ], + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "text": [ + "/usr/local/lib/python3.7/dist-packages/ipykernel_launcher.py:2: VisibleDeprecationWarning: Creating an ndarray from ragged nested sequences (which is a list-or-tuple of lists-or-tuples-or ndarrays with different lengths or shapes) is deprecated. If you meant to do this, you must specify 'dtype=object' when creating the ndarray\n", + " \n" + ], + "name": "stderr" + } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "asK7WfbkaLjt" + }, + "source": [ + "### 你要交到JudgeBoi的檔案94這個\n", + "儲存結果到本地端 (就是你的電腦裡拉 = = )\n" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "c-CqyhHzaWAL", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 17 + }, + "outputId": "adfba5e6-a107-49aa-9f98-3c0655c5d6c2" + }, + "source": [ + "from google.colab import files\n", + "files.download(PATH)" + ], + "execution_count": null, + "outputs": [ + { + "output_type": "display_data", + "data": { + "application/javascript": [ + "\n", + " async function download(id, filename, size) {\n", + " if (!google.colab.kernel.accessAllowed) {\n", + " return;\n", + " }\n", + " const div = document.createElement('div');\n", + " const label = document.createElement('label');\n", + " label.textContent = `Downloading \"${filename}\": `;\n", + " div.appendChild(label);\n", + " const progress = document.createElement('progress');\n", + " progress.max = size;\n", + " div.appendChild(progress);\n", + " document.body.appendChild(div);\n", + "\n", + " const buffers = [];\n", + " let downloaded = 0;\n", + "\n", + " const channel = await google.colab.kernel.comms.open(id);\n", + " // Send a message to notify the kernel that we're ready.\n", + " channel.send({})\n", + "\n", + " for await (const message of channel.messages) {\n", + " // Send a message to notify the kernel that we're ready.\n", + " channel.send({})\n", + " if (message.buffers) {\n", + " for (const buffer of message.buffers) {\n", + " buffers.push(buffer);\n", + " downloaded += buffer.byteLength;\n", + " progress.value = downloaded;\n", + " }\n", + " }\n", + " }\n", + " const blob = new Blob(buffers, {type: 'application/binary'});\n", + " const a = document.createElement('a');\n", + " a.href = window.URL.createObjectURL(blob);\n", + " a.download = filename;\n", + " div.appendChild(a);\n", + " a.click();\n", + " div.remove();\n", + " }\n", + " " + ], + "text/plain": [ + "" + ] + }, + "metadata": { + "tags": [] + } + }, + { + "output_type": "display_data", + "data": { + "application/javascript": [ + "download(\"download_5d13b99b-295d-4ab0-814c-b2d0fff26eff\", \"Action_List_test.npy\", 2999)" + ], + "text/plain": [ + "" + ] + }, + "metadata": { + "tags": [] + } + } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "seT4NUmWmAZ1" + }, + "source": [ + "# Server 測試\n", + "到時候下面會是我們Server上測試的環境,可以給大家看一下自己的表現如何" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "U69c-YTxaw6b", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 412 + }, + "outputId": "50015892-29ae-4665-c66f-880aecf7be8f" + }, + "source": [ + "action_list = np.load(PATH,allow_pickle=True) #到時候你上傳的檔案\n", + "seed = 543 #到時候測試的seed 請不要更改\n", + "fix(env, seed)\n", + "\n", + "agent.network.eval() # 測試前先將 network 切換為 evaluation 模式\n", + "\n", + "test_total_reward = []\n", + "for actions in action_list:\n", + " state = env.reset()\n", + " img = plt.imshow(env.render(mode='rgb_array'))\n", + "\n", + " total_reward = 0\n", + "\n", + " done = False\n", + " # while not done:\n", + " done_count = 0\n", + " for action in actions:\n", + " # action, _ = agent1.sample(state)\n", + " state, reward, done, _ = env.step(action)\n", + " done_count += 1\n", + " total_reward += reward\n", + " if done:\n", + " \n", + " break\n", + " # img.set_data(env.render(mode='rgb_array'))\n", + " # display.display(plt.gcf())\n", + " # display.clear_output(wait=True)\n", + " print(f\"Your reward is : %.2f\"%total_reward)\n", + " test_total_reward.append(total_reward)" + ], + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "text": [ + "/usr/local/lib/python3.7/dist-packages/torch/__init__.py:422: UserWarning: torch.set_deterministic is deprecated and will be removed in a future release. Please use torch.use_deterministic_algorithms instead\n", + " \"torch.set_deterministic is deprecated and will be removed in a future \"\n" + ], + "name": "stderr" + }, + { + "output_type": "stream", + "text": [ + "Your reward is : 260.62\n", + "Your reward is : -212.89\n", + "Your reward is : 11.86\n", + "Your reward is : 8.02\n", + "Your reward is : -219.22\n" + ], + "name": "stdout" + }, + { + "output_type": "display_data", + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAW4AAAD8CAYAAABXe05zAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAYXUlEQVR4nO3de5BV5Z3u8e9D0zS3RrobbNvuVkR7NFwcYBDxdgZRGSTWwakQhXNGjUNJrGgl1EzlROdUHZ06yVSlMmNmUpnCIeV1ktHxmIuU4wxj1FQm4y2oqAghoiDQIhi5yGVAuvt3/tir2y1N09fN7rf7+VTt2mu9a629fm+7fXrx9rv3UkRgZmbpGFLsAszMrHsc3GZmiXFwm5klxsFtZpYYB7eZWWIc3GZmiSlYcEuaL2mjpE2S7ijUeczMBhsVYh63pBLgt8BVwHbg18CSiFjf5yczMxtkCnXFPQvYFBHvRsQnwKPAwgKdy8xsUBlaoNetBbblrW8HLuxoZ0n++Kb1mZKSUsaUn8bwoWM50vQx+w/uIqKF8tHVDB86ttevf7hpL/sP7KS5+Sjl5dWMHFbF0eaDHDi0iyNHDvZBD8xyIkLHay9UcHdK0jJgWbHObwPX9OmLmH3eLQwrGc2adx/g+RceoLLyTC6/ZDnnjVuIdNz/F7okIvjN757guf/8W3btepv6+mlcNPXLVAyfwGtbf8ivnv8HDh/e34e9MWuvUEMljUB93npd1tYmIlZGxMyImFmgGmwQqqw8g/qaP2BMWT079r/O5i0v0tR0uEBnCzZvfpEd+9YydMgIaqumU1MzuUDnMvtUoYL710CDpLMkDQMWA6sKdC6zNhMnXkztKTM43LSPxt+9SmPjmwU936FDe3l3y/N8eGg9p46aytkTL2XYsJEFPadZQYZKIqJJ0u3AaqAEuD8i3irEucxaVVVNoP60mYwpq2Pz7l/wzjv/SUtLEwDNzUc59MluNu99ttfnOfTJbpqbj2Zrwdatr3DWGWsZV3sep1dOp67u93n33Rd6fR6zjhRsjDsingKeKtTrm+WThjDhzFmcNmYKh5s+pvGj13j//XVt2z/6aDOvvPYIZWXlvT7XkSP7+eijzW3rBw9+xJZtL3J6xXROG30+EydexLZtr3H0aKGGaGywK9ofJ836Wks08+HBjUSsZ9Om/yCi5TPb33+/cP/o27btNd6ve5WoaGbfvg8YPXo8e/Zs6/xAsx4oyAdwul2EpwNaHygpKaWi4gxaWprYu7exbZjkZBk7tpaamsn89re/oLn5k5N6bhuYOpoO6OA2M+unOgpuf8mUmVliHNxmZolxcJuZJcbBbWaWGAe3mVliHNxmZolxcJuZJcbBbWaWGAe3mVliHNxmZolxcJuZJcbBbWaWGAe3mVliHNxmZolxcJuZJaZXd8CRtAXYDzQDTRExU1Il8M/ABGALcF1E7OldmWZm1qovrrgvj4hpETEzW78DeCYiGoBnsnUzM+sjhRgqWQg8lC0/BFxbgHOYmQ1avQ3uAP5d0iuSlmVt1RGxI1v+AKju5TnMzCxPb+/yfmlENEo6FXha0m/yN0ZEdHQ/ySzolx1vm5mZdazPbhYs6W7gAHALMCcidkiqAX4REed2cqxvFmxmdow+v1mwpFGSyluXgXnAOmAVcFO2203AEz09h5mZtdfjK25JE4GfZqtDgX+KiG9JqgIeA84A3iM3HXB3J6/lK24zs2N0dMXdZ0MlveHgNjNrr8+HSszMrDgc3GZmiXFwm5klxsFtZpYYB7eZWWIc3GZmiXFwm5klxsFtZpYYB7eZWWIc3GZmiXFwm5klxsFtZpYYB7eZWWIc3GZmiXFwm5klxsFtZpYYB7eZWWIc3GZmiXFwm5klptPglnS/pF2S1uW1VUp6WtLb2XNF1i5J35O0SdIbkmYUsngzs8GoK1fcDwLzj2m7A3gmIhqAZ7J1gKuBhuyxDFjRN2WamVmrToM7In4J7D6meSHwULb8EHBtXvvDkfMiMFZSTV8Va2ZmPR/jro6IHdnyB0B1tlwLbMvbb3vW1o6kZZLWSFrTwxrMzAalob19gYgISdGD41YCKwF6cryZ2WDV0yvuna1DINnzrqy9EajP268uazMzsz7S0+BeBdyULd8EPJHXfmM2u2Q2sC9vSMXMzPqAIk48SiHpEWAOMA7YCdwF/Ax4DDgDeA+4LiJ2SxLwfXKzUA4BN0dEp2PYHioxM2svInS89k6D+2RwcJuZtddRcPuTk2ZmiXFwm5klxsFtZpYYB7eZWWIc3GZmiXFwm5klxsFtZpYYB7eZWWIc3GZmiXFwm5klxsFtZpYYB7eZWWIc3GZmiXFwm5klxsFtZpYYB7eZWWIc3GZmiXFwm5klptPglnS/pF2S1uW13S2pUdLa7LEgb9udkjZJ2ijpjwpVuJnZYNWVmwX/N+AA8HBETMna7gYORMRfH7PvJOARYBZwOvBz4PciormTc/iek2Zmx+jxPScj4pfA7i6eZyHwaEQciYjNwCZyIW5mZn2kN2Pct0t6IxtKqcjaaoFteftsz9rakbRM0hpJa3pRg5nZoNPT4F4BnA1MA3YAf9PdF4iIlRExMyJm9rAGM7NBqUfBHRE7I6I5IlqAH/DpcEgjUJ+3a13WZmZmfaRHwS2pJm/1j4HWGSergMWSyiSdBTQAL/euRDMzyze0sx0kPQLMAcZJ2g7cBcyRNA0IYAvwZYCIeEvSY8B6oAm4rbMZJWZm1j2dTgc8KUV4OqCZWTs9ng5oZmb9i4PbzCwxDm4zs8Q4uM3MEuPgNjNLjIPbzCwxDm4zs8Q4uM3MEuPgNjNLjIPbzCwxDm4zs8Q4uM3MEuPgNjNLjIPbzCwxDm4zs8Q4uM3MEuPgNjNLjIPbzCwxnQa3pHpJz0laL+ktSV/L2islPS3p7ey5ImuXpO9J2iTpDUkzCt0JM7PBpCtX3E3An0fEJGA2cJukScAdwDMR0QA8k60DXE3u7u4NwDJgRZ9XbWY2iHUa3BGxIyJezZb3AxuAWmAh8FC220PAtdnyQuDhyHkRGCupps8rNzMbpLo1xi1pAjAdeAmojogd2aYPgOpsuRbYlnfY9qzt2NdaJmmNpDXdrNnMbFDrcnBLGg38GFgeER/nb4uIAKI7J46IlRExMyJmduc4M7PBrkvBLamUXGj/KCJ+kjXvbB0CyZ53Ze2NQH3e4XVZm5mZ9YGuzCoRcB+wISLuydu0CrgpW74JeCKv/cZsdslsYF/ekIqZmfWScqMcJ9hBuhT4D+BNoCVr/gty49yPAWcA7wHXRcTuLOi/D8wHDgE3R8QJx7EldWuYxcxsMIgIHa+90+A+GRzcZmbtdRTc/uSkmVliHNxmZolxcJuZJcbBbWaWGAe3mVliHNxmZolxcJuZJcbBbWaWGAe3mVliHNxmZolxcJuZJcbBbWaWGAe3mVliHNxmZolxcJuZJcbBbWaWGAe3mVliHNxmZonpys2C6yU9J2m9pLckfS1rv1tSo6S12WNB3jF3StokaaOkPypkB8zMBpuu3Cy4BqiJiFcllQOvANcC1wEHIuKvj9l/EvAIMAs4Hfg58HsR0XyCc/iek2Zmx+jxPScjYkdEvJot7wc2ALUnOGQh8GhEHImIzcAmciFuZmZ9oFtj3JImANOBl7Km2yW9Iel+SRVZWy2wLe+w7Zw46M0A+Ku/+jLf/jZMmQKTJsHppxe7opNvzpw5PPjguSxYAJMnw3nnQUlJsauy/mZoV3eUNBr4MbA8Ij6WtAL4v0Bkz38D/Gk3Xm8ZsKx75dpANnXqRGpqYO7c3PqOHbB+fW753/4NNm2CCPjgA2jucOAtbePHj2fWrANMnpxbb2qC55+Ho0dh+3b42c9y7fv2wf79xavTiqtLwS2plFxo/ygifgIQETvztv8AeDJbbQTq8w6vy9o+IyJWAiuz4z3GbW2UjeqdfvqnV92XX54L7eZmWL0a/uu/csH+wx8Wr85Cav0ZlJbCH/5hbjkC/uRPcsvr1sHGjbnlhx+GnTvbv4YNXF2ZVSLgPmBDRNyT116Tt9sfA+uy5VXAYkllks4CGoCX+65kG4xaWnKh3dQEhw7BwYO58B5MWn9xNTfD4cO5n8HBg7mfjQ0uXbnivgS4AXhT0tqs7S+AJZKmkRsq2QJ8GSAi3pL0GLAeaAJuO9GMErN8EbkH5IYG1mbvuNWr4d13c9t27x74YdX6c2hqgmefhU8+gcZGWLUqt/3AgcH3i8s+1WlwR8SvgONNSXnqBMd8C/hWL+qyQejAAfiXf8kNf7S05MZwP/yw2FWdfGvXwg9+AO+9l/s5bN068H9RWfd0+Y+TZoW2dSvcfXexqyi+e+6BNWuKXYX1Zw5uM7M8o0ePpqamhvnz57N69WqO/ZDikSNH2Lp1a5Gqy3Fwm9mgV15ezpgxY/jqV7/KlClTmDdvHkOGDKHlOGNUu3fv5vHHH28X6Dt37uTee+9tt39TUxN79uzp03od3GY26JSUlDBs2DBqa2u5/vrrufrqqzn//PMZNWoUQ4Z8Otkuf7nVqaeeyle+8pV27c3NzXz9619v175r1y4eeOCB47b/8DjzWVtaWjhy5MgJ63dwm9mgIImRI0fyhS98gYsvvphrrrmGsrIyxo0b1yevX1JSQnl5ebv28vJyvvnNb7ZrP3r0KHfddVe79h07drBixQqeeOKJDs/l4DazAa26upqZM2fypS99ialTp9LQ0HDcK+mTrbS0lNra9t8GUltby3333cfrr7/e4bEObjMbcCZOnEhdXR3f+MY3mDBhApMmTSp2SX3KwW1mA8KECRNoaGjg1ltvZcaMGZx55plIx/1W1OQ5uM0sSSUlJVRVVTFnzhwuv/xyFi1aREVFBSWD4OsUHdxmlozW2SCLFy9m6tSpLF26lLKyMsrKyopd2knl4Dazfk0SI0aMYNGiRcyePZuFCxcybtw4hg0bVuzSisbBbWb9UnV1NRdccAE33ngj559/fr+ZDdIfOLjNrOgmT57MmDFjALjiiiuYN28elZWVTG69o4R9hoPbzE6a+vp6RowYAcDcuXOZN28eAJdddhlVVVVt+w3U2SB9xcFtZn2uoqKibQz6sssu44orrgDgmmuu4bTTTgNyHyf30EfPOLjNrMfKysrapt9deOGFXHbZZQAsXryY+vrcHQyHDRs2qP+QWAgObjPrktar46lTp3LBBRcAcPPNN3PWWWcBMGrUqLZxaissB7eZHdfZZ5/d9lHxuro6br/9dgCqqqqorq4uZmmDXqfBLWk48EugLNv/8Yi4K7sR8KNAFfAKcENEfCKpDHgY+APgI+D6iNhSoPrNrIcaGho+88145557LkuXLm1br62tbbuatv6lK1fcR4C5EXFAUinwK0n/CvwZ8N2IeFTSvcBSYEX2vCcizpG0GPg2cH2B6jezY1RUVHDqqad+pu3zn/982/hzq5kzZ7b7djrP5khDV24WHMCBbLU0ewQwF/gfWftDwN3kgnthtgzwOPB9SYpjbxdhZt2WP1uj1a233kpNTU3b+pQpU5g9e/Zn9pHkGRwDSJfGuCWVkBsOOQf4e+AdYG9ENGW7bAdaf3XXAtsAIqJJ0j5ywym/68O6LVElJSUdfq/E888/z6hRo9rdEmow2bJlC/v27WPkyJEsWLCg3QdQFi1a1G74Yvjw4YPii5XsU10K7ohoBqZJGgv8FDivtyeWtAxY1tvXsd6R1KN/Hk+dOpULL7yw28edc8453HDDDd0+bjA65ZRT2j6sYpavW7NKImKvpOeAi4CxkoZmV911QGO2WyNQD2yXNBQ4hdwfKY99rZXASgBJg/cSq49Nnz6dM844o8v7X3fddcyYMaPb56msrGw3jmpmJ0dXZpWMB45moT0CuIrcHxyfAxaRm1lyE9B6g7RV2foL2fZnPb7dsTFjxjB16tQOt3/xi19smzPbFeecc44D1WyA68oVdw3wUDbOPQR4LCKelLQeeFTSN4HXgPuy/e8D/lHSJmA3sLgAdSfnzDPPZOTIkQBcddVVbR8Brqys5NJLLy1maWaWmK7MKnkDmH6c9neBWcdpPwx8sU+qS1BVVRWlpaWUl5ezfPlySktLgdx3NLReCfsv/GbWG/3ik5OSkppJUFJSwvDhw4Hcx3xvueWWtoBesmQJtbW1bV/+7nmxZtbX+kVwT5w4kSuvvPKE+zz11FO8//77HW6PCFpaWvq0riFDhrQF7yWXXMLnPvc5IPcJsyVLlgCf3vfOV9BmdrL0i+AeO3Ys99577wn32bZtG4cOHepw+8GDB/nOd77D4cOHO9xn586dvPDCCx1uHzVqFFdeeWXbFLnly5e3fSdDdXU1Y8eO7aQnZmaF1y+CuytavyLyRB555JETbt+zZw8bNmzocPuIESOYNm2ahzfMrF9LJrj7QkVFBRdffHGxyzAz6xUPzJqZJcbBbWaWGAe3mVliHNxmZolxcJuZJcbBbWaWGAe3mVliHNxmZolxcJuZJcbBbWaWGAe3mVliHNxmZolxcJuZJabT4JY0XNLLkl6X9Jakv8zaH5S0WdLa7DEta5ek70naJOkNSd2/hbiZmXWoK1/regSYGxEHJJUCv5L0r9m2r0fE48fsfzXQkD0uBFZkz2Zm1gc6veKOnAPZamn2ONENIhcCD2fHvQiMlVTT+1LNzAy6OMYtqUTSWmAX8HREvJRt+lY2HPJdSWVZWy2wLe/w7VmbmZn1gS4Fd0Q0R8Q0oA6YJWkKcCdwHnABUAl8ozsnlrRM0hpJaz788MNulm1mNnh1a1ZJROwFngPmR8SObDjkCPAAMCvbrRHIv0FkXdZ27GutjIiZETFz/PjxPavezGwQ6sqskvGSxmbLI4CrgN+0jlsrd2fda4F12SGrgBuz2SWzgX0RsaMg1ZuZDUJdmVVSAzwkqYRc0D8WEU9KelbSeEDAWuDWbP+ngAXAJuAQcHPfl21mNnh1GtwR8QYw/TjtczvYP4Dbel+amZkdjz85aWaWGAe3mVliHNxmZolxcJuZJcbBbWaWGAe3mVliHNxmZolxcJuZJcbBbWaWGAe3mVliHNxmZolxcJuZJcbBbWaWGAe3mVliHNxmZolxcJuZJcbBbWaWGAe3mVliHNxmZolxcJuZJcbBbWaWGAe3mVliFBHFrgFJ+4GNxa6jQMYBvyt2EQUwUPsFA7dv7ldazoyI8cfbMPRkV9KBjRExs9hFFIKkNQOxbwO1XzBw++Z+DRweKjEzS4yD28wsMf0luFcWu4ACGqh9G6j9goHbN/drgOgXf5w0M7Ou6y9X3GZm1kVFD25J8yVtlLRJ0h3Frqe7JN0vaZekdXltlZKelvR29lyRtUvS97K+viFpRvEqPzFJ9ZKek7Re0luSvpa1J903ScMlvSzp9axff5m1nyXppaz+f5Y0LGsvy9Y3ZdsnFLP+zkgqkfSapCez9YHSry2S3pS0VtKarC3p92JvFDW4JZUAfw9cDUwClkiaVMyaeuBBYP4xbXcAz0REA/BMtg65fjZkj2XAipNUY080AX8eEZOA2cBt2X+b1Pt2BJgbEb8PTAPmS5oNfBv4bkScA+wBlmb7LwX2ZO3fzfbrz74GbMhbHyj9Arg8IqblTf1L/b3YcxFRtAdwEbA6b/1O4M5i1tTDfkwA1uWtbwRqsuUacvPUAf4BWHK8/fr7A3gCuGog9Q0YCbwKXEjuAxxDs/a29yWwGrgoWx6a7adi195Bf+rIBdhc4ElAA6FfWY1bgHHHtA2Y92J3H8UeKqkFtuWtb8/aUlcdETuy5Q+A6mw5yf5m/4yeDrzEAOhbNpywFtgFPA28A+yNiKZsl/za2/qVbd8HVJ3cirvsb4H/BbRk61UMjH4BBPDvkl6RtCxrS/692FP95ZOTA1ZEhKRkp+5IGg38GFgeER9LatuWat8iohmYJmks8FPgvCKX1GuSrgF2RcQrkuYUu54CuDQiGiWdCjwt6Tf5G1N9L/ZUsa+4G4H6vPW6rC11OyXVAGTPu7L2pPorqZRcaP8oIn6SNQ+IvgFExF7gOXJDCGMltV7I5Nfe1q9s+ynARye51K64BPjvkrYAj5IbLvk70u8XABHRmD3vIvfLdhYD6L3YXcUO7l8DDdlfvocBi4FVRa6pL6wCbsqWbyI3PtzafmP2V+/ZwL68f+r1K8pdWt8HbIiIe/I2Jd03SeOzK20kjSA3br+BXIAvynY7tl+t/V0EPBvZwGl/EhF3RkRdREwg9//RsxHxP0m8XwCSRkkqb10G5gHrSPy92CvFHmQHFgC/JTfO+L+LXU8P6n8E2AEcJTeWtpTcWOEzwNvAz4HKbF+Rm0XzDvAmMLPY9Z+gX5eSG1d8A1ibPRak3jfgfOC1rF/rgP+TtU8EXgY2Af8PKMvah2frm7LtE4vdhy70cQ7w5EDpV9aH17PHW605kfp7sTcPf3LSzCwxxR4qMTOzbnJwm5klxsFtZpYYB7eZWWIc3GZmiXFwm5klxsFtZpYYB7eZWWL+P8eEF+CNkxTNAAAAAElFTkSuQmCC\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "tags": [], + "needs_background": "light" + } + } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "TjFBWwQP1hVe" + }, + "source": [ + "# 你的成績" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "GpJpZz3Wbm0X", + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "f1b08157-bec6-4c5a-8021-482f719b4ade" + }, + "source": [ + "print(f\"Your final reward is : %.2f\"%np.mean(test_total_reward))" + ], + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "text": [ + "Your final reward is : -30.32\n" + ], + "name": "stdout" + } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "wUBtYXG2eaqf" + }, + "source": [ + "## 參考資料\n", + "\n", + "以下是一些有用的參考資料。\n", + "建議同學們實做前,可以先參考第一則連結的上課影片。\n", + "在影片的最後有提到兩個有用的 Tips,這對於本次作業的實做非常有幫助。\n", + "\n", + "- [DRL Lecture 1: Policy Gradient (Review)](https://youtu.be/z95ZYgPgXOY)\n", + "- [ML Lecture 23-3: Reinforcement Learning (including Q-learning) start at 30:00](https://youtu.be/2-JNBzCq77c?t=1800)\n", + "- [Lecture 7: Policy Gradient, David Silver](http://www0.cs.ucl.ac.uk/staff/d.silver/web/Teaching_files/pg.pdf)\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "cGqP2EU1joWM" + }, + "source": [ + "" + ] + } + ] +} \ No newline at end of file diff --git a/11 Quantum ML/作业HW12/hw12_reinforcement_learning_english_version.ipynb b/11 Quantum ML/作业HW12/hw12_reinforcement_learning_english_version.ipynb new file mode 100644 index 0000000..5694e58 --- /dev/null +++ b/11 Quantum ML/作业HW12/hw12_reinforcement_learning_english_version.ipynb @@ -0,0 +1,3092 @@ +{ + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "colab": { + "name": "hw12_reinforcement_learning_english_version.ipynb", + "provenance": [], + "collapsed_sections": [] + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + }, + "widgets": { + "application/vnd.jupyter.widget-state+json": { + "de3a153737af485ea436d7e8393d8248": { + "model_module": "@jupyter-widgets/controls", + "model_name": "HBoxModel", + "state": { + "_view_name": "HBoxView", + "_dom_classes": [], + "_model_name": "HBoxModel", + "_view_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_view_count": null, + "_view_module_version": "1.5.0", + "box_style": "", + "layout": "IPY_MODEL_6345f8926212465291c04587353161f1", + "_model_module": "@jupyter-widgets/controls", + "children": [ + "IPY_MODEL_3aa84c8c097d4858a0c38f18dec8b060", + "IPY_MODEL_6647e68cf064416ca593b990bed81edf" + ] + } + }, + "6345f8926212465291c04587353161f1": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "state": { + "_view_name": "LayoutView", + "grid_template_rows": null, + "right": null, + "justify_content": null, + "_view_module": "@jupyter-widgets/base", + "overflow": null, + "_model_module_version": "1.2.0", + "_view_count": null, + "flex_flow": null, + "width": null, + "min_width": null, + "border": null, + "align_items": null, + "bottom": null, + "_model_module": "@jupyter-widgets/base", + "top": null, + "grid_column": null, + "overflow_y": null, + "overflow_x": null, + "grid_auto_flow": null, + "grid_area": null, + "grid_template_columns": null, + "flex": null, + "_model_name": "LayoutModel", + "justify_items": null, + "grid_row": null, + "max_height": null, + "align_content": null, + "visibility": null, + "align_self": null, + "height": null, + "min_height": null, + "padding": null, + "grid_auto_rows": null, + "grid_gap": null, + "max_width": null, + "order": null, + "_view_module_version": "1.2.0", + "grid_template_areas": null, + "object_position": null, + "object_fit": null, + "grid_auto_columns": null, + "margin": null, + "display": null, + "left": null + } + }, + "3aa84c8c097d4858a0c38f18dec8b060": { + "model_module": "@jupyter-widgets/controls", + "model_name": "FloatProgressModel", + "state": { + "_view_name": "ProgressView", + "style": "IPY_MODEL_a25de7edbbee47cc8094f43125efa39b", + "_dom_classes": [], + "description": "Total: 86.3, Final: 0.0: 100%", + "_model_name": "FloatProgressModel", + "bar_style": "success", + "max": 400, + "_view_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "value": 400, + "_view_count": null, + "_view_module_version": "1.5.0", + "orientation": "horizontal", + "min": 0, + "description_tooltip": null, + "_model_module": "@jupyter-widgets/controls", + "layout": "IPY_MODEL_a0bb5061ca0946ab89a89f9abadcebc8" + } + }, + "6647e68cf064416ca593b990bed81edf": { + "model_module": "@jupyter-widgets/controls", + "model_name": "HTMLModel", + "state": { + "_view_name": "HTMLView", + "style": "IPY_MODEL_d99d8e6bcbe0445eb7eddbfe31277635", + "_dom_classes": [], + "description": "", + "_model_name": "HTMLModel", + "placeholder": "​", + "_view_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "value": " 400/400 [11:36<00:00, 1.74s/it]", + "_view_count": null, + "_view_module_version": "1.5.0", + "description_tooltip": null, + "_model_module": "@jupyter-widgets/controls", + "layout": "IPY_MODEL_5af1c6579f0f4c2eb03df4063749569e" + } + }, + "a25de7edbbee47cc8094f43125efa39b": { + "model_module": "@jupyter-widgets/controls", + "model_name": "ProgressStyleModel", + "state": { + "_view_name": "StyleView", + "_model_name": "ProgressStyleModel", + "description_width": "initial", + "_view_module": "@jupyter-widgets/base", + "_model_module_version": "1.5.0", + "_view_count": null, + "_view_module_version": "1.2.0", + "bar_color": null, + "_model_module": "@jupyter-widgets/controls" + } + }, + "a0bb5061ca0946ab89a89f9abadcebc8": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "state": { + "_view_name": "LayoutView", + "grid_template_rows": null, + "right": null, + "justify_content": null, + "_view_module": "@jupyter-widgets/base", + "overflow": null, + "_model_module_version": "1.2.0", + "_view_count": null, + "flex_flow": null, + "width": null, + "min_width": null, + "border": null, + "align_items": null, + "bottom": null, + "_model_module": "@jupyter-widgets/base", + "top": null, + "grid_column": null, + "overflow_y": null, + "overflow_x": null, + "grid_auto_flow": null, + "grid_area": null, + "grid_template_columns": null, + "flex": null, + "_model_name": "LayoutModel", + "justify_items": null, + "grid_row": null, + "max_height": null, + "align_content": null, + "visibility": null, + "align_self": null, + "height": null, + "min_height": null, + "padding": null, + "grid_auto_rows": null, + "grid_gap": null, + "max_width": null, + "order": null, + "_view_module_version": "1.2.0", + "grid_template_areas": null, + "object_position": null, + "object_fit": null, + "grid_auto_columns": null, + "margin": null, + "display": null, + "left": null + } + }, + "d99d8e6bcbe0445eb7eddbfe31277635": { + "model_module": "@jupyter-widgets/controls", + "model_name": "DescriptionStyleModel", + "state": { + "_view_name": "StyleView", + "_model_name": "DescriptionStyleModel", + "description_width": "", + "_view_module": "@jupyter-widgets/base", + "_model_module_version": "1.5.0", + "_view_count": null, + "_view_module_version": "1.2.0", + "_model_module": "@jupyter-widgets/controls" + } + }, + "5af1c6579f0f4c2eb03df4063749569e": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "state": { + "_view_name": "LayoutView", + "grid_template_rows": null, + "right": null, + "justify_content": null, + "_view_module": "@jupyter-widgets/base", + "overflow": null, + "_model_module_version": "1.2.0", + "_view_count": null, + "flex_flow": null, + "width": null, + "min_width": null, + "border": null, + "align_items": null, + "bottom": null, + "_model_module": "@jupyter-widgets/base", + "top": null, + "grid_column": null, + "overflow_y": null, + "overflow_x": null, + "grid_auto_flow": null, + "grid_area": null, + "grid_template_columns": null, + "flex": null, + "_model_name": "LayoutModel", + "justify_items": null, + "grid_row": null, + "max_height": null, + "align_content": null, + "visibility": null, + "align_self": null, + "height": null, + "min_height": null, + "padding": null, + "grid_auto_rows": null, + "grid_gap": null, + "max_width": null, + "order": null, + "_view_module_version": "1.2.0", + "grid_template_areas": null, + "object_position": null, + "object_fit": null, + "grid_auto_columns": null, + "margin": null, + "display": null, + "left": null + } + } + } + } + }, + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "Fp30SB4bxeQb" + }, + "source": [ + "# **Homework 12 - Reinforcement Learning**\n", + "\n", + "If you have any problem, e-mail us at ntu-ml-2021spring-ta@googlegroups.com\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "yXsnCWPtWSNk" + }, + "source": [ + "## Preliminary work\n", + "\n", + "First, we need to install all necessary packages.\n", + "One of them, gym, builded by OpenAI, is a toolkit for developing Reinforcement Learning algorithm. Other packages are for visualization in colab." + ] + }, + { + "cell_type": "code", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "5e2bScpnkVbv", + "outputId": "52198e39-e2a2-4ea2-a1f3-4ba9545476d7" + }, + "source": [ + "!apt update\n", + "!apt install python-opengl xvfb -y\n", + "!pip install gym[box2d]==0.18.3 pyvirtualdisplay tqdm numpy==1.19.5 torch==1.8.1" + ], + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "text": [ + "Hit:1 http://security.ubuntu.com/ubuntu bionic-security InRelease\n", + "Ign:2 https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64 InRelease\n", + "Hit:3 https://cloud.r-project.org/bin/linux/ubuntu bionic-cran40/ InRelease\n", + "Ign:4 https://developer.download.nvidia.com/compute/machine-learning/repos/ubuntu1804/x86_64 InRelease\n", + "Hit:5 http://archive.ubuntu.com/ubuntu bionic InRelease\n", + "Hit:6 https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64 Release\n", + "Hit:7 http://ppa.launchpad.net/c2d4u.team/c2d4u4.0+/ubuntu bionic InRelease\n", + "Hit:8 https://developer.download.nvidia.com/compute/machine-learning/repos/ubuntu1804/x86_64 Release\n", + "Hit:9 http://archive.ubuntu.com/ubuntu bionic-updates InRelease\n", + "Hit:10 http://archive.ubuntu.com/ubuntu bionic-backports InRelease\n", + "Hit:11 http://ppa.launchpad.net/cran/libgit2/ubuntu bionic InRelease\n", + "Hit:12 http://ppa.launchpad.net/deadsnakes/ppa/ubuntu bionic InRelease\n", + "Hit:13 http://ppa.launchpad.net/graphics-drivers/ppa/ubuntu bionic InRelease\n", + "Reading package lists... Done\n", + "Building dependency tree \n", + "Reading state information... Done\n", + "86 packages can be upgraded. Run 'apt list --upgradable' to see them.\n", + "Reading package lists... Done\n", + "Building dependency tree \n", + "Reading state information... Done\n", + "python-opengl is already the newest version (3.1.0+dfsg-1).\n", + "xvfb is already the newest version (2:1.19.6-1ubuntu4.9).\n", + "The following package was automatically installed and is no longer required:\n", + " libnvidia-common-460\n", + "Use 'apt autoremove' to remove it.\n", + "0 upgraded, 0 newly installed, 0 to remove and 86 not upgraded.\n", + "Requirement already satisfied: gym[box2d] in /usr/local/lib/python3.7/dist-packages (0.17.3)\n", + "Requirement already satisfied: pyvirtualdisplay in /usr/local/lib/python3.7/dist-packages (2.1)\n", + "Requirement already satisfied: tqdm in /usr/local/lib/python3.7/dist-packages (4.41.1)\n", + "Requirement already satisfied: numpy>=1.10.4 in /usr/local/lib/python3.7/dist-packages (from gym[box2d]) (1.19.5)\n", + "Requirement already satisfied: scipy in /usr/local/lib/python3.7/dist-packages (from gym[box2d]) (1.4.1)\n", + "Requirement already satisfied: pyglet<=1.5.0,>=1.4.0 in /usr/local/lib/python3.7/dist-packages (from gym[box2d]) (1.5.0)\n", + "Requirement already satisfied: cloudpickle<1.7.0,>=1.2.0 in /usr/local/lib/python3.7/dist-packages (from gym[box2d]) (1.3.0)\n", + "Requirement already satisfied: box2d-py~=2.3.5; extra == \"box2d\" in /usr/local/lib/python3.7/dist-packages (from gym[box2d]) (2.3.8)\n", + "Requirement already satisfied: EasyProcess in /usr/local/lib/python3.7/dist-packages (from pyvirtualdisplay) (0.3)\n", + "Requirement already satisfied: future in /usr/local/lib/python3.7/dist-packages (from pyglet<=1.5.0,>=1.4.0->gym[box2d]) (0.16.0)\n" + ], + "name": "stdout" + } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "M_-i3cdoYsks" + }, + "source": [ + "\n", + "Next, set up virtual display,and import all necessaary packages." + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "nl2nREINDLiw" + }, + "source": [ + "%%capture\n", + "from pyvirtualdisplay import Display\n", + "virtual_display = Display(visible=0, size=(1400, 900))\n", + "virtual_display.start()\n", + "\n", + "%matplotlib inline\n", + "import matplotlib.pyplot as plt\n", + "\n", + "from IPython import display\n", + "\n", + "import numpy as np\n", + "import torch\n", + "import torch.nn as nn\n", + "import torch.optim as optim\n", + "import torch.nn.functional as F\n", + "from torch.distributions import Categorical\n", + "from tqdm.notebook import tqdm" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "CaEJ8BUCpN9P" + }, + "source": [ + "# Warning ! Do not revise random seed !!!\n", + "# Your submission on JudgeBoi will not reproduce your result !!!\n", + "Make your HW result to be reproducible.\n" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "fV9i8i2YkRbO" + }, + "source": [ + "seed = 543 # Do not change this\n", + "def fix(env, seed):\n", + " env.seed(seed)\n", + " env.action_space.seed(seed)\n", + " torch.manual_seed(seed)\n", + " torch.cuda.manual_seed(seed)\n", + " torch.cuda.manual_seed_all(seed)\n", + " np.random.seed(seed)\n", + " random.seed(seed)\n", + " torch.set_deterministic(True)\n", + " torch.backends.cudnn.benchmark = False\n", + " torch.backends.cudnn.deterministic = True" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "He0XDx6bzjgC" + }, + "source": [ + "Last, call gym and build an [Lunar Lander](https://gym.openai.com/envs/LunarLander-v2/) environment." + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "N_4-xJcbBt09" + }, + "source": [ + "%%capture\n", + "import gym\n", + "import random\n", + "env = gym.make('LunarLander-v2')\n", + "fix(env, seed) # fix the environment Do not revise this !!!" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "NrkVvTrvWZ5H" + }, + "source": [ + "## What Lunar Lander?\n", + "\n", + "“LunarLander-v2”is to simulate the situation when the craft lands on the surface of the moon.\n", + "\n", + "This task is to enable the craft to land \"safely\" at the pad between the two yellow flags.\n", + "> Landing pad is always at coordinates (0,0).\n", + "> Coordinates are the first two numbers in state vector.\n", + "\n", + "![](https://gym.openai.com/assets/docs/aeloop-138c89d44114492fd02822303e6b4b07213010bb14ca5856d2d49d6b62d88e53.svg)\n", + "\n", + "\"LunarLander-v2\" actually includes \"Agent\" and \"Environment\". \n", + "\n", + "In this homework, we will utilize the function `step()` to control the action of \"Agent\". \n", + "\n", + "Then `step()` will return the observation/state and reward given by the \"Environment\"." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "bIbp82sljvAt" + }, + "source": [ + "### Observation / State\n", + "\n", + "First, we can take a look at what an Observation / State looks like." + ] + }, + { + "cell_type": "code", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "rsXZra3N9R5T", + "outputId": "a36868de-bbbc-4de9-815b-0b43fc012c96" + }, + "source": [ + "print(env.observation_space)" + ], + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "text": [ + "Box(-inf, inf, (8,), float32)\n" + ], + "name": "stdout" + } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "ezdfoThbAQ49" + }, + "source": [ + "\n", + "`Box(8,)`means that observation is an 8-dim vector\n", + "### Action\n", + "\n", + "Actions can be taken by looks like" + ] + }, + { + "cell_type": "code", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "p1k4dIrBAaKi", + "outputId": "80c453ee-539f-4e40-c5d8-8e9dc8fffaef" + }, + "source": [ + "print(env.action_space)" + ], + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "text": [ + "Discrete(4)\n" + ], + "name": "stdout" + } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "dejXT6PHBrPn" + }, + "source": [ + "`Discrete(4)` implies that there are four kinds of actions can be taken by agent.\n", + "- 0 implies the agent will not take any actions\n", + "- 2 implies the agent will accelerate downward\n", + "- 1, 3 implies the agent will accelerate left and right\n", + "\n", + "Next, we will try to make the agent interact with the environment. \n", + "Before taking any actions, we recommend to call `reset()` function to reset the environment. Also, this function will return the initial state of the environment." + ] + }, + { + "cell_type": "code", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "pi4OmrmZgnWA", + "outputId": "2635bdfb-a4dc-442b-a21a-f57af67edc4b" + }, + "source": [ + "initial_state = env.reset()\n", + "print(initial_state)" + ], + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "text": [ + "[ 0.00396109 1.4083536 0.40119505 -0.11407257 -0.00458307 -0.09087662\n", + " 0. 0. ]\n" + ], + "name": "stdout" + } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "uBx0mEqqgxJ9" + }, + "source": [ + "Then, we try to get a random action from the agent's action space." + ] + }, + { + "cell_type": "code", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "vxkOEXRKgizt", + "outputId": "de93c740-f01c-464e-f436-a2b59e7dc7e5" + }, + "source": [ + "random_action = env.action_space.sample()\n", + "print(random_action)" + ], + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "text": [ + "0\n" + ], + "name": "stdout" + } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "mns-bO01g0-J" + }, + "source": [ + "More, we can utilize `step()` to make agent act according to the randomly-selected `random_action`.\n", + "The `step()` function will return four values:\n", + "- observation / state\n", + "- reward\n", + "- done (True/ False)\n", + "- Other information" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "E_WViSxGgIk9" + }, + "source": [ + "observation, reward, done, info = env.step(random_action)" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "yK7r126kuCNp", + "outputId": "2f3363d9-5bc3-4ba5-86f2-1c4abc89f179" + }, + "source": [ + "print(done)" + ], + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "text": [ + "False\n" + ], + "name": "stdout" + } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "GKdS8vOihxhc" + }, + "source": [ + "### Reward\n", + "\n", + "\n", + "> Landing pad is always at coordinates (0,0). Coordinates are the first two numbers in state vector. Reward for moving from the top of the screen to landing pad and zero speed is about 100..140 points. If lander moves away from landing pad it loses reward back. Episode finishes if the lander crashes or comes to rest, receiving additional -100 or +100 points. Each leg ground contact is +10. Firing main engine is -0.3 points each frame. Solved is 200 points. " + ] + }, + { + "cell_type": "code", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "vxQNs77hi0_7", + "outputId": "4633d678-be4f-4f52-8f91-1b6681642580" + }, + "source": [ + "print(reward)" + ], + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "text": [ + "-0.8588900517154912\n" + ], + "name": "stdout" + } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Mhqp6D-XgHpe" + }, + "source": [ + "### Random Agent\n", + "In the end, before we start training, we can see whether a random agent can successfully land the moon or not." + ] + }, + { + "cell_type": "code", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 269 + }, + "id": "Y3G0bxoccelv", + "outputId": "11ad28c1-058b-4243-bf35-1fdfbb60be9e" + }, + "source": [ + "env.reset()\n", + "\n", + "img = plt.imshow(env.render(mode='rgb_array'))\n", + "\n", + "done = False\n", + "while not done:\n", + " action = env.action_space.sample()\n", + " observation, reward, done, _ = env.step(action)\n", + "\n", + " img.set_data(env.render(mode='rgb_array'))\n", + " display.display(plt.gcf())\n", + " display.clear_output(wait=True)" + ], + "execution_count": null, + "outputs": [ + { + "output_type": "display_data", + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "tags": [], + "needs_background": "light" + } + } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "F5paWqo7tWL2" + }, + "source": [ + "## Policy Gradient\n", + "Now, we can build a simple policy network. The network will return one of action in the action space." + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "J8tdmeD-tZew" + }, + "source": [ + "class PolicyGradientNetwork(nn.Module):\n", + "\n", + " def __init__(self):\n", + " super().__init__()\n", + " self.fc1 = nn.Linear(8, 16)\n", + " self.fc2 = nn.Linear(16, 16)\n", + " self.fc3 = nn.Linear(16, 4)\n", + "\n", + " def forward(self, state):\n", + " hid = torch.tanh(self.fc1(state))\n", + " hid = torch.tanh(self.fc2(hid))\n", + " return F.softmax(self.fc3(hid), dim=-1)" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "ynbqJrhIFTC3" + }, + "source": [ + "Then, we need to build a simple agent. The agent will acts according to the output of the policy network above. There are a few things can be done by agent:\n", + "- `learn()`:update the policy network from log probabilities and rewards.\n", + "- `sample()`:After receiving observation from the environment, utilize policy network to tell which action to take. The return values of this function includes action and log probabilities. " + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "zZo-IxJx286z" + }, + "source": [ + "from torch.optim.lr_scheduler import StepLR\n", + "class PolicyGradientAgent():\n", + " \n", + " def __init__(self, network):\n", + " self.network = network\n", + " self.optimizer = optim.SGD(self.network.parameters(), lr=0.001)\n", + " \n", + " def forward(self, state):\n", + " return self.network(state)\n", + " def learn(self, log_probs, rewards):\n", + " loss = (-log_probs * rewards).sum() # You don't need to revise this to pass simple baseline (but you can)\n", + "\n", + " self.optimizer.zero_grad()\n", + " loss.backward()\n", + " self.optimizer.step()\n", + " \n", + " def sample(self, state):\n", + " action_prob = self.network(torch.FloatTensor(state))\n", + " action_dist = Categorical(action_prob)\n", + " action = action_dist.sample()\n", + " log_prob = action_dist.log_prob(action)\n", + " return action.item(), log_prob" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "ehPlnTKyRZf9" + }, + "source": [ + "Lastly, build a network and agent to start training." + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "GfJIvML-RYjL" + }, + "source": [ + "network = PolicyGradientNetwork()\n", + "agent = PolicyGradientAgent(network)" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "ouv23glgf5Qt" + }, + "source": [ + "## Trainin Agent\n", + "\n", + "Now let's start to train our agent.\n", + "Through taking all the interactions between agent and environment as training data, the policy network can learn from all these attempts," + ] + }, + { + "cell_type": "code", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 1000, + "referenced_widgets": [ + "de3a153737af485ea436d7e8393d8248", + "6345f8926212465291c04587353161f1", + "3aa84c8c097d4858a0c38f18dec8b060", + "6647e68cf064416ca593b990bed81edf", + "a25de7edbbee47cc8094f43125efa39b", + "a0bb5061ca0946ab89a89f9abadcebc8", + "d99d8e6bcbe0445eb7eddbfe31277635", + "5af1c6579f0f4c2eb03df4063749569e" + ] + }, + "id": "vg5rxBBaf38_", + "outputId": "a1b06e39-99d6-4233-eda0-a3d58e77ffed" + }, + "source": [ + "agent.network.train() # Switch network into training mode \n", + "EPISODE_PER_BATCH = 5 # update the agent every 5 episode\n", + "NUM_BATCH = 400 # totally update the agent for 400 time\n", + "\n", + "avg_total_rewards, avg_final_rewards = [], []\n", + "\n", + "prg_bar = tqdm(range(NUM_BATCH))\n", + "for batch in prg_bar:\n", + "\n", + " log_probs, rewards = [], []\n", + " total_rewards, final_rewards = [], []\n", + "\n", + " # collect trajectory\n", + " for episode in range(EPISODE_PER_BATCH):\n", + " \n", + " state = env.reset()\n", + " total_reward, total_step = 0, 0\n", + " seq_rewards = []\n", + " while True:\n", + "\n", + " action, log_prob = agent.sample(state) # at, log(at|st)\n", + " next_state, reward, done, _ = env.step(action)\n", + "\n", + " log_probs.append(log_prob) # [log(a1|s1), log(a2|s2), ...., log(at|st)]\n", + " # seq_rewards.append(reward)\n", + " state = next_state\n", + " total_reward += reward\n", + " total_step += 1\n", + " rewards.append(reward) # change here\n", + " # ! IMPORTANT !\n", + " # Current reward implementation: immediate reward, given action_list : a1, a2, a3 ......\n", + " # rewards : r1, r2 ,r3 ......\n", + " # medium:change \"rewards\" to accumulative decaying reward, given action_list : a1, a2, a3, ......\n", + " # rewards : r1+0.99*r2+0.99^2*r3+......, r2+0.99*r3+0.99^2*r4+...... , r3+0.99*r4+0.99^2*r5+ ......\n", + " # boss : implement DQN\n", + " if done:\n", + " final_rewards.append(reward)\n", + " total_rewards.append(total_reward)\n", + " \n", + " break\n", + "\n", + " print(f\"rewards looks like \", np.shape(rewards)) \n", + " print(f\"log_probs looks like \", np.shape(log_probs)) \n", + " # record training process\n", + " avg_total_reward = sum(total_rewards) / len(total_rewards)\n", + " avg_final_reward = sum(final_rewards) / len(final_rewards)\n", + " avg_total_rewards.append(avg_total_reward)\n", + " avg_final_rewards.append(avg_final_reward)\n", + " prg_bar.set_description(f\"Total: {avg_total_reward: 4.1f}, Final: {avg_final_reward: 4.1f}\")\n", + "\n", + " # update agent\n", + " # rewards = np.concatenate(rewards, axis=0)\n", + " rewards = (rewards - np.mean(rewards)) / (np.std(rewards) + 1e-9) # normalize the reward \n", + " agent.learn(torch.stack(log_probs), torch.from_numpy(rewards))\n", + " print(\"logs prob looks like \", torch.stack(log_probs).size())\n", + " print(\"torch.from_numpy(rewards) looks like \", torch.from_numpy(rewards).size())" + ], + "execution_count": null, + "outputs": [ + { + "output_type": "display_data", + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "de3a153737af485ea436d7e8393d8248", + "version_minor": 0, + "version_major": 2 + }, + "text/plain": [ + "HBox(children=(FloatProgress(value=0.0, max=400.0), HTML(value='')))" + ] + }, + "metadata": { + "tags": [] + } + }, + { + "output_type": "stream", + "text": [ + "rewards looks like (448,)\n", + "log_probs looks like (448,)\n", + "logs prob looks like torch.Size([448])\n", + "torch.from_numpy(rewards) looks like torch.Size([448])\n", + "rewards looks like (515,)\n", + "log_probs looks like (515,)\n", + "logs prob looks like torch.Size([515])\n", + "torch.from_numpy(rewards) looks like torch.Size([515])\n", + "rewards looks like (392,)\n", + "log_probs looks like (392,)\n", + "logs prob looks like torch.Size([392])\n", + "torch.from_numpy(rewards) looks like torch.Size([392])\n", + "rewards looks like (518,)\n", + "log_probs looks like (518,)\n", + "logs prob looks like torch.Size([518])\n", + "torch.from_numpy(rewards) looks like torch.Size([518])\n", + "rewards looks like (472,)\n", + "log_probs looks like (472,)\n", + "logs prob looks like torch.Size([472])\n", + "torch.from_numpy(rewards) looks like torch.Size([472])\n", + "rewards looks like (530,)\n", + "log_probs looks like (530,)\n", + "logs prob looks like torch.Size([530])\n", + "torch.from_numpy(rewards) looks like torch.Size([530])\n", + "rewards looks like (463,)\n", + "log_probs looks like (463,)\n", + "logs prob looks like torch.Size([463])\n", + "torch.from_numpy(rewards) looks like torch.Size([463])\n", + "rewards looks like (540,)\n", + "log_probs looks like (540,)\n", + "logs prob looks like torch.Size([540])\n", + "torch.from_numpy(rewards) looks like torch.Size([540])\n", + "rewards looks like (513,)\n", + "log_probs looks like (513,)\n", + "logs prob looks like torch.Size([513])\n", + "torch.from_numpy(rewards) looks like torch.Size([513])\n", + "rewards looks like (449,)\n", + "log_probs looks like (449,)\n", + "logs prob looks like torch.Size([449])\n", + "torch.from_numpy(rewards) looks like torch.Size([449])\n", + "rewards looks like (602,)\n", + "log_probs looks like (602,)\n", + "logs prob looks like torch.Size([602])\n", + "torch.from_numpy(rewards) looks like torch.Size([602])\n", + "rewards looks like (542,)\n", + "log_probs looks like (542,)\n", + "logs prob looks like torch.Size([542])\n", + "torch.from_numpy(rewards) looks like torch.Size([542])\n", + "rewards looks like (503,)\n", + "log_probs looks like (503,)\n", + "logs prob looks like torch.Size([503])\n", + "torch.from_numpy(rewards) looks like torch.Size([503])\n", + "rewards looks like (470,)\n", + "log_probs looks like (470,)\n", + "logs prob looks like torch.Size([470])\n", + "torch.from_numpy(rewards) looks like torch.Size([470])\n", + "rewards looks like (518,)\n", + "log_probs looks like (518,)\n", + "logs prob looks like torch.Size([518])\n", + "torch.from_numpy(rewards) looks like torch.Size([518])\n", + "rewards looks like (421,)\n", + "log_probs looks like (421,)\n", + "logs prob looks like torch.Size([421])\n", + "torch.from_numpy(rewards) looks like torch.Size([421])\n", + "rewards looks like (592,)\n", + "log_probs looks like (592,)\n", + "logs prob looks like torch.Size([592])\n", + "torch.from_numpy(rewards) looks like torch.Size([592])\n", + "rewards looks like (520,)\n", + "log_probs looks like (520,)\n", + "logs prob looks like torch.Size([520])\n", + "torch.from_numpy(rewards) looks like torch.Size([520])\n", + "rewards looks like (494,)\n", + "log_probs looks like (494,)\n", + "logs prob looks like torch.Size([494])\n", + "torch.from_numpy(rewards) looks like torch.Size([494])\n", + "rewards looks like (461,)\n", + "log_probs looks like (461,)\n", + "logs prob looks like torch.Size([461])\n", + "torch.from_numpy(rewards) looks like torch.Size([461])\n", + "rewards looks like (572,)\n", + "log_probs looks like (572,)\n", + "logs prob looks like torch.Size([572])\n", + "torch.from_numpy(rewards) looks like torch.Size([572])\n", + "rewards looks like (593,)\n", + "log_probs looks like (593,)\n", + "logs prob looks like torch.Size([593])\n", + "torch.from_numpy(rewards) looks like torch.Size([593])\n", + "rewards looks like (569,)\n", + "log_probs looks like (569,)\n", + "logs prob looks like torch.Size([569])\n", + "torch.from_numpy(rewards) looks like torch.Size([569])\n", + "rewards looks like (546,)\n", + "log_probs looks like (546,)\n", + "logs prob looks like torch.Size([546])\n", + "torch.from_numpy(rewards) looks like torch.Size([546])\n", + "rewards looks like (612,)\n", + "log_probs looks like (612,)\n", + "logs prob looks like torch.Size([612])\n", + "torch.from_numpy(rewards) looks like torch.Size([612])\n", + "rewards looks like (534,)\n", + "log_probs looks like (534,)\n", + "logs prob looks like torch.Size([534])\n", + "torch.from_numpy(rewards) looks like torch.Size([534])\n", + "rewards looks like (513,)\n", + "log_probs looks like (513,)\n", + "logs prob looks like torch.Size([513])\n", + "torch.from_numpy(rewards) looks like torch.Size([513])\n", + "rewards looks like (513,)\n", + "log_probs looks like (513,)\n", + "logs prob looks like torch.Size([513])\n", + "torch.from_numpy(rewards) looks like torch.Size([513])\n", + "rewards looks like (535,)\n", + "log_probs looks like (535,)\n", + "logs prob looks like torch.Size([535])\n", + "torch.from_numpy(rewards) looks like torch.Size([535])\n", + "rewards looks like (533,)\n", + "log_probs looks like (533,)\n", + "logs prob looks like torch.Size([533])\n", + "torch.from_numpy(rewards) looks like torch.Size([533])\n", + "rewards looks like (521,)\n", + "log_probs looks like (521,)\n", + "logs prob looks like torch.Size([521])\n", + "torch.from_numpy(rewards) looks like torch.Size([521])\n", + "rewards looks like (566,)\n", + "log_probs looks like (566,)\n", + "logs prob looks like torch.Size([566])\n", + "torch.from_numpy(rewards) looks like torch.Size([566])\n", + "rewards looks like (586,)\n", + "log_probs looks like (586,)\n", + "logs prob looks like torch.Size([586])\n", + "torch.from_numpy(rewards) looks like torch.Size([586])\n", + "rewards looks like (575,)\n", + "log_probs looks like (575,)\n", + "logs prob looks like torch.Size([575])\n", + "torch.from_numpy(rewards) looks like torch.Size([575])\n", + "rewards looks like (709,)\n", + "log_probs looks like (709,)\n", + "logs prob looks like torch.Size([709])\n", + "torch.from_numpy(rewards) looks like torch.Size([709])\n", + "rewards looks like (486,)\n", + "log_probs looks like (486,)\n", + "logs prob looks like torch.Size([486])\n", + "torch.from_numpy(rewards) looks like torch.Size([486])\n", + "rewards looks like (557,)\n", + "log_probs looks like (557,)\n", + "logs prob looks like torch.Size([557])\n", + "torch.from_numpy(rewards) looks like torch.Size([557])\n", + "rewards looks like (517,)\n", + "log_probs looks like (517,)\n", + "logs prob looks like torch.Size([517])\n", + "torch.from_numpy(rewards) looks like torch.Size([517])\n", + "rewards looks like (550,)\n", + "log_probs looks like (550,)\n", + "logs prob looks like torch.Size([550])\n", + "torch.from_numpy(rewards) looks like torch.Size([550])\n", + "rewards looks like (690,)\n", + "log_probs looks like (690,)\n", + "logs prob looks like torch.Size([690])\n", + "torch.from_numpy(rewards) looks like torch.Size([690])\n", + "rewards looks like (591,)\n", + "log_probs looks like (591,)\n", + "logs prob looks like torch.Size([591])\n", + "torch.from_numpy(rewards) looks like torch.Size([591])\n", + "rewards looks like (689,)\n", + "log_probs looks like (689,)\n", + "logs prob looks like torch.Size([689])\n", + "torch.from_numpy(rewards) looks like torch.Size([689])\n", + "rewards looks like (1059,)\n", + "log_probs looks like (1059,)\n", + "logs prob looks like torch.Size([1059])\n", + "torch.from_numpy(rewards) looks like torch.Size([1059])\n", + "rewards looks like (619,)\n", + "log_probs looks like (619,)\n", + "logs prob looks like torch.Size([619])\n", + "torch.from_numpy(rewards) looks like torch.Size([619])\n", + "rewards looks like (527,)\n", + "log_probs looks like (527,)\n", + "logs prob looks like torch.Size([527])\n", + "torch.from_numpy(rewards) looks like torch.Size([527])\n", + "rewards looks like (514,)\n", + "log_probs looks like (514,)\n", + "logs prob looks like torch.Size([514])\n", + "torch.from_numpy(rewards) looks like torch.Size([514])\n", + "rewards looks like (655,)\n", + "log_probs looks like (655,)\n", + "logs prob looks like torch.Size([655])\n", + "torch.from_numpy(rewards) looks like torch.Size([655])\n", + "rewards looks like (667,)\n", + "log_probs looks like (667,)\n", + "logs prob looks like torch.Size([667])\n", + "torch.from_numpy(rewards) looks like torch.Size([667])\n", + "rewards looks like (712,)\n", + "log_probs looks like (712,)\n", + "logs prob looks like torch.Size([712])\n", + "torch.from_numpy(rewards) looks like torch.Size([712])\n", + "rewards looks like (636,)\n", + "log_probs looks like (636,)\n", + "logs prob looks like torch.Size([636])\n", + "torch.from_numpy(rewards) looks like torch.Size([636])\n", + "rewards looks like (620,)\n", + "log_probs looks like (620,)\n", + "logs prob looks like torch.Size([620])\n", + "torch.from_numpy(rewards) looks like torch.Size([620])\n", + "rewards looks like (543,)\n", + "log_probs looks like (543,)\n", + "logs prob looks like torch.Size([543])\n", + "torch.from_numpy(rewards) looks like torch.Size([543])\n", + "rewards looks like (586,)\n", + "log_probs looks like (586,)\n", + "logs prob looks like torch.Size([586])\n", + "torch.from_numpy(rewards) looks like torch.Size([586])\n", + "rewards looks like (498,)\n", + "log_probs looks like (498,)\n", + "logs prob looks like torch.Size([498])\n", + "torch.from_numpy(rewards) looks like torch.Size([498])\n", + "rewards looks like (586,)\n", + "log_probs looks like (586,)\n", + "logs prob looks like torch.Size([586])\n", + "torch.from_numpy(rewards) looks like torch.Size([586])\n", + "rewards looks like (591,)\n", + "log_probs looks like (591,)\n", + "logs prob looks like torch.Size([591])\n", + "torch.from_numpy(rewards) looks like torch.Size([591])\n", + "rewards looks like (693,)\n", + "log_probs looks like (693,)\n", + "logs prob looks like torch.Size([693])\n", + "torch.from_numpy(rewards) looks like torch.Size([693])\n", + "rewards looks like (648,)\n", + "log_probs looks like (648,)\n", + "logs prob looks like torch.Size([648])\n", + "torch.from_numpy(rewards) looks like torch.Size([648])\n", + "rewards looks like (513,)\n", + "log_probs looks like (513,)\n", + "logs prob looks like torch.Size([513])\n", + "torch.from_numpy(rewards) looks like torch.Size([513])\n", + "rewards looks like (574,)\n", + "log_probs looks like (574,)\n", + "logs prob looks like torch.Size([574])\n", + "torch.from_numpy(rewards) looks like torch.Size([574])\n", + "rewards looks like (718,)\n", + "log_probs looks like (718,)\n", + "logs prob looks like torch.Size([718])\n", + "torch.from_numpy(rewards) looks like torch.Size([718])\n", + "rewards looks like (730,)\n", + "log_probs looks like (730,)\n", + "logs prob looks like torch.Size([730])\n", + "torch.from_numpy(rewards) looks like torch.Size([730])\n", + "rewards looks like (668,)\n", + "log_probs looks like (668,)\n", + "logs prob looks like torch.Size([668])\n", + "torch.from_numpy(rewards) looks like torch.Size([668])\n", + "rewards looks like (754,)\n", + "log_probs looks like (754,)\n", + "logs prob looks like torch.Size([754])\n", + "torch.from_numpy(rewards) looks like torch.Size([754])\n", + "rewards looks like (712,)\n", + "log_probs looks like (712,)\n", + "logs prob looks like torch.Size([712])\n", + "torch.from_numpy(rewards) looks like torch.Size([712])\n", + "rewards looks like (470,)\n", + "log_probs looks like (470,)\n", + "logs prob looks like torch.Size([470])\n", + "torch.from_numpy(rewards) looks like torch.Size([470])\n", + "rewards looks like (665,)\n", + "log_probs looks like (665,)\n", + "logs prob looks like torch.Size([665])\n", + "torch.from_numpy(rewards) looks like torch.Size([665])\n", + "rewards looks like (585,)\n", + "log_probs looks like (585,)\n", + "logs prob looks like torch.Size([585])\n", + "torch.from_numpy(rewards) looks like torch.Size([585])\n", + "rewards looks like (512,)\n", + "log_probs looks like (512,)\n", + "logs prob looks like torch.Size([512])\n", + "torch.from_numpy(rewards) looks like torch.Size([512])\n", + "rewards looks like (702,)\n", + "log_probs looks like (702,)\n", + "logs prob looks like torch.Size([702])\n", + "torch.from_numpy(rewards) looks like torch.Size([702])\n", + "rewards looks like (596,)\n", + "log_probs looks like (596,)\n", + "logs prob looks like torch.Size([596])\n", + "torch.from_numpy(rewards) looks like torch.Size([596])\n", + "rewards looks like (626,)\n", + "log_probs looks like (626,)\n", + "logs prob looks like torch.Size([626])\n", + "torch.from_numpy(rewards) looks like torch.Size([626])\n", + "rewards looks like (566,)\n", + "log_probs looks like (566,)\n", + "logs prob looks like torch.Size([566])\n", + "torch.from_numpy(rewards) looks like torch.Size([566])\n", + "rewards looks like (717,)\n", + "log_probs looks like (717,)\n", + "logs prob looks like torch.Size([717])\n", + "torch.from_numpy(rewards) looks like torch.Size([717])\n", + "rewards looks like (708,)\n", + "log_probs looks like (708,)\n", + "logs prob looks like torch.Size([708])\n", + "torch.from_numpy(rewards) looks like torch.Size([708])\n", + "rewards looks like (565,)\n", + "log_probs looks like (565,)\n", + "logs prob looks like torch.Size([565])\n", + "torch.from_numpy(rewards) looks like torch.Size([565])\n", + "rewards looks like (450,)\n", + "log_probs looks like (450,)\n", + "logs prob looks like torch.Size([450])\n", + "torch.from_numpy(rewards) looks like torch.Size([450])\n", + "rewards looks like (584,)\n", + "log_probs looks like (584,)\n", + "logs prob looks like torch.Size([584])\n", + "torch.from_numpy(rewards) looks like torch.Size([584])\n", + "rewards looks like (670,)\n", + "log_probs looks like (670,)\n", + "logs prob looks like torch.Size([670])\n", + "torch.from_numpy(rewards) looks like torch.Size([670])\n", + "rewards looks like (691,)\n", + "log_probs looks like (691,)\n", + "logs prob looks like torch.Size([691])\n", + "torch.from_numpy(rewards) looks like torch.Size([691])\n", + "rewards looks like (760,)\n", + "log_probs looks like (760,)\n", + "logs prob looks like torch.Size([760])\n", + "torch.from_numpy(rewards) looks like torch.Size([760])\n", + "rewards looks like (752,)\n", + "log_probs looks like (752,)\n", + "logs prob looks like torch.Size([752])\n", + "torch.from_numpy(rewards) looks like torch.Size([752])\n", + "rewards looks like (478,)\n", + "log_probs looks like (478,)\n", + "logs prob looks like torch.Size([478])\n", + "torch.from_numpy(rewards) looks like torch.Size([478])\n", + "rewards looks like (553,)\n", + "log_probs looks like (553,)\n", + "logs prob looks like torch.Size([553])\n", + "torch.from_numpy(rewards) looks like torch.Size([553])\n", + "rewards looks like (1660,)\n", + "log_probs looks like (1660,)\n", + "logs prob looks like torch.Size([1660])\n", + "torch.from_numpy(rewards) looks like torch.Size([1660])\n", + "rewards looks like (751,)\n", + "log_probs looks like (751,)\n", + "logs prob looks like torch.Size([751])\n", + "torch.from_numpy(rewards) looks like torch.Size([751])\n", + "rewards looks like (801,)\n", + "log_probs looks like (801,)\n", + "logs prob looks like torch.Size([801])\n", + "torch.from_numpy(rewards) looks like torch.Size([801])\n", + "rewards looks like (715,)\n", + "log_probs looks like (715,)\n", + "logs prob looks like torch.Size([715])\n", + "torch.from_numpy(rewards) looks like torch.Size([715])\n", + "rewards looks like (708,)\n", + "log_probs looks like (708,)\n", + "logs prob looks like torch.Size([708])\n", + "torch.from_numpy(rewards) looks like torch.Size([708])\n", + "rewards looks like (609,)\n", + "log_probs looks like (609,)\n", + "logs prob looks like torch.Size([609])\n", + "torch.from_numpy(rewards) looks like torch.Size([609])\n", + "rewards looks like (732,)\n", + "log_probs looks like (732,)\n", + "logs prob looks like torch.Size([732])\n", + "torch.from_numpy(rewards) looks like torch.Size([732])\n", + "rewards looks like (603,)\n", + "log_probs looks like (603,)\n", + "logs prob looks like torch.Size([603])\n", + "torch.from_numpy(rewards) looks like torch.Size([603])\n", + "rewards looks like (603,)\n", + "log_probs looks like (603,)\n", + "logs prob looks like torch.Size([603])\n", + "torch.from_numpy(rewards) looks like torch.Size([603])\n", + "rewards looks like (665,)\n", + "log_probs looks like (665,)\n", + "logs prob looks like torch.Size([665])\n", + "torch.from_numpy(rewards) looks like torch.Size([665])\n", + "rewards looks like (658,)\n", + "log_probs looks like (658,)\n", + "logs prob looks like torch.Size([658])\n", + "torch.from_numpy(rewards) looks like torch.Size([658])\n", + "rewards looks like (783,)\n", + "log_probs looks like (783,)\n", + "logs prob looks like torch.Size([783])\n", + "torch.from_numpy(rewards) looks like torch.Size([783])\n", + "rewards looks like (652,)\n", + "log_probs looks like (652,)\n", + "logs prob looks like torch.Size([652])\n", + "torch.from_numpy(rewards) looks like torch.Size([652])\n", + "rewards looks like (892,)\n", + "log_probs looks like (892,)\n", + "logs prob looks like torch.Size([892])\n", + "torch.from_numpy(rewards) looks like torch.Size([892])\n", + "rewards looks like (821,)\n", + "log_probs looks like (821,)\n", + "logs prob looks like torch.Size([821])\n", + "torch.from_numpy(rewards) looks like torch.Size([821])\n", + "rewards looks like (986,)\n", + "log_probs looks like (986,)\n", + "logs prob looks like torch.Size([986])\n", + "torch.from_numpy(rewards) looks like torch.Size([986])\n", + "rewards looks like (916,)\n", + "log_probs looks like (916,)\n", + "logs prob looks like torch.Size([916])\n", + "torch.from_numpy(rewards) looks like torch.Size([916])\n", + "rewards looks like (742,)\n", + "log_probs looks like (742,)\n", + "logs prob looks like torch.Size([742])\n", + "torch.from_numpy(rewards) looks like torch.Size([742])\n", + "rewards looks like (604,)\n", + "log_probs looks like (604,)\n", + "logs prob looks like torch.Size([604])\n", + "torch.from_numpy(rewards) looks like torch.Size([604])\n", + "rewards looks like (818,)\n", + "log_probs looks like (818,)\n", + "logs prob looks like torch.Size([818])\n", + "torch.from_numpy(rewards) looks like torch.Size([818])\n", + "rewards looks like (855,)\n", + "log_probs looks like (855,)\n", + "logs prob looks like torch.Size([855])\n", + "torch.from_numpy(rewards) looks like torch.Size([855])\n", + "rewards looks like (795,)\n", + "log_probs looks like (795,)\n", + "logs prob looks like torch.Size([795])\n", + "torch.from_numpy(rewards) looks like torch.Size([795])\n", + "rewards looks like (868,)\n", + "log_probs looks like (868,)\n", + "logs prob looks like torch.Size([868])\n", + "torch.from_numpy(rewards) looks like torch.Size([868])\n", + "rewards looks like (800,)\n", + "log_probs looks like (800,)\n", + "logs prob looks like torch.Size([800])\n", + "torch.from_numpy(rewards) looks like torch.Size([800])\n", + "rewards looks like (820,)\n", + "log_probs looks like (820,)\n", + "logs prob looks like torch.Size([820])\n", + "torch.from_numpy(rewards) looks like torch.Size([820])\n", + "rewards looks like (760,)\n", + "log_probs looks like (760,)\n", + "logs prob looks like torch.Size([760])\n", + "torch.from_numpy(rewards) looks like torch.Size([760])\n", + "rewards looks like (886,)\n", + "log_probs looks like (886,)\n", + "logs prob looks like torch.Size([886])\n", + "torch.from_numpy(rewards) looks like torch.Size([886])\n", + "rewards looks like (1027,)\n", + "log_probs looks like (1027,)\n", + "logs prob looks like torch.Size([1027])\n", + "torch.from_numpy(rewards) looks like torch.Size([1027])\n", + "rewards looks like (819,)\n", + "log_probs looks like (819,)\n", + "logs prob looks like torch.Size([819])\n", + "torch.from_numpy(rewards) looks like torch.Size([819])\n", + "rewards looks like (934,)\n", + "log_probs looks like (934,)\n", + "logs prob looks like torch.Size([934])\n", + "torch.from_numpy(rewards) looks like torch.Size([934])\n", + "rewards looks like (1648,)\n", + "log_probs looks like (1648,)\n", + "logs prob looks like torch.Size([1648])\n", + "torch.from_numpy(rewards) looks like torch.Size([1648])\n", + "rewards looks like (1057,)\n", + "log_probs looks like (1057,)\n", + "logs prob looks like torch.Size([1057])\n", + "torch.from_numpy(rewards) looks like torch.Size([1057])\n", + "rewards looks like (861,)\n", + "log_probs looks like (861,)\n", + "logs prob looks like torch.Size([861])\n", + "torch.from_numpy(rewards) looks like torch.Size([861])\n", + "rewards looks like (1533,)\n", + "log_probs looks like (1533,)\n", + "logs prob looks like torch.Size([1533])\n", + "torch.from_numpy(rewards) looks like torch.Size([1533])\n", + "rewards looks like (920,)\n", + "log_probs looks like (920,)\n", + "logs prob looks like torch.Size([920])\n", + "torch.from_numpy(rewards) looks like torch.Size([920])\n", + "rewards looks like (905,)\n", + "log_probs looks like (905,)\n", + "logs prob looks like torch.Size([905])\n", + "torch.from_numpy(rewards) looks like torch.Size([905])\n", + "rewards looks like (814,)\n", + "log_probs looks like (814,)\n", + "logs prob looks like torch.Size([814])\n", + "torch.from_numpy(rewards) looks like torch.Size([814])\n", + "rewards looks like (809,)\n", + "log_probs looks like (809,)\n", + "logs prob looks like torch.Size([809])\n", + "torch.from_numpy(rewards) looks like torch.Size([809])\n", + "rewards looks like (873,)\n", + "log_probs looks like (873,)\n", + "logs prob looks like torch.Size([873])\n", + "torch.from_numpy(rewards) looks like torch.Size([873])\n", + "rewards looks like (727,)\n", + "log_probs looks like (727,)\n", + "logs prob looks like torch.Size([727])\n", + "torch.from_numpy(rewards) looks like torch.Size([727])\n", + "rewards looks like (1129,)\n", + "log_probs looks like (1129,)\n", + "logs prob looks like torch.Size([1129])\n", + "torch.from_numpy(rewards) looks like torch.Size([1129])\n", + "rewards looks like (1394,)\n", + "log_probs looks like (1394,)\n", + "logs prob looks like torch.Size([1394])\n", + "torch.from_numpy(rewards) looks like torch.Size([1394])\n", + "rewards looks like (884,)\n", + "log_probs looks like (884,)\n", + "logs prob looks like torch.Size([884])\n", + "torch.from_numpy(rewards) looks like torch.Size([884])\n", + "rewards looks like (1132,)\n", + "log_probs looks like (1132,)\n", + "logs prob looks like torch.Size([1132])\n", + "torch.from_numpy(rewards) looks like torch.Size([1132])\n", + "rewards looks like (1007,)\n", + "log_probs looks like (1007,)\n", + "logs prob looks like torch.Size([1007])\n", + "torch.from_numpy(rewards) looks like torch.Size([1007])\n", + "rewards looks like (711,)\n", + "log_probs looks like (711,)\n", + "logs prob looks like torch.Size([711])\n", + "torch.from_numpy(rewards) looks like torch.Size([711])\n", + "rewards looks like (836,)\n", + "log_probs looks like (836,)\n", + "logs prob looks like torch.Size([836])\n", + "torch.from_numpy(rewards) looks like torch.Size([836])\n", + "rewards looks like (1514,)\n", + "log_probs looks like (1514,)\n", + "logs prob looks like torch.Size([1514])\n", + "torch.from_numpy(rewards) looks like torch.Size([1514])\n", + "rewards looks like (896,)\n", + "log_probs looks like (896,)\n", + "logs prob looks like torch.Size([896])\n", + "torch.from_numpy(rewards) looks like torch.Size([896])\n", + "rewards looks like (912,)\n", + "log_probs looks like (912,)\n", + "logs prob looks like torch.Size([912])\n", + "torch.from_numpy(rewards) looks like torch.Size([912])\n", + "rewards looks like (1478,)\n", + "log_probs looks like (1478,)\n", + "logs prob looks like torch.Size([1478])\n", + "torch.from_numpy(rewards) looks like torch.Size([1478])\n", + "rewards looks like (1279,)\n", + "log_probs looks like (1279,)\n", + "logs prob looks like torch.Size([1279])\n", + "torch.from_numpy(rewards) looks like torch.Size([1279])\n", + "rewards looks like (676,)\n", + "log_probs looks like (676,)\n", + "logs prob looks like torch.Size([676])\n", + "torch.from_numpy(rewards) looks like torch.Size([676])\n", + "rewards looks like (1768,)\n", + "log_probs looks like (1768,)\n", + "logs prob looks like torch.Size([1768])\n", + "torch.from_numpy(rewards) looks like torch.Size([1768])\n", + "rewards looks like (897,)\n", + "log_probs looks like (897,)\n", + "logs prob looks like torch.Size([897])\n", + "torch.from_numpy(rewards) looks like torch.Size([897])\n", + "rewards looks like (1119,)\n", + "log_probs looks like (1119,)\n", + "logs prob looks like torch.Size([1119])\n", + "torch.from_numpy(rewards) looks like torch.Size([1119])\n", + "rewards looks like (943,)\n", + "log_probs looks like (943,)\n", + "logs prob looks like torch.Size([943])\n", + "torch.from_numpy(rewards) looks like torch.Size([943])\n", + "rewards looks like (1255,)\n", + "log_probs looks like (1255,)\n", + "logs prob looks like torch.Size([1255])\n", + "torch.from_numpy(rewards) looks like torch.Size([1255])\n", + "rewards looks like (861,)\n", + "log_probs looks like (861,)\n", + "logs prob looks like torch.Size([861])\n", + "torch.from_numpy(rewards) looks like torch.Size([861])\n", + "rewards looks like (1149,)\n", + "log_probs looks like (1149,)\n", + "logs prob looks like torch.Size([1149])\n", + "torch.from_numpy(rewards) looks like torch.Size([1149])\n", + "rewards looks like (1229,)\n", + "log_probs looks like (1229,)\n", + "logs prob looks like torch.Size([1229])\n", + "torch.from_numpy(rewards) looks like torch.Size([1229])\n", + "rewards looks like (1680,)\n", + "log_probs looks like (1680,)\n", + "logs prob looks like torch.Size([1680])\n", + "torch.from_numpy(rewards) looks like torch.Size([1680])\n", + "rewards looks like (1731,)\n", + "log_probs looks like (1731,)\n", + "logs prob looks like torch.Size([1731])\n", + "torch.from_numpy(rewards) looks like torch.Size([1731])\n", + "rewards looks like (1017,)\n", + "log_probs looks like (1017,)\n", + "logs prob looks like torch.Size([1017])\n", + "torch.from_numpy(rewards) looks like torch.Size([1017])\n", + "rewards looks like (990,)\n", + "log_probs looks like (990,)\n", + "logs prob looks like torch.Size([990])\n", + "torch.from_numpy(rewards) looks like torch.Size([990])\n", + "rewards looks like (1020,)\n", + "log_probs looks like (1020,)\n", + "logs prob looks like torch.Size([1020])\n", + "torch.from_numpy(rewards) looks like torch.Size([1020])\n", + "rewards looks like (1240,)\n", + "log_probs looks like (1240,)\n", + "logs prob looks like torch.Size([1240])\n", + "torch.from_numpy(rewards) looks like torch.Size([1240])\n", + "rewards looks like (774,)\n", + "log_probs looks like (774,)\n", + "logs prob looks like torch.Size([774])\n", + "torch.from_numpy(rewards) looks like torch.Size([774])\n", + "rewards looks like (1069,)\n", + "log_probs looks like (1069,)\n", + "logs prob looks like torch.Size([1069])\n", + "torch.from_numpy(rewards) looks like torch.Size([1069])\n", + "rewards looks like (1355,)\n", + "log_probs looks like (1355,)\n", + "logs prob looks like torch.Size([1355])\n", + "torch.from_numpy(rewards) looks like torch.Size([1355])\n", + "rewards looks like (1556,)\n", + "log_probs looks like (1556,)\n", + "logs prob looks like torch.Size([1556])\n", + "torch.from_numpy(rewards) looks like torch.Size([1556])\n", + "rewards looks like (1840,)\n", + "log_probs looks like (1840,)\n", + "logs prob looks like torch.Size([1840])\n", + "torch.from_numpy(rewards) looks like torch.Size([1840])\n", + "rewards looks like (1352,)\n", + "log_probs looks like (1352,)\n", + "logs prob looks like torch.Size([1352])\n", + "torch.from_numpy(rewards) looks like torch.Size([1352])\n", + "rewards looks like (1617,)\n", + "log_probs looks like (1617,)\n", + "logs prob looks like torch.Size([1617])\n", + "torch.from_numpy(rewards) looks like torch.Size([1617])\n", + "rewards looks like (1637,)\n", + "log_probs looks like (1637,)\n", + "logs prob looks like torch.Size([1637])\n", + "torch.from_numpy(rewards) looks like torch.Size([1637])\n", + "rewards looks like (1606,)\n", + "log_probs looks like (1606,)\n", + "logs prob looks like torch.Size([1606])\n", + "torch.from_numpy(rewards) looks like torch.Size([1606])\n", + "rewards looks like (860,)\n", + "log_probs looks like (860,)\n", + "logs prob looks like torch.Size([860])\n", + "torch.from_numpy(rewards) looks like torch.Size([860])\n", + "rewards looks like (1780,)\n", + "log_probs looks like (1780,)\n", + "logs prob looks like torch.Size([1780])\n", + "torch.from_numpy(rewards) looks like torch.Size([1780])\n", + "rewards looks like (2248,)\n", + "log_probs looks like (2248,)\n", + "logs prob looks like torch.Size([2248])\n", + "torch.from_numpy(rewards) looks like torch.Size([2248])\n", + "rewards looks like (1410,)\n", + "log_probs looks like (1410,)\n", + "logs prob looks like torch.Size([1410])\n", + "torch.from_numpy(rewards) looks like torch.Size([1410])\n", + "rewards looks like (557,)\n", + "log_probs looks like (557,)\n", + "logs prob looks like torch.Size([557])\n", + "torch.from_numpy(rewards) looks like torch.Size([557])\n", + "rewards looks like (719,)\n", + "log_probs looks like (719,)\n", + "logs prob looks like torch.Size([719])\n", + "torch.from_numpy(rewards) looks like torch.Size([719])\n", + "rewards looks like (1919,)\n", + "log_probs looks like (1919,)\n", + "logs prob looks like torch.Size([1919])\n", + "torch.from_numpy(rewards) looks like torch.Size([1919])\n", + "rewards looks like (1250,)\n", + "log_probs looks like (1250,)\n", + "logs prob looks like torch.Size([1250])\n", + "torch.from_numpy(rewards) looks like torch.Size([1250])\n", + "rewards looks like (1054,)\n", + "log_probs looks like (1054,)\n", + "logs prob looks like torch.Size([1054])\n", + "torch.from_numpy(rewards) looks like torch.Size([1054])\n", + "rewards looks like (1276,)\n", + "log_probs looks like (1276,)\n", + "logs prob looks like torch.Size([1276])\n", + "torch.from_numpy(rewards) looks like torch.Size([1276])\n", + "rewards looks like (1040,)\n", + "log_probs looks like (1040,)\n", + "logs prob looks like torch.Size([1040])\n", + "torch.from_numpy(rewards) looks like torch.Size([1040])\n", + "rewards looks like (991,)\n", + "log_probs looks like (991,)\n", + "logs prob looks like torch.Size([991])\n", + "torch.from_numpy(rewards) looks like torch.Size([991])\n", + "rewards looks like (1390,)\n", + "log_probs looks like (1390,)\n", + "logs prob looks like torch.Size([1390])\n", + "torch.from_numpy(rewards) looks like torch.Size([1390])\n", + "rewards looks like (1349,)\n", + "log_probs looks like (1349,)\n", + "logs prob looks like torch.Size([1349])\n", + "torch.from_numpy(rewards) looks like torch.Size([1349])\n", + "rewards looks like (1332,)\n", + "log_probs looks like (1332,)\n", + "logs prob looks like torch.Size([1332])\n", + "torch.from_numpy(rewards) looks like torch.Size([1332])\n", + "rewards looks like (1378,)\n", + "log_probs looks like (1378,)\n", + "logs prob looks like torch.Size([1378])\n", + "torch.from_numpy(rewards) looks like torch.Size([1378])\n", + "rewards looks like (1967,)\n", + "log_probs looks like (1967,)\n", + "logs prob looks like torch.Size([1967])\n", + "torch.from_numpy(rewards) looks like torch.Size([1967])\n", + "rewards looks like (1789,)\n", + "log_probs looks like (1789,)\n", + "logs prob looks like torch.Size([1789])\n", + "torch.from_numpy(rewards) looks like torch.Size([1789])\n", + "rewards looks like (1325,)\n", + "log_probs looks like (1325,)\n", + "logs prob looks like torch.Size([1325])\n", + "torch.from_numpy(rewards) looks like torch.Size([1325])\n", + "rewards looks like (1685,)\n", + "log_probs looks like (1685,)\n", + "logs prob looks like torch.Size([1685])\n", + "torch.from_numpy(rewards) looks like torch.Size([1685])\n", + "rewards looks like (1895,)\n", + "log_probs looks like (1895,)\n", + "logs prob looks like torch.Size([1895])\n", + "torch.from_numpy(rewards) looks like torch.Size([1895])\n", + "rewards looks like (1920,)\n", + "log_probs looks like (1920,)\n", + "logs prob looks like torch.Size([1920])\n", + "torch.from_numpy(rewards) looks like torch.Size([1920])\n", + "rewards looks like (1522,)\n", + "log_probs looks like (1522,)\n", + "logs prob looks like torch.Size([1522])\n", + "torch.from_numpy(rewards) looks like torch.Size([1522])\n", + "rewards looks like (1173,)\n", + "log_probs looks like (1173,)\n", + "logs prob looks like torch.Size([1173])\n", + "torch.from_numpy(rewards) looks like torch.Size([1173])\n", + "rewards looks like (2136,)\n", + "log_probs looks like (2136,)\n", + "logs prob looks like torch.Size([2136])\n", + "torch.from_numpy(rewards) looks like torch.Size([2136])\n", + "rewards looks like (1696,)\n", + "log_probs looks like (1696,)\n", + "logs prob looks like torch.Size([1696])\n", + "torch.from_numpy(rewards) looks like torch.Size([1696])\n", + "rewards looks like (568,)\n", + "log_probs looks like (568,)\n", + "logs prob looks like torch.Size([568])\n", + "torch.from_numpy(rewards) looks like torch.Size([568])\n", + "rewards looks like (1475,)\n", + "log_probs looks like (1475,)\n", + "logs prob looks like torch.Size([1475])\n", + "torch.from_numpy(rewards) looks like torch.Size([1475])\n", + "rewards looks like (2470,)\n", + "log_probs looks like (2470,)\n", + "logs prob looks like torch.Size([2470])\n", + "torch.from_numpy(rewards) looks like torch.Size([2470])\n", + "rewards looks like (3053,)\n", + "log_probs looks like (3053,)\n", + "logs prob looks like torch.Size([3053])\n", + "torch.from_numpy(rewards) looks like torch.Size([3053])\n", + "rewards looks like (915,)\n", + "log_probs looks like (915,)\n", + "logs prob looks like torch.Size([915])\n", + "torch.from_numpy(rewards) looks like torch.Size([915])\n", + "rewards looks like (2049,)\n", + "log_probs looks like (2049,)\n", + "logs prob looks like torch.Size([2049])\n", + "torch.from_numpy(rewards) looks like torch.Size([2049])\n", + "rewards looks like (2068,)\n", + "log_probs looks like (2068,)\n", + "logs prob looks like torch.Size([2068])\n", + "torch.from_numpy(rewards) looks like torch.Size([2068])\n", + "rewards looks like (2528,)\n", + "log_probs looks like (2528,)\n", + "logs prob looks like torch.Size([2528])\n", + "torch.from_numpy(rewards) looks like torch.Size([2528])\n", + "rewards looks like (1839,)\n", + "log_probs looks like (1839,)\n", + "logs prob looks like torch.Size([1839])\n", + "torch.from_numpy(rewards) looks like torch.Size([1839])\n", + "rewards looks like (497,)\n", + "log_probs looks like (497,)\n", + "logs prob looks like torch.Size([497])\n", + "torch.from_numpy(rewards) looks like torch.Size([497])\n", + "rewards looks like (627,)\n", + "log_probs looks like (627,)\n", + "logs prob looks like torch.Size([627])\n", + "torch.from_numpy(rewards) looks like torch.Size([627])\n", + "rewards looks like (2354,)\n", + "log_probs looks like (2354,)\n", + "logs prob looks like torch.Size([2354])\n", + "torch.from_numpy(rewards) looks like torch.Size([2354])\n", + "rewards looks like (2394,)\n", + "log_probs looks like (2394,)\n", + "logs prob looks like torch.Size([2394])\n", + "torch.from_numpy(rewards) looks like torch.Size([2394])\n", + "rewards looks like (743,)\n", + "log_probs looks like (743,)\n", + "logs prob looks like torch.Size([743])\n", + "torch.from_numpy(rewards) looks like torch.Size([743])\n", + "rewards looks like (1572,)\n", + "log_probs looks like (1572,)\n", + "logs prob looks like torch.Size([1572])\n", + "torch.from_numpy(rewards) looks like torch.Size([1572])\n", + "rewards looks like (2575,)\n", + "log_probs looks like (2575,)\n", + "logs prob looks like torch.Size([2575])\n", + "torch.from_numpy(rewards) looks like torch.Size([2575])\n", + "rewards looks like (2226,)\n", + "log_probs looks like (2226,)\n", + "logs prob looks like torch.Size([2226])\n", + "torch.from_numpy(rewards) looks like torch.Size([2226])\n", + "rewards looks like (541,)\n", + "log_probs looks like (541,)\n", + "logs prob looks like torch.Size([541])\n", + "torch.from_numpy(rewards) looks like torch.Size([541])\n", + "rewards looks like (820,)\n", + "log_probs looks like (820,)\n", + "logs prob looks like torch.Size([820])\n", + "torch.from_numpy(rewards) looks like torch.Size([820])\n", + "rewards looks like (2584,)\n", + "log_probs looks like (2584,)\n", + "logs prob looks like torch.Size([2584])\n", + "torch.from_numpy(rewards) looks like torch.Size([2584])\n", + "rewards looks like (1792,)\n", + "log_probs looks like (1792,)\n", + "logs prob looks like torch.Size([1792])\n", + "torch.from_numpy(rewards) looks like torch.Size([1792])\n", + "rewards looks like (1613,)\n", + "log_probs looks like (1613,)\n", + "logs prob looks like torch.Size([1613])\n", + "torch.from_numpy(rewards) looks like torch.Size([1613])\n", + "rewards looks like (4300,)\n", + "log_probs looks like (4300,)\n", + "logs prob looks like torch.Size([4300])\n", + "torch.from_numpy(rewards) looks like torch.Size([4300])\n", + "rewards looks like (1602,)\n", + "log_probs looks like (1602,)\n", + "logs prob looks like torch.Size([1602])\n", + "torch.from_numpy(rewards) looks like torch.Size([1602])\n", + "rewards looks like (3313,)\n", + "log_probs looks like (3313,)\n", + "logs prob looks like torch.Size([3313])\n", + "torch.from_numpy(rewards) looks like torch.Size([3313])\n", + "rewards looks like (1538,)\n", + "log_probs looks like (1538,)\n", + "logs prob looks like torch.Size([1538])\n", + "torch.from_numpy(rewards) looks like torch.Size([1538])\n", + "rewards looks like (1824,)\n", + "log_probs looks like (1824,)\n", + "logs prob looks like torch.Size([1824])\n", + "torch.from_numpy(rewards) looks like torch.Size([1824])\n", + "rewards looks like (1320,)\n", + "log_probs looks like (1320,)\n", + "logs prob looks like torch.Size([1320])\n", + "torch.from_numpy(rewards) looks like torch.Size([1320])\n", + "rewards looks like (2077,)\n", + "log_probs looks like (2077,)\n", + "logs prob looks like torch.Size([2077])\n", + "torch.from_numpy(rewards) looks like torch.Size([2077])\n", + "rewards looks like (1995,)\n", + "log_probs looks like (1995,)\n", + "logs prob looks like torch.Size([1995])\n", + "torch.from_numpy(rewards) looks like torch.Size([1995])\n", + "rewards looks like (1089,)\n", + "log_probs looks like (1089,)\n", + "logs prob looks like torch.Size([1089])\n", + "torch.from_numpy(rewards) looks like torch.Size([1089])\n", + "rewards looks like (1135,)\n", + "log_probs looks like (1135,)\n", + "logs prob looks like torch.Size([1135])\n", + "torch.from_numpy(rewards) looks like torch.Size([1135])\n", + "rewards looks like (1617,)\n", + "log_probs looks like (1617,)\n", + "logs prob looks like torch.Size([1617])\n", + "torch.from_numpy(rewards) looks like torch.Size([1617])\n", + "rewards looks like (942,)\n", + "log_probs looks like (942,)\n", + "logs prob looks like torch.Size([942])\n", + "torch.from_numpy(rewards) looks like torch.Size([942])\n", + "rewards looks like (2006,)\n", + "log_probs looks like (2006,)\n", + "logs prob looks like torch.Size([2006])\n", + "torch.from_numpy(rewards) looks like torch.Size([2006])\n", + "rewards looks like (2204,)\n", + "log_probs looks like (2204,)\n", + "logs prob looks like torch.Size([2204])\n", + "torch.from_numpy(rewards) looks like torch.Size([2204])\n", + "rewards looks like (1060,)\n", + "log_probs looks like (1060,)\n", + "logs prob looks like torch.Size([1060])\n", + "torch.from_numpy(rewards) looks like torch.Size([1060])\n", + "rewards looks like (1994,)\n", + "log_probs looks like (1994,)\n", + "logs prob looks like torch.Size([1994])\n", + "torch.from_numpy(rewards) looks like torch.Size([1994])\n", + "rewards looks like (1118,)\n", + "log_probs looks like (1118,)\n", + "logs prob looks like torch.Size([1118])\n", + "torch.from_numpy(rewards) looks like torch.Size([1118])\n", + "rewards looks like (1298,)\n", + "log_probs looks like (1298,)\n", + "logs prob looks like torch.Size([1298])\n", + "torch.from_numpy(rewards) looks like torch.Size([1298])\n", + "rewards looks like (1377,)\n", + "log_probs looks like (1377,)\n", + "logs prob looks like torch.Size([1377])\n", + "torch.from_numpy(rewards) looks like torch.Size([1377])\n", + "rewards looks like (1902,)\n", + "log_probs looks like (1902,)\n", + "logs prob looks like torch.Size([1902])\n", + "torch.from_numpy(rewards) looks like torch.Size([1902])\n", + "rewards looks like (1982,)\n", + "log_probs looks like (1982,)\n", + "logs prob looks like torch.Size([1982])\n", + "torch.from_numpy(rewards) looks like torch.Size([1982])\n", + "rewards looks like (1625,)\n", + "log_probs looks like (1625,)\n", + "logs prob looks like torch.Size([1625])\n", + "torch.from_numpy(rewards) looks like torch.Size([1625])\n", + "rewards looks like (1947,)\n", + "log_probs looks like (1947,)\n", + "logs prob looks like torch.Size([1947])\n", + "torch.from_numpy(rewards) looks like torch.Size([1947])\n", + "rewards looks like (1589,)\n", + "log_probs looks like (1589,)\n", + "logs prob looks like torch.Size([1589])\n", + "torch.from_numpy(rewards) looks like torch.Size([1589])\n", + "rewards looks like (1625,)\n", + "log_probs looks like (1625,)\n", + "logs prob looks like torch.Size([1625])\n", + "torch.from_numpy(rewards) looks like torch.Size([1625])\n", + "rewards looks like (1492,)\n", + "log_probs looks like (1492,)\n", + "logs prob looks like torch.Size([1492])\n", + "torch.from_numpy(rewards) looks like torch.Size([1492])\n", + "rewards looks like (1347,)\n", + "log_probs looks like (1347,)\n", + "logs prob looks like torch.Size([1347])\n", + "torch.from_numpy(rewards) looks like torch.Size([1347])\n", + "rewards looks like (2110,)\n", + "log_probs looks like (2110,)\n", + "logs prob looks like torch.Size([2110])\n", + "torch.from_numpy(rewards) looks like torch.Size([2110])\n", + "rewards looks like (877,)\n", + "log_probs looks like (877,)\n", + "logs prob looks like torch.Size([877])\n", + "torch.from_numpy(rewards) looks like torch.Size([877])\n", + "rewards looks like (1078,)\n", + "log_probs looks like (1078,)\n", + "logs prob looks like torch.Size([1078])\n", + "torch.from_numpy(rewards) looks like torch.Size([1078])\n", + "rewards looks like (2001,)\n", + "log_probs looks like (2001,)\n", + "logs prob looks like torch.Size([2001])\n", + "torch.from_numpy(rewards) looks like torch.Size([2001])\n", + "rewards looks like (1452,)\n", + "log_probs looks like (1452,)\n", + "logs prob looks like torch.Size([1452])\n", + "torch.from_numpy(rewards) looks like torch.Size([1452])\n", + "rewards looks like (1169,)\n", + "log_probs looks like (1169,)\n", + "logs prob looks like torch.Size([1169])\n", + "torch.from_numpy(rewards) looks like torch.Size([1169])\n", + "rewards looks like (1977,)\n", + "log_probs looks like (1977,)\n", + "logs prob looks like torch.Size([1977])\n", + "torch.from_numpy(rewards) looks like torch.Size([1977])\n", + "rewards looks like (1263,)\n", + "log_probs looks like (1263,)\n", + "logs prob looks like torch.Size([1263])\n", + "torch.from_numpy(rewards) looks like torch.Size([1263])\n", + "rewards looks like (2219,)\n", + "log_probs looks like (2219,)\n", + "logs prob looks like torch.Size([2219])\n", + "torch.from_numpy(rewards) looks like torch.Size([2219])\n", + "rewards looks like (1732,)\n", + "log_probs looks like (1732,)\n", + "logs prob looks like torch.Size([1732])\n", + "torch.from_numpy(rewards) looks like torch.Size([1732])\n", + "rewards looks like (1413,)\n", + "log_probs looks like (1413,)\n", + "logs prob looks like torch.Size([1413])\n", + "torch.from_numpy(rewards) looks like torch.Size([1413])\n", + "rewards looks like (1099,)\n", + "log_probs looks like (1099,)\n", + "logs prob looks like torch.Size([1099])\n", + "torch.from_numpy(rewards) looks like torch.Size([1099])\n", + "rewards looks like (1184,)\n", + "log_probs looks like (1184,)\n", + "logs prob looks like torch.Size([1184])\n", + "torch.from_numpy(rewards) looks like torch.Size([1184])\n", + "rewards looks like (1148,)\n", + "log_probs looks like (1148,)\n", + "logs prob looks like torch.Size([1148])\n", + "torch.from_numpy(rewards) looks like torch.Size([1148])\n", + "rewards looks like (1339,)\n", + "log_probs looks like (1339,)\n", + "logs prob looks like torch.Size([1339])\n", + "torch.from_numpy(rewards) looks like torch.Size([1339])\n", + "rewards looks like (2095,)\n", + "log_probs looks like (2095,)\n", + "logs prob looks like torch.Size([2095])\n", + "torch.from_numpy(rewards) looks like torch.Size([2095])\n", + "rewards looks like (1514,)\n", + "log_probs looks like (1514,)\n", + "logs prob looks like torch.Size([1514])\n", + "torch.from_numpy(rewards) looks like torch.Size([1514])\n", + "rewards looks like (1276,)\n", + "log_probs looks like (1276,)\n", + "logs prob looks like torch.Size([1276])\n", + "torch.from_numpy(rewards) looks like torch.Size([1276])\n", + "rewards looks like (1277,)\n", + "log_probs looks like (1277,)\n", + "logs prob looks like torch.Size([1277])\n", + "torch.from_numpy(rewards) looks like torch.Size([1277])\n", + "rewards looks like (1453,)\n", + "log_probs looks like (1453,)\n", + "logs prob looks like torch.Size([1453])\n", + "torch.from_numpy(rewards) looks like torch.Size([1453])\n", + "rewards looks like (1467,)\n", + "log_probs looks like (1467,)\n", + "logs prob looks like torch.Size([1467])\n", + "torch.from_numpy(rewards) looks like torch.Size([1467])\n", + "rewards looks like (1383,)\n", + "log_probs looks like (1383,)\n", + "logs prob looks like torch.Size([1383])\n", + "torch.from_numpy(rewards) looks like torch.Size([1383])\n", + "rewards looks like (1741,)\n", + "log_probs looks like (1741,)\n", + "logs prob looks like torch.Size([1741])\n", + "torch.from_numpy(rewards) looks like torch.Size([1741])\n", + "rewards looks like (1039,)\n", + "log_probs looks like (1039,)\n", + "logs prob looks like torch.Size([1039])\n", + "torch.from_numpy(rewards) looks like torch.Size([1039])\n", + "rewards looks like (1063,)\n", + "log_probs looks like (1063,)\n", + "logs prob looks like torch.Size([1063])\n", + "torch.from_numpy(rewards) looks like torch.Size([1063])\n", + "rewards looks like (1731,)\n", + "log_probs looks like (1731,)\n", + "logs prob looks like torch.Size([1731])\n", + "torch.from_numpy(rewards) looks like torch.Size([1731])\n", + "rewards looks like (2661,)\n", + "log_probs looks like (2661,)\n", + "logs prob looks like torch.Size([2661])\n", + "torch.from_numpy(rewards) looks like torch.Size([2661])\n", + "rewards looks like (704,)\n", + "log_probs looks like (704,)\n", + "logs prob looks like torch.Size([704])\n", + "torch.from_numpy(rewards) looks like torch.Size([704])\n", + "rewards looks like (1389,)\n", + "log_probs looks like (1389,)\n", + "logs prob looks like torch.Size([1389])\n", + "torch.from_numpy(rewards) looks like torch.Size([1389])\n", + "rewards looks like (2131,)\n", + "log_probs looks like (2131,)\n", + "logs prob looks like torch.Size([2131])\n", + "torch.from_numpy(rewards) looks like torch.Size([2131])\n", + "rewards looks like (1779,)\n", + "log_probs looks like (1779,)\n", + "logs prob looks like torch.Size([1779])\n", + "torch.from_numpy(rewards) looks like torch.Size([1779])\n", + "rewards looks like (1415,)\n", + "log_probs looks like (1415,)\n", + "logs prob looks like torch.Size([1415])\n", + "torch.from_numpy(rewards) looks like torch.Size([1415])\n", + "rewards looks like (2320,)\n", + "log_probs looks like (2320,)\n", + "logs prob looks like torch.Size([2320])\n", + "torch.from_numpy(rewards) looks like torch.Size([2320])\n", + "rewards looks like (1147,)\n", + "log_probs looks like (1147,)\n", + "logs prob looks like torch.Size([1147])\n", + "torch.from_numpy(rewards) looks like torch.Size([1147])\n", + "rewards looks like (1022,)\n", + "log_probs looks like (1022,)\n", + "logs prob looks like torch.Size([1022])\n", + "torch.from_numpy(rewards) looks like torch.Size([1022])\n", + "rewards looks like (2141,)\n", + "log_probs looks like (2141,)\n", + "logs prob looks like torch.Size([2141])\n", + "torch.from_numpy(rewards) looks like torch.Size([2141])\n", + "rewards looks like (1362,)\n", + "log_probs looks like (1362,)\n", + "logs prob looks like torch.Size([1362])\n", + "torch.from_numpy(rewards) looks like torch.Size([1362])\n", + "rewards looks like (1450,)\n", + "log_probs looks like (1450,)\n", + "logs prob looks like torch.Size([1450])\n", + "torch.from_numpy(rewards) looks like torch.Size([1450])\n", + "rewards looks like (1546,)\n", + "log_probs looks like (1546,)\n", + "logs prob looks like torch.Size([1546])\n", + "torch.from_numpy(rewards) looks like torch.Size([1546])\n", + "rewards looks like (1166,)\n", + "log_probs looks like (1166,)\n", + "logs prob looks like torch.Size([1166])\n", + "torch.from_numpy(rewards) looks like torch.Size([1166])\n", + "rewards looks like (1647,)\n", + "log_probs looks like (1647,)\n", + "logs prob looks like torch.Size([1647])\n", + "torch.from_numpy(rewards) looks like torch.Size([1647])\n", + "rewards looks like (1205,)\n", + "log_probs looks like (1205,)\n", + "logs prob looks like torch.Size([1205])\n", + "torch.from_numpy(rewards) looks like torch.Size([1205])\n", + "rewards looks like (2098,)\n", + "log_probs looks like (2098,)\n", + "logs prob looks like torch.Size([2098])\n", + "torch.from_numpy(rewards) looks like torch.Size([2098])\n", + "rewards looks like (1940,)\n", + "log_probs looks like (1940,)\n", + "logs prob looks like torch.Size([1940])\n", + "torch.from_numpy(rewards) looks like torch.Size([1940])\n", + "rewards looks like (2191,)\n", + "log_probs looks like (2191,)\n", + "logs prob looks like torch.Size([2191])\n", + "torch.from_numpy(rewards) looks like torch.Size([2191])\n", + "rewards looks like (2740,)\n", + "log_probs looks like (2740,)\n", + "logs prob looks like torch.Size([2740])\n", + "torch.from_numpy(rewards) looks like torch.Size([2740])\n", + "rewards looks like (587,)\n", + "log_probs looks like (587,)\n", + "logs prob looks like torch.Size([587])\n", + "torch.from_numpy(rewards) looks like torch.Size([587])\n", + "rewards looks like (1063,)\n", + "log_probs looks like (1063,)\n", + "logs prob looks like torch.Size([1063])\n", + "torch.from_numpy(rewards) looks like torch.Size([1063])\n", + "rewards looks like (861,)\n", + "log_probs looks like (861,)\n", + "logs prob looks like torch.Size([861])\n", + "torch.from_numpy(rewards) looks like torch.Size([861])\n", + "rewards looks like (1051,)\n", + "log_probs looks like (1051,)\n", + "logs prob looks like torch.Size([1051])\n", + "torch.from_numpy(rewards) looks like torch.Size([1051])\n", + "rewards looks like (1389,)\n", + "log_probs looks like (1389,)\n", + "logs prob looks like torch.Size([1389])\n", + "torch.from_numpy(rewards) looks like torch.Size([1389])\n", + "rewards looks like (1152,)\n", + "log_probs looks like (1152,)\n", + "logs prob looks like torch.Size([1152])\n", + "torch.from_numpy(rewards) looks like torch.Size([1152])\n", + "rewards looks like (1103,)\n", + "log_probs looks like (1103,)\n", + "logs prob looks like torch.Size([1103])\n", + "torch.from_numpy(rewards) looks like torch.Size([1103])\n", + "rewards looks like (1887,)\n", + "log_probs looks like (1887,)\n", + "logs prob looks like torch.Size([1887])\n", + "torch.from_numpy(rewards) looks like torch.Size([1887])\n", + "rewards looks like (1753,)\n", + "log_probs looks like (1753,)\n", + "logs prob looks like torch.Size([1753])\n", + "torch.from_numpy(rewards) looks like torch.Size([1753])\n", + "rewards looks like (1372,)\n", + "log_probs looks like (1372,)\n", + "logs prob looks like torch.Size([1372])\n", + "torch.from_numpy(rewards) looks like torch.Size([1372])\n", + "rewards looks like (1056,)\n", + "log_probs looks like (1056,)\n", + "logs prob looks like torch.Size([1056])\n", + "torch.from_numpy(rewards) looks like torch.Size([1056])\n", + "rewards looks like (1465,)\n", + "log_probs looks like (1465,)\n", + "logs prob looks like torch.Size([1465])\n", + "torch.from_numpy(rewards) looks like torch.Size([1465])\n", + "rewards looks like (3297,)\n", + "log_probs looks like (3297,)\n", + "logs prob looks like torch.Size([3297])\n", + "torch.from_numpy(rewards) looks like torch.Size([3297])\n", + "rewards looks like (2492,)\n", + "log_probs looks like (2492,)\n", + "logs prob looks like torch.Size([2492])\n", + "torch.from_numpy(rewards) looks like torch.Size([2492])\n", + "rewards looks like (1580,)\n", + "log_probs looks like (1580,)\n", + "logs prob looks like torch.Size([1580])\n", + "torch.from_numpy(rewards) looks like torch.Size([1580])\n", + "rewards looks like (1357,)\n", + "log_probs looks like (1357,)\n", + "logs prob looks like torch.Size([1357])\n", + "torch.from_numpy(rewards) looks like torch.Size([1357])\n", + "rewards looks like (1227,)\n", + "log_probs looks like (1227,)\n", + "logs prob looks like torch.Size([1227])\n", + "torch.from_numpy(rewards) looks like torch.Size([1227])\n", + "rewards looks like (2123,)\n", + "log_probs looks like (2123,)\n", + "logs prob looks like torch.Size([2123])\n", + "torch.from_numpy(rewards) looks like torch.Size([2123])\n", + "rewards looks like (1864,)\n", + "log_probs looks like (1864,)\n", + "logs prob looks like torch.Size([1864])\n", + "torch.from_numpy(rewards) looks like torch.Size([1864])\n", + "rewards looks like (1324,)\n", + "log_probs looks like (1324,)\n", + "logs prob looks like torch.Size([1324])\n", + "torch.from_numpy(rewards) looks like torch.Size([1324])\n", + "rewards looks like (1281,)\n", + "log_probs looks like (1281,)\n", + "logs prob looks like torch.Size([1281])\n", + "torch.from_numpy(rewards) looks like torch.Size([1281])\n", + "rewards looks like (1366,)\n", + "log_probs looks like (1366,)\n", + "logs prob looks like torch.Size([1366])\n", + "torch.from_numpy(rewards) looks like torch.Size([1366])\n", + "rewards looks like (957,)\n", + "log_probs looks like (957,)\n", + "logs prob looks like torch.Size([957])\n", + "torch.from_numpy(rewards) looks like torch.Size([957])\n", + "rewards looks like (1187,)\n", + "log_probs looks like (1187,)\n", + "logs prob looks like torch.Size([1187])\n", + "torch.from_numpy(rewards) looks like torch.Size([1187])\n", + "rewards looks like (1625,)\n", + "log_probs looks like (1625,)\n", + "logs prob looks like torch.Size([1625])\n", + "torch.from_numpy(rewards) looks like torch.Size([1625])\n", + "rewards looks like (1605,)\n", + "log_probs looks like (1605,)\n", + "logs prob looks like torch.Size([1605])\n", + "torch.from_numpy(rewards) looks like torch.Size([1605])\n", + "rewards looks like (1015,)\n", + "log_probs looks like (1015,)\n", + "logs prob looks like torch.Size([1015])\n", + "torch.from_numpy(rewards) looks like torch.Size([1015])\n", + "rewards looks like (1565,)\n", + "log_probs looks like (1565,)\n", + "logs prob looks like torch.Size([1565])\n", + "torch.from_numpy(rewards) looks like torch.Size([1565])\n", + "rewards looks like (1353,)\n", + "log_probs looks like (1353,)\n", + "logs prob looks like torch.Size([1353])\n", + "torch.from_numpy(rewards) looks like torch.Size([1353])\n", + "rewards looks like (1321,)\n", + "log_probs looks like (1321,)\n", + "logs prob looks like torch.Size([1321])\n", + "torch.from_numpy(rewards) looks like torch.Size([1321])\n", + "rewards looks like (1074,)\n", + "log_probs looks like (1074,)\n", + "logs prob looks like torch.Size([1074])\n", + "torch.from_numpy(rewards) looks like torch.Size([1074])\n", + "rewards looks like (1301,)\n", + "log_probs looks like (1301,)\n", + "logs prob looks like torch.Size([1301])\n", + "torch.from_numpy(rewards) looks like torch.Size([1301])\n", + "rewards looks like (2105,)\n", + "log_probs looks like (2105,)\n", + "logs prob looks like torch.Size([2105])\n", + "torch.from_numpy(rewards) looks like torch.Size([2105])\n", + "rewards looks like (2008,)\n", + "log_probs looks like (2008,)\n", + "logs prob looks like torch.Size([2008])\n", + "torch.from_numpy(rewards) looks like torch.Size([2008])\n", + "rewards looks like (1885,)\n", + "log_probs looks like (1885,)\n", + "logs prob looks like torch.Size([1885])\n", + "torch.from_numpy(rewards) looks like torch.Size([1885])\n", + "rewards looks like (1184,)\n", + "log_probs looks like (1184,)\n", + "logs prob looks like torch.Size([1184])\n", + "torch.from_numpy(rewards) looks like torch.Size([1184])\n", + "rewards looks like (2551,)\n", + "log_probs looks like (2551,)\n", + "logs prob looks like torch.Size([2551])\n", + "torch.from_numpy(rewards) looks like torch.Size([2551])\n", + "rewards looks like (1330,)\n", + "log_probs looks like (1330,)\n", + "logs prob looks like torch.Size([1330])\n", + "torch.from_numpy(rewards) looks like torch.Size([1330])\n", + "rewards looks like (1510,)\n", + "log_probs looks like (1510,)\n", + "logs prob looks like torch.Size([1510])\n", + "torch.from_numpy(rewards) looks like torch.Size([1510])\n", + "rewards looks like (1330,)\n", + "log_probs looks like (1330,)\n", + "logs prob looks like torch.Size([1330])\n", + "torch.from_numpy(rewards) looks like torch.Size([1330])\n", + "rewards looks like (2157,)\n", + "log_probs looks like (2157,)\n", + "logs prob looks like torch.Size([2157])\n", + "torch.from_numpy(rewards) looks like torch.Size([2157])\n", + "rewards looks like (1276,)\n", + "log_probs looks like (1276,)\n", + "logs prob looks like torch.Size([1276])\n", + "torch.from_numpy(rewards) looks like torch.Size([1276])\n", + "rewards looks like (1188,)\n", + "log_probs looks like (1188,)\n", + "logs prob looks like torch.Size([1188])\n", + "torch.from_numpy(rewards) looks like torch.Size([1188])\n", + "rewards looks like (2381,)\n", + "log_probs looks like (2381,)\n", + "logs prob looks like torch.Size([2381])\n", + "torch.from_numpy(rewards) looks like torch.Size([2381])\n", + "rewards looks like (1450,)\n", + "log_probs looks like (1450,)\n", + "logs prob looks like torch.Size([1450])\n", + "torch.from_numpy(rewards) looks like torch.Size([1450])\n", + "rewards looks like (1612,)\n", + "log_probs looks like (1612,)\n", + "logs prob looks like torch.Size([1612])\n", + "torch.from_numpy(rewards) looks like torch.Size([1612])\n", + "rewards looks like (1780,)\n", + "log_probs looks like (1780,)\n", + "logs prob looks like torch.Size([1780])\n", + "torch.from_numpy(rewards) looks like torch.Size([1780])\n", + "rewards looks like (1350,)\n", + "log_probs looks like (1350,)\n", + "logs prob looks like torch.Size([1350])\n", + "torch.from_numpy(rewards) looks like torch.Size([1350])\n", + "rewards looks like (1459,)\n", + "log_probs looks like (1459,)\n", + "logs prob looks like torch.Size([1459])\n", + "torch.from_numpy(rewards) looks like torch.Size([1459])\n", + "rewards looks like (1958,)\n", + "log_probs looks like (1958,)\n", + "logs prob looks like torch.Size([1958])\n", + "torch.from_numpy(rewards) looks like torch.Size([1958])\n", + "rewards looks like (1325,)\n", + "log_probs looks like (1325,)\n", + "logs prob looks like torch.Size([1325])\n", + "torch.from_numpy(rewards) looks like torch.Size([1325])\n", + "rewards looks like (2168,)\n", + "log_probs looks like (2168,)\n", + "logs prob looks like torch.Size([2168])\n", + "torch.from_numpy(rewards) looks like torch.Size([2168])\n", + "rewards looks like (1682,)\n", + "log_probs looks like (1682,)\n", + "logs prob looks like torch.Size([1682])\n", + "torch.from_numpy(rewards) looks like torch.Size([1682])\n", + "rewards looks like (852,)\n", + "log_probs looks like (852,)\n", + "logs prob looks like torch.Size([852])\n", + "torch.from_numpy(rewards) looks like torch.Size([852])\n", + "rewards looks like (1757,)\n", + "log_probs looks like (1757,)\n", + "logs prob looks like torch.Size([1757])\n", + "torch.from_numpy(rewards) looks like torch.Size([1757])\n", + "rewards looks like (2313,)\n", + "log_probs looks like (2313,)\n", + "logs prob looks like torch.Size([2313])\n", + "torch.from_numpy(rewards) looks like torch.Size([2313])\n", + "rewards looks like (1662,)\n", + "log_probs looks like (1662,)\n", + "logs prob looks like torch.Size([1662])\n", + "torch.from_numpy(rewards) looks like torch.Size([1662])\n", + "rewards looks like (1559,)\n", + "log_probs looks like (1559,)\n", + "logs prob looks like torch.Size([1559])\n", + "torch.from_numpy(rewards) looks like torch.Size([1559])\n", + "rewards looks like (2077,)\n", + "log_probs looks like (2077,)\n", + "logs prob looks like torch.Size([2077])\n", + "torch.from_numpy(rewards) looks like torch.Size([2077])\n", + "rewards looks like (2119,)\n", + "log_probs looks like (2119,)\n", + "logs prob looks like torch.Size([2119])\n", + "torch.from_numpy(rewards) looks like torch.Size([2119])\n", + "rewards looks like (954,)\n", + "log_probs looks like (954,)\n", + "logs prob looks like torch.Size([954])\n", + "torch.from_numpy(rewards) looks like torch.Size([954])\n", + "rewards looks like (1797,)\n", + "log_probs looks like (1797,)\n", + "logs prob looks like torch.Size([1797])\n", + "torch.from_numpy(rewards) looks like torch.Size([1797])\n", + "rewards looks like (1579,)\n", + "log_probs looks like (1579,)\n", + "logs prob looks like torch.Size([1579])\n", + "torch.from_numpy(rewards) looks like torch.Size([1579])\n", + "rewards looks like (1277,)\n", + "log_probs looks like (1277,)\n", + "logs prob looks like torch.Size([1277])\n", + "torch.from_numpy(rewards) looks like torch.Size([1277])\n", + "rewards looks like (1196,)\n", + "log_probs looks like (1196,)\n", + "logs prob looks like torch.Size([1196])\n", + "torch.from_numpy(rewards) looks like torch.Size([1196])\n", + "rewards looks like (1294,)\n", + "log_probs looks like (1294,)\n", + "logs prob looks like torch.Size([1294])\n", + "torch.from_numpy(rewards) looks like torch.Size([1294])\n", + "rewards looks like (1318,)\n", + "log_probs looks like (1318,)\n", + "logs prob looks like torch.Size([1318])\n", + "torch.from_numpy(rewards) looks like torch.Size([1318])\n", + "rewards looks like (2605,)\n", + "log_probs looks like (2605,)\n", + "logs prob looks like torch.Size([2605])\n", + "torch.from_numpy(rewards) looks like torch.Size([2605])\n", + "rewards looks like (2002,)\n", + "log_probs looks like (2002,)\n", + "logs prob looks like torch.Size([2002])\n", + "torch.from_numpy(rewards) looks like torch.Size([2002])\n", + "rewards looks like (1354,)\n", + "log_probs looks like (1354,)\n", + "logs prob looks like torch.Size([1354])\n", + "torch.from_numpy(rewards) looks like torch.Size([1354])\n", + "rewards looks like (1785,)\n", + "log_probs looks like (1785,)\n", + "logs prob looks like torch.Size([1785])\n", + "torch.from_numpy(rewards) looks like torch.Size([1785])\n", + "rewards looks like (781,)\n", + "log_probs looks like (781,)\n", + "logs prob looks like torch.Size([781])\n", + "torch.from_numpy(rewards) looks like torch.Size([781])\n", + "rewards looks like (1965,)\n", + "log_probs looks like (1965,)\n", + "logs prob looks like torch.Size([1965])\n", + "torch.from_numpy(rewards) looks like torch.Size([1965])\n", + "rewards looks like (1135,)\n", + "log_probs looks like (1135,)\n", + "logs prob looks like torch.Size([1135])\n", + "torch.from_numpy(rewards) looks like torch.Size([1135])\n", + "rewards looks like (1672,)\n", + "log_probs looks like (1672,)\n", + "logs prob looks like torch.Size([1672])\n", + "torch.from_numpy(rewards) looks like torch.Size([1672])\n", + "rewards looks like (1278,)\n", + "log_probs looks like (1278,)\n", + "logs prob looks like torch.Size([1278])\n", + "torch.from_numpy(rewards) looks like torch.Size([1278])\n", + "rewards looks like (2499,)\n", + "log_probs looks like (2499,)\n", + "logs prob looks like torch.Size([2499])\n", + "torch.from_numpy(rewards) looks like torch.Size([2499])\n", + "rewards looks like (1275,)\n", + "log_probs looks like (1275,)\n", + "logs prob looks like torch.Size([1275])\n", + "torch.from_numpy(rewards) looks like torch.Size([1275])\n", + "rewards looks like (1144,)\n", + "log_probs looks like (1144,)\n", + "logs prob looks like torch.Size([1144])\n", + "torch.from_numpy(rewards) looks like torch.Size([1144])\n", + "rewards looks like (1605,)\n", + "log_probs looks like (1605,)\n", + "logs prob looks like torch.Size([1605])\n", + "torch.from_numpy(rewards) looks like torch.Size([1605])\n", + "rewards looks like (1178,)\n", + "log_probs looks like (1178,)\n", + "logs prob looks like torch.Size([1178])\n", + "torch.from_numpy(rewards) looks like torch.Size([1178])\n", + "rewards looks like (3269,)\n", + "log_probs looks like (3269,)\n", + "logs prob looks like torch.Size([3269])\n", + "torch.from_numpy(rewards) looks like torch.Size([3269])\n", + "rewards looks like (1492,)\n", + "log_probs looks like (1492,)\n", + "logs prob looks like torch.Size([1492])\n", + "torch.from_numpy(rewards) looks like torch.Size([1492])\n", + "rewards looks like (1285,)\n", + "log_probs looks like (1285,)\n", + "logs prob looks like torch.Size([1285])\n", + "torch.from_numpy(rewards) looks like torch.Size([1285])\n", + "rewards looks like (1687,)\n", + "log_probs looks like (1687,)\n", + "logs prob looks like torch.Size([1687])\n", + "torch.from_numpy(rewards) looks like torch.Size([1687])\n", + "rewards looks like (1124,)\n", + "log_probs looks like (1124,)\n", + "logs prob looks like torch.Size([1124])\n", + "torch.from_numpy(rewards) looks like torch.Size([1124])\n", + "rewards looks like (2043,)\n", + "log_probs looks like (2043,)\n", + "logs prob looks like torch.Size([2043])\n", + "torch.from_numpy(rewards) looks like torch.Size([2043])\n", + "rewards looks like (1280,)\n", + "log_probs looks like (1280,)\n", + "logs prob looks like torch.Size([1280])\n", + "torch.from_numpy(rewards) looks like torch.Size([1280])\n", + "rewards looks like (1418,)\n", + "log_probs looks like (1418,)\n", + "logs prob looks like torch.Size([1418])\n", + "torch.from_numpy(rewards) looks like torch.Size([1418])\n", + "rewards looks like (1365,)\n", + "log_probs looks like (1365,)\n", + "logs prob looks like torch.Size([1365])\n", + "torch.from_numpy(rewards) looks like torch.Size([1365])\n", + "rewards looks like (1091,)\n", + "log_probs looks like (1091,)\n", + "logs prob looks like torch.Size([1091])\n", + "torch.from_numpy(rewards) looks like torch.Size([1091])\n", + "rewards looks like (1279,)\n", + "log_probs looks like (1279,)\n", + "logs prob looks like torch.Size([1279])\n", + "torch.from_numpy(rewards) looks like torch.Size([1279])\n", + "rewards looks like (1109,)\n", + "log_probs looks like (1109,)\n", + "logs prob looks like torch.Size([1109])\n", + "torch.from_numpy(rewards) looks like torch.Size([1109])\n", + "rewards looks like (1285,)\n", + "log_probs looks like (1285,)\n", + "logs prob looks like torch.Size([1285])\n", + "torch.from_numpy(rewards) looks like torch.Size([1285])\n", + "rewards looks like (1222,)\n", + "log_probs looks like (1222,)\n", + "logs prob looks like torch.Size([1222])\n", + "torch.from_numpy(rewards) looks like torch.Size([1222])\n", + "rewards looks like (1538,)\n", + "log_probs looks like (1538,)\n", + "logs prob looks like torch.Size([1538])\n", + "torch.from_numpy(rewards) looks like torch.Size([1538])\n", + "rewards looks like (1139,)\n", + "log_probs looks like (1139,)\n", + "logs prob looks like torch.Size([1139])\n", + "torch.from_numpy(rewards) looks like torch.Size([1139])\n", + "rewards looks like (1354,)\n", + "log_probs looks like (1354,)\n", + "logs prob looks like torch.Size([1354])\n", + "torch.from_numpy(rewards) looks like torch.Size([1354])\n", + "rewards looks like (1166,)\n", + "log_probs looks like (1166,)\n", + "logs prob looks like torch.Size([1166])\n", + "torch.from_numpy(rewards) looks like torch.Size([1166])\n", + "rewards looks like (1348,)\n", + "log_probs looks like (1348,)\n", + "logs prob looks like torch.Size([1348])\n", + "torch.from_numpy(rewards) looks like torch.Size([1348])\n", + "rewards looks like (1347,)\n", + "log_probs looks like (1347,)\n", + "logs prob looks like torch.Size([1347])\n", + "torch.from_numpy(rewards) looks like torch.Size([1347])\n", + "rewards looks like (2059,)\n", + "log_probs looks like (2059,)\n", + "logs prob looks like torch.Size([2059])\n", + "torch.from_numpy(rewards) looks like torch.Size([2059])\n", + "rewards looks like (2021,)\n", + "log_probs looks like (2021,)\n", + "logs prob looks like torch.Size([2021])\n", + "torch.from_numpy(rewards) looks like torch.Size([2021])\n", + "rewards looks like (2232,)\n", + "log_probs looks like (2232,)\n", + "logs prob looks like torch.Size([2232])\n", + "torch.from_numpy(rewards) looks like torch.Size([2232])\n", + "rewards looks like (1102,)\n", + "log_probs looks like (1102,)\n", + "logs prob looks like torch.Size([1102])\n", + "torch.from_numpy(rewards) looks like torch.Size([1102])\n", + "rewards looks like (1165,)\n", + "log_probs looks like (1165,)\n", + "logs prob looks like torch.Size([1165])\n", + "torch.from_numpy(rewards) looks like torch.Size([1165])\n", + "rewards looks like (1264,)\n", + "log_probs looks like (1264,)\n", + "logs prob looks like torch.Size([1264])\n", + "torch.from_numpy(rewards) looks like torch.Size([1264])\n", + "rewards looks like (1346,)\n", + "log_probs looks like (1346,)\n", + "logs prob looks like torch.Size([1346])\n", + "torch.from_numpy(rewards) looks like torch.Size([1346])\n", + "rewards looks like (2848,)\n", + "log_probs looks like (2848,)\n", + "logs prob looks like torch.Size([2848])\n", + "torch.from_numpy(rewards) looks like torch.Size([2848])\n", + "rewards looks like (938,)\n", + "log_probs looks like (938,)\n", + "logs prob looks like torch.Size([938])\n", + "torch.from_numpy(rewards) looks like torch.Size([938])\n", + "rewards looks like (1069,)\n", + "log_probs looks like (1069,)\n", + "logs prob looks like torch.Size([1069])\n", + "torch.from_numpy(rewards) looks like torch.Size([1069])\n", + "rewards looks like (2588,)\n", + "log_probs looks like (2588,)\n", + "logs prob looks like torch.Size([2588])\n", + "torch.from_numpy(rewards) looks like torch.Size([2588])\n", + "rewards looks like (1461,)\n", + "log_probs looks like (1461,)\n", + "logs prob looks like torch.Size([1461])\n", + "torch.from_numpy(rewards) looks like torch.Size([1461])\n", + "rewards looks like (2153,)\n", + "log_probs looks like (2153,)\n", + "logs prob looks like torch.Size([2153])\n", + "torch.from_numpy(rewards) looks like torch.Size([2153])\n", + "rewards looks like (2312,)\n", + "log_probs looks like (2312,)\n", + "logs prob looks like torch.Size([2312])\n", + "torch.from_numpy(rewards) looks like torch.Size([2312])\n", + "rewards looks like (1636,)\n", + "log_probs looks like (1636,)\n", + "logs prob looks like torch.Size([1636])\n", + "torch.from_numpy(rewards) looks like torch.Size([1636])\n", + "rewards looks like (2019,)\n", + "log_probs looks like (2019,)\n", + "logs prob looks like torch.Size([2019])\n", + "torch.from_numpy(rewards) looks like torch.Size([2019])\n", + "rewards looks like (1450,)\n", + "log_probs looks like (1450,)\n", + "logs prob looks like torch.Size([1450])\n", + "torch.from_numpy(rewards) looks like torch.Size([1450])\n", + "rewards looks like (2105,)\n", + "log_probs looks like (2105,)\n", + "logs prob looks like torch.Size([2105])\n", + "torch.from_numpy(rewards) looks like torch.Size([2105])\n", + "\n" + ], + "name": "stdout" + } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "vNb_tuFYhKVK" + }, + "source": [ + "### Training Result\n", + "During the training process, we recorded `avg_total_reward`, which represents the average total reward of episodes before updating the policy network.\n", + "\n", + "Theoretically, if the agent becomes better, the `avg_total_reward` will increase.\n", + "The visualization of the training process is shown below: \n" + ] + }, + { + "cell_type": "code", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 281 + }, + "id": "wZYOI8H10SHN", + "outputId": "54840043-6fe4-4771-e8c9-78785c55aa79" + }, + "source": [ + "plt.plot(avg_total_rewards)\n", + "plt.title(\"Total Rewards\")\n", + "plt.show()" + ], + "execution_count": null, + "outputs": [ + { + "output_type": "display_data", + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "tags": [], + "needs_background": "light" + } + } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "mV5jj4dThz0Y" + }, + "source": [ + "In addition, `avg_final_reward` represents average final rewards of episodes. To be specific, final rewards is the last reward received in one episode, indicating whether the craft lands successfully or not.\n" + ] + }, + { + "cell_type": "code", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "txDZ5vlGWz5w", + "outputId": "c7c5e9ca-6329-4ee2-f3d6-a1ba46b5aea2" + }, + "source": [ + "plt.plot(avg_final_rewards)\n", + "plt.title(\"Final Rewards\")\n", + "plt.show()" + ], + "execution_count": null, + "outputs": [ + { + "output_type": "display_data", + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "tags": [], + "needs_background": "light" + } + } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "u2HaGRVEYGQS" + }, + "source": [ + "## Testing\n", + "The testing result will be the average reward of 5 testing" + ] + }, + { + "cell_type": "code", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 286 + }, + "id": "5yFuUKKRYH73", + "outputId": "3ad20d21-4c56-47b7-e617-d71e4211aed7" + }, + "source": [ + "fix(env, seed)\n", + "agent.network.eval() # set the network into evaluation mode\n", + "NUM_OF_TEST = 5 # Do not revise this !!!\n", + "test_total_reward = []\n", + "action_list = []\n", + "for i in range(NUM_OF_TEST):\n", + " actions = []\n", + " state = env.reset()\n", + "\n", + " img = plt.imshow(env.render(mode='rgb_array'))\n", + "\n", + " total_reward = 0\n", + "\n", + " done = False\n", + " while not done:\n", + " action, _ = agent.sample(state)\n", + " actions.append(action)\n", + " state, reward, done, _ = env.step(action)\n", + "\n", + " total_reward += reward\n", + "\n", + " img.set_data(env.render(mode='rgb_array'))\n", + " display.display(plt.gcf())\n", + " display.clear_output(wait=True)\n", + " \n", + " print(total_reward)\n", + " test_total_reward.append(total_reward)\n", + "\n", + " action_list.append(actions) # save the result of testing \n" + ], + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "text": [ + "-207.9114975585693\n" + ], + "name": "stdout" + }, + { + "output_type": "display_data", + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "tags": [], + "needs_background": "light" + } + } + ] + }, + { + "cell_type": "code", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "Aex7mcKr0J01", + "outputId": "a36aaa35-ec20-4089-ddde-ab4742d3e90e" + }, + "source": [ + "print(np.mean(test_total_reward))" + ], + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "text": [ + "-147.2620449863271\n" + ], + "name": "stdout" + } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "leyebGYRpqsF" + }, + "source": [ + "Action list" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "hGAH4YWDpp4u", + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "24f547ee-6648-4d2a-f9c7-718b23e93251" + }, + "source": [ + "print(\"Action list looks like \", action_list)\n", + "print(\"Action list's shape looks like \", np.shape(action_list))" + ], + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "text": [ + "Action list looks like [[2, 2, 2, 2, 2, 2, 1, 2, 1, 2, 1, 2, 2, 2, 2, 1, 1, 2, 1, 2, 1, 2, 1, 2, 2, 1, 2, 1, 2, 1, 1, 2, 2, 2, 2, 2, 2, 1, 2, 2, 2, 1, 2, 2, 2, 2, 2, 0, 3, 2, 3, 2, 3, 3, 2, 2, 3, 3, 3, 3, 3, 3, 2, 3, 3, 3, 3, 3, 3, 2, 3, 3, 2, 3, 2, 3, 3, 2, 2, 3, 3, 2, 2, 3, 3, 2, 2, 3, 2, 2, 3, 2, 2, 2, 2, 3, 2, 2, 2, 3, 2, 3, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 1, 1, 2, 1, 2, 1, 1, 2, 1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 2, 1, 2, 1, 2, 1, 1, 2, 2, 2, 1, 1, 2, 2, 1, 1, 1, 2, 2, 1, 2, 2, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 1, 2, 2, 2, 2, 2, 2, 2, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 2, 3, 3, 3, 3, 3, 3, 3, 2, 3, 3, 3, 3, 2, 3, 3, 3, 3, 3, 2, 2, 3, 2, 3, 3, 2, 3, 2, 3, 2, 2, 2, 3, 2, 2, 2, 3, 2, 2, 2, 2, 2, 3, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3], [2, 2, 2, 2, 2, 2, 2, 2, 1, 2, 1, 1, 2, 2, 1, 2, 1, 2, 1, 2, 2, 2, 1, 2, 1, 1, 2, 1, 2, 2, 2, 2, 2, 2, 1, 1, 2, 2, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 2, 3, 3, 2, 3, 3, 3, 3, 3, 2, 3, 2, 3, 3, 3, 3, 2, 2, 3, 3, 3, 3, 2, 2, 2, 3, 3, 3, 2, 2, 3, 3, 3, 3, 2, 2, 3, 3, 2, 2, 2, 2, 2, 3, 2, 2, 2, 2, 2, 2, 3, 2, 3, 2, 2, 2, 2, 3, 2, 2, 2, 3, 2, 2, 2, 3, 2, 2, 2, 3, 2, 2, 2, 1, 1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 2, 1, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 1, 1, 2, 1, 1, 2, 1, 2, 2, 1, 2, 1, 2, 2, 2, 2, 2, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 2, 1, 2, 1, 2, 1, 1, 0, 2, 1, 1, 2, 2, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 1, 2, 1, 2, 2, 2, 2, 2, 2, 1, 2, 2, 2, 2, 2, 2, 2, 3, 2, 3, 2, 3, 3, 3, 3, 3, 2, 3, 3, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 2, 3, 2, 3, 3, 2, 2, 2, 3, 3, 2, 3, 0, 2, 3, 2, 0, 2, 3, 3, 2, 3, 2, 3, 2, 3, 2, 2, 3, 2, 3, 2, 2, 3, 3, 2, 3, 2, 2, 3, 2, 2, 2, 2, 3, 3, 2, 2, 2, 2, 3, 2, 2, 0, 1, 2, 1, 2, 1, 2, 1, 1, 1, 1, 2, 1, 2, 2, 1, 1, 1, 1, 2, 2, 1, 1, 1, 2, 1, 2, 1, 2, 2, 1, 1, 2, 2, 1, 2, 2, 2, 2, 2, 1, 2, 2, 2, 2, 1, 2, 2, 2, 0, 2, 3, 2, 3, 3, 3, 2, 3, 3, 3, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 2, 3, 2, 2, 3, 2, 3, 3, 3, 2, 3, 2, 3, 2, 2, 3, 2, 3, 2, 3, 2, 3, 3, 2, 2, 3, 2, 3, 3, 3, 2, 3, 3, 2, 2, 3, 2, 3, 2, 2, 3, 3, 2, 2, 2, 2, 2, 2, 0, 2, 1, 2, 2, 0, 1, 2, 2, 0, 2, 1, 1, 1, 2, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 1, 1, 1, 1, 1, 2, 1, 1, 1, 2, 2, 1, 2, 2, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 1, 2, 2, 2, 1, 2, 0, 2, 0, 3, 2, 3, 2, 0, 2, 0, 2, 3, 3, 3, 2, 2, 3, 2, 3, 2, 3, 2, 3, 3, 3, 2, 3, 3, 3, 3, 2, 2, 3, 2, 2, 2, 3, 2, 2, 3, 2, 3, 2, 2, 3, 2, 2, 3, 2, 0, 2, 3, 2, 2, 2, 1, 1, 1, 1, 2, 2, 1, 1, 1, 1, 2, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 2, 2, 2, 2, 2, 2, 2, 0, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 2, 1, 2, 2, 1, 2, 2, 1, 2, 2, 1, 2, 1, 1, 2, 2, 0, 2, 1, 2, 1, 1, 1, 2, 1, 2, 2, 1, 2, 1, 2, 2, 2, 2, 1, 2, 2, 1, 2, 2, 1, 2, 1, 2, 2, 2, 1, 1, 2, 3, 3, 2, 3, 2, 3, 2, 3, 2, 3, 3, 3, 3, 3, 3, 2, 2, 2, 2, 2, 3, 2, 3, 3, 2, 2, 2, 2, 3, 2, 3, 2, 3, 3, 2, 3, 2, 2, 3, 2, 3, 2, 2, 2, 3, 2, 0, 2, 1, 2, 1, 2, 1, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 1, 1, 2, 2, 2, 2, 1, 2, 2, 2, 2, 2, 1, 1, 2, 2, 1, 1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 2, 1, 2, 1, 2, 2, 1, 2, 1, 2, 1, 2, 1, 2, 2, 2, 1, 1, 2, 2, 2, 1, 1, 2, 1, 2, 2, 1, 2, 1, 2, 1, 2, 1, 1, 2, 2, 1, 2, 2, 2, 0, 2, 2, 1, 2, 2, 2, 3, 0, 3, 2, 3, 3, 2, 3, 3, 2, 3, 3, 2, 2, 3, 3, 3, 3, 2, 3, 3, 3, 2, 2, 3, 2, 2, 3, 3, 3, 2, 2, 2, 3, 3, 2, 2, 2, 3, 2, 2, 2, 2, 2, 2, 2, 2, 2, 0, 1, 2, 2, 1, 1, 1, 1, 2, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 1, 1, 1, 2, 2, 1, 2, 1, 1, 2, 2, 1, 2, 1, 1, 2, 2, 2, 1, 2, 1], [0, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 2, 2, 3, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 2, 1, 1, 1, 1, 2, 2, 1, 1, 1, 1, 1, 2, 2, 2, 1, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 2, 2, 2, 2, 1, 2, 2, 2, 2, 2, 1, 1, 2, 1, 2, 2, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 2, 2, 3, 3, 3, 2, 3, 3, 3, 3, 3, 3, 2, 3, 3, 3, 3, 3, 3, 3, 2, 3, 2, 3, 2, 3, 3, 2, 2, 3, 3, 2, 3, 2, 3, 3, 2, 3, 2, 3, 2, 2, 2, 2, 2, 2, 2, 3, 2, 3, 2, 2, 2, 3, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 2, 2, 1, 1, 1, 1, 1, 1], [2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 2, 2, 2, 3, 2, 3, 2, 2, 2, 2, 2, 2, 0, 2, 1, 2, 2, 0, 1, 2, 1, 2, 2, 1, 1, 2, 0, 2, 1, 2, 2, 2, 2, 2, 1, 1, 2, 1, 0, 2, 2, 2, 2, 1, 2, 2, 2, 1, 1, 2, 2, 2, 1, 2, 1, 2, 2, 2, 2, 1, 2, 1, 2, 1, 1, 2, 2, 1, 2, 2, 2, 3, 2, 3, 2, 2, 3, 2, 3, 2, 3, 3, 2, 3, 3, 3, 2, 3, 3, 3, 3, 2, 3, 3, 2, 2, 2, 2, 3, 2, 3, 3, 2, 3, 2, 2, 3, 2, 2, 3, 2, 2, 2, 3, 2, 2, 3, 2, 2, 3, 2, 2, 2, 2, 2, 3, 2, 2, 2, 2, 2, 0, 0, 2, 2, 1, 1, 2, 2, 2, 1, 1, 2, 2, 1, 1, 1, 1, 2, 2, 1, 1, 1, 1, 1, 1, 2, 1, 2, 1, 0, 1, 2, 1, 1, 2, 1, 2, 2, 1, 2, 1, 2, 1, 2, 1, 1, 2, 2, 2, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 2, 3, 2, 3, 3, 3, 2, 3, 3, 3, 2, 3, 3, 2, 2, 2, 3, 2, 3, 2, 2, 3, 3, 2, 3, 3, 3, 2, 2, 3, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 2, 1, 1, 2, 2, 1, 2, 1, 2, 2, 2, 2, 1, 1, 1, 1, 1, 1], [2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 2, 2, 3, 3, 2, 3, 2, 2, 3, 2, 2, 2, 3, 2, 2, 2, 3, 2, 2, 2, 2, 3, 2, 2, 2, 2, 3, 2, 2, 2, 2, 1, 1, 1, 2, 1, 2, 1, 2, 2, 1, 2, 2, 1, 1, 1, 1, 1, 2, 1, 2, 1, 2, 1, 1, 1, 1, 2, 1, 2, 2, 1, 2, 1, 2, 2, 2, 2, 2, 2, 2, 2, 1, 2, 2, 2, 1, 2, 2, 2, 3, 2, 2, 0, 2, 2, 2, 3, 2, 3, 2, 2, 3, 3, 2, 3, 3, 3, 3, 3, 3, 3, 3, 2, 3, 2, 3, 3, 3, 3, 2, 3, 3, 2, 3, 3, 3, 2, 2, 3, 2, 3, 3, 2, 3, 2, 3, 2, 2, 2, 3, 2, 3, 2, 3, 2, 3, 2, 2, 2, 3, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 2, 1, 2, 2, 1, 1, 1, 2, 2, 1, 2, 1, 1, 2, 2, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3]]\n", + "Action list's shape looks like (5,)\n" + ], + "name": "stdout" + }, + { + "output_type": "stream", + "text": [ + "/usr/local/lib/python3.7/dist-packages/numpy/core/_asarray.py:83: VisibleDeprecationWarning: Creating an ndarray from ragged nested sequences (which is a list-or-tuple of lists-or-tuples-or ndarrays with different lengths or shapes) is deprecated. If you meant to do this, you must specify 'dtype=object' when creating the ndarray\n", + " return array(a, dtype, copy=False, order=order)\n" + ], + "name": "stderr" + } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "fNkmwucrHMen" + }, + "source": [ + "Analysis of actions taken by agent" + ] + }, + { + "cell_type": "code", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "WHdAItjj1nxw", + "outputId": "20aea8b1-e011-4397-b775-7b4c4b593871" + }, + "source": [ + "distribution = {}\n", + "for actions in action_list:\n", + " for action in actions:\n", + " if action not in distribution.keys():\n", + " distribution[action] = 1\n", + " else:\n", + " distribution[action] += 1\n", + "print(distribution)" + ], + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "text": [ + "{2: 1144, 1: 516, 0: 30, 3: 501}\n" + ], + "name": "stdout" + } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "ricE0schY75M" + }, + "source": [ + "Saving the result of Model Testing\n" + ] + }, + { + "cell_type": "code", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "GZsMkGmIY42b", + "outputId": "c47a3123-eb1b-4dc1-f1b2-7a09a82cd8a6" + }, + "source": [ + "PATH = \"Action_List.npy\" # Can be modified into the name or path you want\n", + "np.save(PATH ,np.array(action_list)) " + ], + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "text": [ + "/usr/local/lib/python3.7/dist-packages/ipykernel_launcher.py:2: VisibleDeprecationWarning: Creating an ndarray from ragged nested sequences (which is a list-or-tuple of lists-or-tuples-or ndarrays with different lengths or shapes) is deprecated. If you meant to do this, you must specify 'dtype=object' when creating the ndarray\n", + " \n" + ], + "name": "stderr" + } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "asK7WfbkaLjt" + }, + "source": [ + "### This is the file you need to submit !!!\n", + "Download the testing result to your device\n", + "\n" + ] + }, + { + "cell_type": "code", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 34 + }, + "id": "c-CqyhHzaWAL", + "outputId": "38653c82-673e-4f90-8746-3a0424fe3aca" + }, + "source": [ + "from google.colab import files\n", + "files.download(PATH)" + ], + "execution_count": null, + "outputs": [ + { + "output_type": "display_data", + "data": { + "application/javascript": [ + "\n", + " async function download(id, filename, size) {\n", + " if (!google.colab.kernel.accessAllowed) {\n", + " return;\n", + " }\n", + " const div = document.createElement('div');\n", + " const label = document.createElement('label');\n", + " label.textContent = `Downloading \"${filename}\": `;\n", + " div.appendChild(label);\n", + " const progress = document.createElement('progress');\n", + " progress.max = size;\n", + " div.appendChild(progress);\n", + " document.body.appendChild(div);\n", + "\n", + " const buffers = [];\n", + " let downloaded = 0;\n", + "\n", + " const channel = await google.colab.kernel.comms.open(id);\n", + " // Send a message to notify the kernel that we're ready.\n", + " channel.send({})\n", + "\n", + " for await (const message of channel.messages) {\n", + " // Send a message to notify the kernel that we're ready.\n", + " channel.send({})\n", + " if (message.buffers) {\n", + " for (const buffer of message.buffers) {\n", + " buffers.push(buffer);\n", + " downloaded += buffer.byteLength;\n", + " progress.value = downloaded;\n", + " }\n", + " }\n", + " }\n", + " const blob = new Blob(buffers, {type: 'application/binary'});\n", + " const a = document.createElement('a');\n", + " a.href = window.URL.createObjectURL(blob);\n", + " a.download = filename;\n", + " div.appendChild(a);\n", + " a.click();\n", + " div.remove();\n", + " }\n", + " " + ], + "text/plain": [ + "" + ] + }, + "metadata": { + "tags": [] + } + }, + { + "output_type": "display_data", + "data": { + "application/javascript": [ + "download(\"download_899cef0d-01bc-40fd-a501-1573f2382641\", \"Action_List.npy\", 4689)" + ], + "text/plain": [ + "" + ] + }, + "metadata": { + "tags": [] + } + } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "seT4NUmWmAZ1" + }, + "source": [ + "# Server \n", + "The code below simulate the environment on the judge server. Can be used for testing." + ] + }, + { + "cell_type": "code", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "U69c-YTxaw6b", + "outputId": "65dd4cf4-e667-469d-a7dc-34f6a053e1d0" + }, + "source": [ + "action_list = np.load(PATH,allow_pickle=True) # The action list you upload\n", + "seed = 543 # Do not revise this\n", + "fix(env, seed)\n", + "\n", + "agent.network.eval() # set network to evaluation mode\n", + "\n", + "test_total_reward = []\n", + "if len(action_list) != 5:\n", + " print(\"Wrong format of file !!!\")\n", + " exit(0)\n", + "for actions in action_list:\n", + " state = env.reset()\n", + " img = plt.imshow(env.render(mode='rgb_array'))\n", + "\n", + " total_reward = 0\n", + "\n", + " done = False\n", + "\n", + " for action in actions:\n", + " \n", + " state, reward, done, _ = env.step(action)\n", + " total_reward += reward\n", + " if done:\n", + " break\n", + "\n", + " print(f\"Your reward is : %.2f\"%total_reward)\n", + " test_total_reward.append(total_reward)" + ], + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "text": [ + "/usr/local/lib/python3.7/dist-packages/torch/__init__.py:422: UserWarning: torch.set_deterministic is deprecated and will be removed in a future release. Please use torch.use_deterministic_algorithms instead\n", + " \"torch.set_deterministic is deprecated and will be removed in a future \"\n" + ], + "name": "stderr" + }, + { + "output_type": "stream", + "text": [ + "Your reward is : -29.53\n", + "Your reward is : -36.44\n", + "Your reward is : -194.16\n", + "Your reward is : -268.27\n", + "Your reward is : -207.91\n" + ], + "name": "stdout" + }, + { + "output_type": "display_data", + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAW4AAAD8CAYAAABXe05zAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAc+klEQVR4nO3de3RU9b338feXJCRcwv1iCAmgDQcRFSFC9NAWWbVaq8JxsRR7qBbx0Cqtlx4ROc9qS/sstVqrz3HJ0getFts+qHgry9qDFpDWWpWLEQGxIuUWgSg3wZDLZL7PH7NDRyD3mUx25vNaa6/Z+7f3zP7+MptPNr/Zk23ujoiIhEenVBcgIiLNo+AWEQkZBbeISMgouEVEQkbBLSISMgpuEZGQSVpwm9nFZvaBmW0xszuStR8RkXRjybiO28wygL8DFwK7gNXA1e6+KeE7ExFJM8k64x4HbHH3re5eDTwFTE7SvkRE0kpmkl43H9gZt7wLGF/fxmamr29KQnXt2pvMzByqKj+jqrqCnJzudO8ygKyMbq163dpoFZ9Xfcrnn+8jKyuHrl17E4nUUFGxH/dogqoXiXF3O1l7soK7UWY2C5iVqv1Lx9Wjxylc8JVbGN7vYjaVv8DyFfczfvy3KR46k9zsvFa99oGj/2D1R4+xZs3TnF8yk7OHTONo5ABvblzI+vVLE9QDkYYla6ikDCiIWx4ctB3j7gvdvdjdi5NUg6Sp006bQH7PsThRKioOEI3WJnwf1dUV7Cx7h4OV2+nXdThDBo8jN3dAwvcjcjLJCu7VQJGZDTOzzsA0QKcjknS9euVTOKiYXjlD2H34Xbb+429UVx9Nyr527lxH2f61RKKVDOo5hmHDSpKyH5HjJSW43T0CfB9YBrwPPOPuG5OxL5E6mZmd+dJpX2Fwr3Oprv2cj/etY9eu0qTtr7q6gq3/+Bt7Dr9Lny6nMiT/XHr0OCVp+xOpk7Qxbnd/GXg5Wa8vcrz8/LMYMexiemTns2H3s6x751lqa2uA2Oc7tV5DJFrZqn1EveYLyzt2rOXjYaWckjuaQb3GMGRIMe+991Kr9iHSmJR9OCmSSBkZWZx22gTyuo9m/9GtbC97i0OHdgdrnX37trGl+zKMk35IfxwDTn6hkxPl00+3HVuORKr4aOvr5PU+m/we4xiSP45t297m8OHy1nZJpF4KbukQCgrOYVDv0WR2yubjg++wffvaL6zftOkVPvrorwnZV1XVkS8s79xZyu5T3+WU7mfTs9tgCgvHsHHjMuoLf5HWUnBLh1BdXcHH+96lIrKPHR+v5tChj7+wPhqNcPTooaTsOxKpYuOmP9I1py+fHf44ONNXaEvyJOUr780uQl/AkQTIzu5Obm5/Dh7cTSTSurHslujZcxAVFQepqalo831Lx1TfF3AU3CIi7VR9wa0/6yoiEjIKbhGRkFFwi4iEjIJbRCRkFNwiIiGj4BYRCRkFt4hIyCi4RURCRsEtIhIyCm4RkZBRcIuIhIyCW0QkZBTcIiIh06q/x21m24DDQC0QcfdiM+sDPA0MBbYBV7r7gdaVKSIidRJxxn2Bu4929+Jg+Q5gubsXAcuDZRERSZBkDJVMBhYF84uAKUnYh4hI2mptcDvwipmtNbNZQdtAd6+7S+seYGAr9yEiInFae8/JCe5eZmYDgFfNbHP8Snf3+u5uEwT9rJOtExGR+iXs1mVmNh84AvwHMNHdd5tZHvCau/9LI8/VrctERI6T8FuXmVk3M8utmwe+DmwAlgLXBptdC/y+pfsQEZETtfiM28xOBV4IFjOB/+fud5pZX+AZoBDYTuxywP2NvJbOuEVEjqO7vIuIhIzu8i4i0kEouEVEQkbBLSISMgpuEZGQUXCLiISMgltEJGQU3CIiIaPgFhEJGQW3iEjIKLhFREJGwS0iEjIKbhGRkFFwi4iEjIJbRCRkFNwiIiGj4BYRCRkFt4hIyCi4RURCRsEtIhIyjQa3mT1uZuVmtiGurY+ZvWpmHwaPvYN2M7MHzWyLma03szHJLF5EJB015Yz718DFx7XdASx39yJgebAM8A2gKJhmAQ8npkwREanTaHC7+5+B/cc1TwYWBfOLgClx7U96zJtALzPLS1SxIiLS8jHuge6+O5jfAwwM5vOBnXHb7QraTmBms8xsjZmtaWENIiJpKbO1L+DubmbeguctBBYCtOT5IiLpqqVn3HvrhkCCx/KgvQwoiNtucNAmIiIJ0tLgXgpcG8xfC/w+rv2a4OqSEuBQ3JCKiIgkgLk3PEphZouBiUA/YC/wE+BF4BmgENgOXOnu+83MgIeIXYVSAcxw90bHsDVUIiJyIne3k7U3GtxtQcEtInKi+oJb35wUEQkZBbeISMgouEVEQkbBLSISMgpuEZGQUXCLiISMgltEJGQU3CIiIaPgFhEJGQW3iEjIKLhFREJGwS0iEjIKbhGRkFFwi4iEjIJbRCRkFNwiIiGj4BYRCRkFt4hIyDQa3Gb2uJmVm9mGuLb5ZlZmZqXBdEncunlmtsXMPjCzi5JVuIhIumrKzYK/AhwBnnT3UUHbfOCIu9933LYjgcXAOGAQ8CdguLvXNrIP3XNSROQ4Lb7npLv/GdjfxP1MBp5y9yp3/wewhViIi4hIgrRmjPv7ZrY+GErpHbTlAzvjttkVtJ3AzGaZ2RozW9OKGkRE0k5Lg/th4DRgNLAb+GVzX8DdF7p7sbsXt7AGEZG01KLgdve97l7r7lHgUf45HFIGFMRtOjhoExGRBGlRcJtZXtzivwF1V5wsBaaZWbaZDQOKgLdbV6KIiMTLbGwDM1sMTAT6mdku4CfARDMbDTiwDfgugLtvNLNngE1ABJjd2BUlIiLSPI1eDtgmRehyQBGRE7T4ckAREWlfFNwiIiGj4BYRCRkFt4hIyCi4RURCRsEtIhIyCm4RkZBRcIuIhIyCW0QkZBTcIiIho+AWEQkZBbeISMgouEVEQkbBLSISMgpuEZGQUXCLiISMgltEJGQU3CIiIdNocJtZgZmtNLNNZrbRzG4O2vuY2atm9mHw2DtoNzN70My2mNl6MxuT7E6IiKSTppxxR4D/dPeRQAkw28xGAncAy929CFgeLAN8g9jd3YuAWcDDCa9aRCSNNRrc7r7b3dcF84eB94F8YDKwKNhsETAlmJ8MPOkxbwK9zCwv4ZWLiKSpZo1xm9lQ4BzgLWCgu+8OVu0BBgbz+cDOuKftCtqOf61ZZrbGzNY0s2YRkbTW5OA2s+7Ac8At7v5Z/Dp3d8Cbs2N3X+juxe5e3JzniYikuyYFt5llEQvt37n780Hz3rohkOCxPGgvAwrinj44aBMRkQRoylUlBvwKeN/d749btRS4Npi/Fvh9XPs1wdUlJcChuCEVERFpJYuNcjSwgdkE4C/Ae0A0aP4vYuPczwCFwHbgSnffHwT9Q8DFQAUww90bHMc2s2YNs4iIpAN3t5O1NxrcbUHBLSJyovqCW9+cFBEJGQW3iEjIKLhFREJGwS0iEjIKbhGRkFFwi4iEjIJbRCRkFNwiIiGj4BYRCRkFt4hIyCi4RURCRsEtIhIyCm4RkZBRcIuIhIyCW0QkZBTcIiIho+AWEQkZBbeISMg05WbBBWa20sw2mdlGM7s5aJ9vZmVmVhpMl8Q9Z56ZbTGzD8zsomR2QEQk3TTlZsF5QJ67rzOzXGAtMAW4Ejji7vcdt/1IYDEwDhgE/AkY7u61DexD95wUETlOi+856e673X1dMH8YeB/Ib+Apk4Gn3L3K3f8BbCEW4iIikgDNGuM2s6HAOcBbQdP3zWy9mT1uZr2DtnxgZ9zTdtFw0IsAcNdd3+Wee2DUKBg5EgYNSnVFbW/ixIn8+tf/wiWXwBlnwIgRkJGR6qqkvcls6oZm1h14DrjF3T8zs4eB/w148PhL4LpmvN4sYFbzypWO7MwzTyUvDyZNii3v3g2bNsXm/+d/YMsWcIc9e6C23oG3cOvfvz/jxh3hjDNiy5EIvPEG1NTArl3w4oux9kOH4PDh1NUpqdWk4DazLGKh/Tt3fx7A3ffGrX8UeClYLAMK4p4+OGj7AndfCCwMnq8xbjnGglG9QYP+edZ9wQWx0K6thWXL4OjRWLD/9repqzOZ6n4GWVnw1a/G5t1h+vTY/IYN8MEHsfknn4S9e098Dem4mnJViQG/At539/vj2vPiNvs3YEMwvxSYZmbZZjYMKALeTlzJko6i0VhoRyJQUQGffx4L73RS94urthYqK2M/g88/j/1sJL005Yz7X4FvA++ZWWnQ9l/A1WY2mthQyTbguwDuvtHMngE2ARFgdkNXlIjEc49NEBsaKA2OuGXLYOvW2Lr9+zt+WNX9HCIRWLECqquhrAyWLo2tP3Ik/X5xyT81Gtzu/jpwsktSXm7gOXcCd7aiLklDR47AH/4QG/6IRmNjuJ98kuqq2l5pKTz6KGzfHvs57NjR8X9RSfM0+cNJkWTbsQPmz091Fal3//2wZk2qq5D2TF95FxEJGQW3iEjIKLhFREJGwS0iEjIKbhGRkFFwi4iEjIJbRCRkdB23iEg706lTJzp1qv+8WmfcIiLtRKdOnRgyZAg///nPGTVqVP3btWFNIiJSj06dOnHLLbewevVqbrvtNjIa+EPsGioREUmhzMxMvv71rzNv3jzGjh1Lly5dGn9OG9QlIiInUVxczP33309JSQlZWVlNfp6CW0SkjQ0ZMoQbb7yRq6++moKCgsafcBwFt4hIGzAzBg4cyM0338yMGTMYMGAAZie9iXujFNwiIknWuXNnLrnkEh599FH69OnT4KV+TaHgFhFJEjPjoosu4qabbuLLX/4y3bt3T8jrKrhFRBLMzJg0aRJz587lvPPOS1hg12k0uM0sB/gzkB1s/6y7/yS4EfBTQF9gLfBtd682s2zgSWAssA+4yt23JbRqEZF2qrCwkGnTpvHTn/6UnJycpOyjKWfcVcAkdz9iZlnA62b2R+CHwAPu/pSZPQLMBB4OHg+4+5fMbBpwD3BVUqoXEWknBg8ezIwZM5g+fTpFRUUt/uCxKZpys2AHjgSLWcHkwCTgW0H7ImA+seCeHMwDPAs8ZGYWvI6ISIeSk5PDnDlzuO666ygsLGz1B49N0aQxbjPLIDYc8iVgAfARcNDdI8Emu4D8YD4f2Ang7hEzO0RsOOXT+l4/Ly+PSCTCJ+l4S2855i9/+QvdunUjnX/Hb9u2jUOHDpGRkUFtbW2qy5EGZGVlUVJSwg033MCVV17Z4FfUE61Jwe3utcBoM+sFvACMaO2OzWwWMAtiY0JPPvkk3/rWtxTeaahTp05cccUV/PCHP+TRRx9NdTntwm9+8xtKS0tZsmQJkUgkrX+ZtTdmxsSJE7nhhhu46KKL6NGjR9sX4e7NmoAfA3OInUFnBm3nAcuC+WXAecF8ZrCdNfSaY8eO9Wg06lu3bvXrr7/eiQ3FaOrgU+fOnf3yyy/3N998048ePeryRdXV1b5582Z/7LHHfPLkyd6rV6+Uv2fpPvXt29dffPFFP3jwYNLf/7Fjx7rXl8P1rfB/BnV/oFcw3wX4C3ApsASYFrQ/AtwYzM8GHgnmpwHPNLaPoEB3dz98+LDPnj3bBw0alPI3SVNypu7du/vll1/ub7zxhtfU1CT7+O8QotGol5aW+rPPPusTJkzw8ePHe2ZmZsrfy3SZ8vLy/Hvf+56/9957bfaetza4zwLeAdYDG4AfB+2nAm8DW4iFeHbQnhMsbwnWn9rYPuKDu87bb7/thYWFKX/DNCVu6tmzp0+ePNlXrVrl0Wg0eUd8GqiurvaXX37Zf/azn/npp5/u3bp1S/n72xGnvLw8v+WWW3z9+vVt/h63KrjbYjpZcNedYcyePVtnFiGfevTo4VOnTvUVK1Z4bW1tAg9tiUajHolE/A9/+IP/4he/8MGDB3tubm7K3/OwT3379vWzzz7bN23alLJjNpTBXae6utrnzJnjXbp0Sfmbqal5U05Ojk+bNs1XrVrlVVVVCTiUpSG1tbV+5MgRX7Vqlc+fP99nzJjh3bp104lPE6esrCzPzc31m2++2T/88EOvqKhI6f8MQx3c7u6VlZW+bt06LyoqcjNL+RusqeEpIyPDp06d6m+88YY+dEyhyspK3717tz/00EP+ne98x7t06aJ/P/UcrxMnTvSnn37a9+zZ024+dwl9cNcpKyvzu+66S2cQ7XQaNWqUT506VVeJtEORSMT//ve/++LFi33KlCnev3//lB8vqZ7qxq83b97cJleJNFeHCW732NDJvffe68OHD0/5G68pNo0aNcoXLFjgZWVlLTg8JRU2btzof/3rX33evHl+/vnne3Z2dsqPo7aYOnfu7BMnTvRXXnnFN2/enOq3oUEdKrjr7Nixw8ePH5/yAyFdJzPzoqIif+SRR7y8vLzZ75+0H5FIxF999VW/7777/PTTT/cePXqk/PhK9DR06FC/5557fNmyZR6JRFL9I2+ShoLb3J1UKy4u9jVr1jTrOe7Orl27eOKJJ7j77ruprKxMUnUSLyMjg8LCQm677TamT59Obm5uUv+YjrQddycajbJq1Sq2bt3K66+/ziuvvALAvn37qK6uTnGFzZObm8sFF1zAZZddxqWXXsrAgQNDdawWFxezZs2akxYc2uCuE41Geeyxx7jttts4fPhwgiuTOhkZGQwdOpQf/OAHzJw5k27duoXqH4E0X01NDTU1NQAsWbKEHTt2sHz5ctauXUtNTQ1VVVUprvBEnTp1oqCggOuuu44JEyYwYcIEOnfunOqyWqRDBzdAJBJh165dzJgxg9deey1xhQlmRlFREddffz3XXXcdffr0UWCnscOHD1NRUcG7777LCy+8wCeffMKLL74IxE6iUpUnmZmZTJ06lbPOOotZs2Z1iOO0wwd3nfLycubOncvzzz/PZ599loDK2kZJSQl5eXlfaHvttdc4cOBAiiqKGTVqFDNmzODqq68+oT4RgMrKSrZv3w7Ejtlly5axZ88e/va3v7XJ/s855xyGDx/OnDlzOPPMM0N7dn0yaRPcEBunW7JkCXfffTelpaUJec1EGzJkCIWFhZxxxhlcc801jBgxgt69ex9b7+5s3Lix3qGfJ554gs2bN590XVVVFatXr27Vmc+oUaO44YYbmDJlCoMGDWrx60h6OnDgAJs3b+bTTz/ll7/8JZFIhNWrVydsjLxPnz7H/u1cdtllDBw4MCGv296kVXDXOXjwIDNnzuT5559P6Ou21LBhwxgyZAi33norI0aMYPjw4UnZz9GjR1m5ciXRaPSEdZWVldxzzz1UVFTU+/ybbrqJK664gv79+yelPkk/tbW1rFy5ksrKSnbs2MGCBQuora3lww8/POlxejIZGRkUFRVRUlLCjTfeyLnnnpvkqlMvLYMb4NChQ/z2t79l/vz5fPppvfdxSIqePXvSq1cvbr31Vrp37843v/lNBgwYgJmlbOyt7lKihqSyPun46o7Bo0eP8txzz1FdXc2CBQsoLy8nEolQXl7+he379evHoEGDuP3225k6dSpZWVltcoeZ9iBtgxtiB8qKFSu45ppr2LNnT5N/wzdXVlYW2dnZTJ8+nfz8fC688ELOOusscnJyFIQi9XB3qqqqiEajHDhwgEWLFlFdXc3ChQu56qqruPHGGykoKCA7Ozvt/h2ldXBD7ODYt28fc+fO5fHHH0/Y62ZkZNC1a1emTZtGSUkJl156Kb179yYrKyth+xBJN+7O/v376dmzJ5mZTbpJV4fUUHCnxU/FzOjXrx8PPvggvXv3ZvHixXz88cctfr2vfe1rxz7Ay8jIYNiwYWnz3zeRZDMz+vbtm+oy2rW0CO463bp147777uOKK67g9ttv580332zSDVlHjhzJ0KFDmTt3LtnZ2Zx55pl07dq1DSoWETlRWgV3nfPPP5/ly5fzox/9iAceeIBIJPKF9f369eOUU07h1ltvZcCAARQXF3PKKaekqFoRkS9Ky+AGyM7O5s4772TMmDHce++97N27l379+nHTTTdx9tlnM2bMGF1hISLtUqPBbWY5wJ+B7GD7Z939J2b2a+CrwKFg0++4e6nFku6/gUuAiqB9XTKKb62srCyuuuoqpkyZEvuLW2Zp+em1iIRLU864q4BJ7n7EzLKA183sj8G6Oe7+7HHbfwMoCqbxwMPBY7tkZuTk5KS6DBGRJmv0UojgT8MeCRazgqmhawgnA08Gz3sT6GVm+kMXIiIJ0qRr2Mwsw8xKgXLgVXd/K1h1p5mtN7MHzCw7aMsHdsY9fVfQJiIiCdCk4Hb3WncfDQwGxpnZKGAeMAI4F+gDzG3Ojs1slpmtMbM1n3zySTPLFhFJX8361oi7HwRWAhe7++5gOKQKeAIYF2xWBhTEPW1w0Hb8ay1092J3L9YfNBIRabpGg9vM+ptZr2C+C3AhsLlu3Dq4imQKsCF4ylLgGospAQ65++6kVC8ikoaaclVJHrDIzDKIBf0z7v6Sma0ws/6AAaXA94LtXyZ2KeAWYpcDzkh82SIi6avR4Hb39cA5J2mfVM/2DsxufWkiInIy+stIIiIho+AWEQkZBbeISMgouEVEQkbBLSISMgpuEZGQUXCLiISMgltEJGQU3CIiIaPgFhEJGQW3iEjIKLhFREJGwS0iEjIKbhGRkFFwi4iEjIJbRCRkFNwiIiGj4BYRCRkFt4hIyCi4RURCRsEtIhIyCm4RkZAxd091DZjZYeCDVNeRJP2AT1NdRBJ01H5Bx+2b+hUuQ9y9/8lWZLZ1JfX4wN2LU11EMpjZmo7Yt47aL+i4fVO/Og4NlYiIhIyCW0QkZNpLcC9MdQFJ1FH71lH7BR23b+pXB9EuPpwUEZGmay9n3CIi0kQpD24zu9jMPjCzLWZ2R6rraS4ze9zMys1sQ1xbHzN71cw+DB57B+1mZg8GfV1vZmNSV3nDzKzAzFaa2SYz22hmNwftoe6bmeWY2dtm9m7Qr58G7cPM7K2g/qfNrHPQnh0sbwnWD01l/Y0xswwze8fMXgqWO0q/tpnZe2ZWamZrgrZQH4utkdLgNrMMYAHwDWAkcLWZjUxlTS3wa+Di49ruAJa7exGwPFiGWD+LgmkW8HAb1dgSEeA/3X0kUALMDt6bsPetCpjk7mcDo4GLzawEuAd4wN2/BBwAZgbbzwQOBO0PBNu1ZzcD78ctd5R+AVzg7qPjLv0L+7HYcu6esgk4D1gWtzwPmJfKmlrYj6HAhrjlD4C8YD6P2HXqAP8XuPpk27X3Cfg9cGFH6hvQFVgHjCf2BY7MoP3YcQksA84L5jOD7SzVtdfTn8HEAmwS8BJgHaFfQY3bgH7HtXWYY7G5U6qHSvKBnXHLu4K2sBvo7ruD+T3AwGA+lP0N/ht9DvAWHaBvwXBCKVAOvAp8BBx090iwSXztx/oVrD8E9G3bipvs/wC3A9FguS8do18ADrxiZmvNbFbQFvpjsaXayzcnOyx3dzML7aU7ZtYdeA64xd0/M7Nj68LaN3evBUabWS/gBWBEiktqNTO7FCh397VmNjHV9STBBHcvM7MBwKtmtjl+ZViPxZZK9Rl3GVAQtzw4aAu7vWaWBxA8lgftoeqvmWURC+3fufvzQXOH6BuAux8EVhIbQuhlZnUnMvG1H+tXsL4nsK+NS22KfwUuN7NtwFPEhkv+m/D3CwB3Lwsey4n9sh1HBzoWmyvVwb0aKAo++e4MTAOWprimRFgKXBvMX0tsfLiu/ZrgU+8S4FDcf/XaFYudWv8KeN/d749bFeq+mVn/4EwbM+tCbNz+fWIBPjXY7Ph+1fV3KrDCg4HT9sTd57n7YHcfSuzf0Qp3/3dC3i8AM+tmZrl188DXgQ2E/FhslVQPsgOXAH8nNs74v1JdTwvqXwzsBmqIjaXNJDZWuBz4EPgT0CfY1ohdRfMR8B5QnOr6G+jXBGLjiuuB0mC6JOx9A84C3gn6tQH4cdB+KvA2sAVYAmQH7TnB8pZg/amp7kMT+jgReKmj9Cvow7vBtLEuJ8J+LLZm0jcnRURCJtVDJSIi0kwKbhGRkFFwi4iEjIJbRCRkFNwiIiGj4BYRCRkFt4hIyCi4RURC5v8Dix/mJUV1WnwAAAAASUVORK5CYII=\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "tags": [], + "needs_background": "light" + } + } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "TjFBWwQP1hVe" + }, + "source": [ + "# Your score" + ] + }, + { + "cell_type": "code", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "GpJpZz3Wbm0X", + "outputId": "7d4677c7-b285-42d3-d8c0-5f0a0d8230c8" + }, + "source": [ + "print(f\"Your final reward is : %.2f\"%np.mean(test_total_reward))" + ], + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "text": [ + "Your final reward is : -147.26\n" + ], + "name": "stdout" + } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "wUBtYXG2eaqf" + }, + "source": [ + "## Reference\n", + "\n", + "Below are some useful tips for you to get high score.\n", + "\n", + "- [DRL Lecture 1: Policy Gradient (Review)](https://youtu.be/z95ZYgPgXOY)\n", + "- [ML Lecture 23-3: Reinforcement Learning (including Q-learning) start at 30:00](https://youtu.be/2-JNBzCq77c?t=1800)\n", + "- [Lecture 7: Policy Gradient, David Silver](http://www0.cs.ucl.ac.uk/staff/d.silver/web/Teaching_files/pg.pdf)\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "eZ7VDw-C19Qe" + }, + "source": [ + "" + ] + } + ] +} \ No newline at end of file diff --git a/11 Quantum ML/GuestLecture_QML.pdf b/11 Quantum ML/课件/GuestLecture_QML.pdf similarity index 100% rename from 11 Quantum ML/GuestLecture_QML.pdf rename to 11 Quantum ML/课件/GuestLecture_QML.pdf diff --git a/12 Life-Long&Compression/课件/life_v2.pdf b/12 Life-Long&Compression/课件/life_v2.pdf new file mode 100644 index 0000000..8b27fae Binary files /dev/null and b/12 Life-Long&Compression/课件/life_v2.pdf differ diff --git a/12 Life-Long&Compression/课件/life_v2.pptx b/12 Life-Long&Compression/课件/life_v2.pptx new file mode 100644 index 0000000..c6eecbd Binary files /dev/null and b/12 Life-Long&Compression/课件/life_v2.pptx differ diff --git a/12 Life-Long&Compression/课件/tiny_v6.pdf b/12 Life-Long&Compression/课件/tiny_v6.pdf new file mode 100644 index 0000000..5f9cf12 Binary files /dev/null and b/12 Life-Long&Compression/课件/tiny_v6.pdf differ diff --git a/12 Life-Long&Compression/课件/tiny_v6.pptx b/12 Life-Long&Compression/课件/tiny_v6.pptx new file mode 100644 index 0000000..b4bc9f4 Binary files /dev/null and b/12 Life-Long&Compression/课件/tiny_v6.pptx differ diff --git a/README.md b/README.md index 9956fae..49b0717 100644 --- a/README.md +++ b/README.md @@ -19,6 +19,10 @@ 2021/05/21/ 更新RL 及 HW11 +2021/05/28/ 更新Quantum ML + +2021/06/04/ 更新Life-Long&Compression 及 HW12 + #------------------------------------------------------------------# B站视频地址:https://www.bilibili.com/video/BV1Wv411h7kN#reply4197445138 diff --git a/范例/HW12/HW12_EN.ipynb b/范例/HW12/HW12_EN.ipynb new file mode 100644 index 0000000..5694e58 --- /dev/null +++ b/范例/HW12/HW12_EN.ipynb @@ -0,0 +1,3092 @@ +{ + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "colab": { + "name": "hw12_reinforcement_learning_english_version.ipynb", + "provenance": [], + "collapsed_sections": [] + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + }, + "widgets": { + "application/vnd.jupyter.widget-state+json": { + "de3a153737af485ea436d7e8393d8248": { + "model_module": "@jupyter-widgets/controls", + "model_name": "HBoxModel", + "state": { + "_view_name": "HBoxView", + "_dom_classes": [], + "_model_name": "HBoxModel", + "_view_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_view_count": null, + "_view_module_version": "1.5.0", + "box_style": "", + "layout": "IPY_MODEL_6345f8926212465291c04587353161f1", + "_model_module": "@jupyter-widgets/controls", + "children": [ + "IPY_MODEL_3aa84c8c097d4858a0c38f18dec8b060", + "IPY_MODEL_6647e68cf064416ca593b990bed81edf" + ] + } + }, + "6345f8926212465291c04587353161f1": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "state": { + "_view_name": "LayoutView", + "grid_template_rows": null, + "right": null, + "justify_content": null, + "_view_module": "@jupyter-widgets/base", + "overflow": null, + "_model_module_version": "1.2.0", + "_view_count": null, + "flex_flow": null, + "width": null, + "min_width": null, + "border": null, + "align_items": null, + "bottom": null, + "_model_module": "@jupyter-widgets/base", + "top": null, + "grid_column": null, + "overflow_y": null, + "overflow_x": null, + "grid_auto_flow": null, + "grid_area": null, + "grid_template_columns": null, + "flex": null, + "_model_name": "LayoutModel", + "justify_items": null, + "grid_row": null, + "max_height": null, + "align_content": null, + "visibility": null, + "align_self": null, + "height": null, + "min_height": null, + "padding": null, + "grid_auto_rows": null, + "grid_gap": null, + "max_width": null, + "order": null, + "_view_module_version": "1.2.0", + "grid_template_areas": null, + "object_position": null, + "object_fit": null, + "grid_auto_columns": null, + "margin": null, + "display": null, + "left": null + } + }, + "3aa84c8c097d4858a0c38f18dec8b060": { + "model_module": "@jupyter-widgets/controls", + "model_name": "FloatProgressModel", + "state": { + "_view_name": "ProgressView", + "style": "IPY_MODEL_a25de7edbbee47cc8094f43125efa39b", + "_dom_classes": [], + "description": "Total: 86.3, Final: 0.0: 100%", + "_model_name": "FloatProgressModel", + "bar_style": "success", + "max": 400, + "_view_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "value": 400, + "_view_count": null, + "_view_module_version": "1.5.0", + "orientation": "horizontal", + "min": 0, + "description_tooltip": null, + "_model_module": "@jupyter-widgets/controls", + "layout": "IPY_MODEL_a0bb5061ca0946ab89a89f9abadcebc8" + } + }, + "6647e68cf064416ca593b990bed81edf": { + "model_module": "@jupyter-widgets/controls", + "model_name": "HTMLModel", + "state": { + "_view_name": "HTMLView", + "style": "IPY_MODEL_d99d8e6bcbe0445eb7eddbfe31277635", + "_dom_classes": [], + "description": "", + "_model_name": "HTMLModel", + "placeholder": "​", + "_view_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "value": " 400/400 [11:36<00:00, 1.74s/it]", + "_view_count": null, + "_view_module_version": "1.5.0", + "description_tooltip": null, + "_model_module": "@jupyter-widgets/controls", + "layout": "IPY_MODEL_5af1c6579f0f4c2eb03df4063749569e" + } + }, + "a25de7edbbee47cc8094f43125efa39b": { + "model_module": "@jupyter-widgets/controls", + "model_name": "ProgressStyleModel", + "state": { + "_view_name": "StyleView", + "_model_name": "ProgressStyleModel", + "description_width": "initial", + "_view_module": "@jupyter-widgets/base", + "_model_module_version": "1.5.0", + "_view_count": null, + "_view_module_version": "1.2.0", + "bar_color": null, + "_model_module": "@jupyter-widgets/controls" + } + }, + "a0bb5061ca0946ab89a89f9abadcebc8": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "state": { + "_view_name": "LayoutView", + "grid_template_rows": null, + "right": null, + "justify_content": null, + "_view_module": "@jupyter-widgets/base", + "overflow": null, + "_model_module_version": "1.2.0", + "_view_count": null, + "flex_flow": null, + "width": null, + "min_width": null, + "border": null, + "align_items": null, + "bottom": null, + "_model_module": "@jupyter-widgets/base", + "top": null, + "grid_column": null, + "overflow_y": null, + "overflow_x": null, + "grid_auto_flow": null, + "grid_area": null, + "grid_template_columns": null, + "flex": null, + "_model_name": "LayoutModel", + "justify_items": null, + "grid_row": null, + "max_height": null, + "align_content": null, + "visibility": null, + "align_self": null, + "height": null, + "min_height": null, + "padding": null, + "grid_auto_rows": null, + "grid_gap": null, + "max_width": null, + "order": null, + "_view_module_version": "1.2.0", + "grid_template_areas": null, + "object_position": null, + "object_fit": null, + "grid_auto_columns": null, + "margin": null, + "display": null, + "left": null + } + }, + "d99d8e6bcbe0445eb7eddbfe31277635": { + "model_module": "@jupyter-widgets/controls", + "model_name": "DescriptionStyleModel", + "state": { + "_view_name": "StyleView", + "_model_name": "DescriptionStyleModel", + "description_width": "", + "_view_module": "@jupyter-widgets/base", + "_model_module_version": "1.5.0", + "_view_count": null, + "_view_module_version": "1.2.0", + "_model_module": "@jupyter-widgets/controls" + } + }, + "5af1c6579f0f4c2eb03df4063749569e": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "state": { + "_view_name": "LayoutView", + "grid_template_rows": null, + "right": null, + "justify_content": null, + "_view_module": "@jupyter-widgets/base", + "overflow": null, + "_model_module_version": "1.2.0", + "_view_count": null, + "flex_flow": null, + "width": null, + "min_width": null, + "border": null, + "align_items": null, + "bottom": null, + "_model_module": "@jupyter-widgets/base", + "top": null, + "grid_column": null, + "overflow_y": null, + "overflow_x": null, + "grid_auto_flow": null, + "grid_area": null, + "grid_template_columns": null, + "flex": null, + "_model_name": "LayoutModel", + "justify_items": null, + "grid_row": null, + "max_height": null, + "align_content": null, + "visibility": null, + "align_self": null, + "height": null, + "min_height": null, + "padding": null, + "grid_auto_rows": null, + "grid_gap": null, + "max_width": null, + "order": null, + "_view_module_version": "1.2.0", + "grid_template_areas": null, + "object_position": null, + "object_fit": null, + "grid_auto_columns": null, + "margin": null, + "display": null, + "left": null + } + } + } + } + }, + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "Fp30SB4bxeQb" + }, + "source": [ + "# **Homework 12 - Reinforcement Learning**\n", + "\n", + "If you have any problem, e-mail us at ntu-ml-2021spring-ta@googlegroups.com\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "yXsnCWPtWSNk" + }, + "source": [ + "## Preliminary work\n", + "\n", + "First, we need to install all necessary packages.\n", + "One of them, gym, builded by OpenAI, is a toolkit for developing Reinforcement Learning algorithm. Other packages are for visualization in colab." + ] + }, + { + "cell_type": "code", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "5e2bScpnkVbv", + "outputId": "52198e39-e2a2-4ea2-a1f3-4ba9545476d7" + }, + "source": [ + "!apt update\n", + "!apt install python-opengl xvfb -y\n", + "!pip install gym[box2d]==0.18.3 pyvirtualdisplay tqdm numpy==1.19.5 torch==1.8.1" + ], + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "text": [ + "Hit:1 http://security.ubuntu.com/ubuntu bionic-security InRelease\n", + "Ign:2 https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64 InRelease\n", + "Hit:3 https://cloud.r-project.org/bin/linux/ubuntu bionic-cran40/ InRelease\n", + "Ign:4 https://developer.download.nvidia.com/compute/machine-learning/repos/ubuntu1804/x86_64 InRelease\n", + "Hit:5 http://archive.ubuntu.com/ubuntu bionic InRelease\n", + "Hit:6 https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64 Release\n", + "Hit:7 http://ppa.launchpad.net/c2d4u.team/c2d4u4.0+/ubuntu bionic InRelease\n", + "Hit:8 https://developer.download.nvidia.com/compute/machine-learning/repos/ubuntu1804/x86_64 Release\n", + "Hit:9 http://archive.ubuntu.com/ubuntu bionic-updates InRelease\n", + "Hit:10 http://archive.ubuntu.com/ubuntu bionic-backports InRelease\n", + "Hit:11 http://ppa.launchpad.net/cran/libgit2/ubuntu bionic InRelease\n", + "Hit:12 http://ppa.launchpad.net/deadsnakes/ppa/ubuntu bionic InRelease\n", + "Hit:13 http://ppa.launchpad.net/graphics-drivers/ppa/ubuntu bionic InRelease\n", + "Reading package lists... Done\n", + "Building dependency tree \n", + "Reading state information... Done\n", + "86 packages can be upgraded. Run 'apt list --upgradable' to see them.\n", + "Reading package lists... Done\n", + "Building dependency tree \n", + "Reading state information... Done\n", + "python-opengl is already the newest version (3.1.0+dfsg-1).\n", + "xvfb is already the newest version (2:1.19.6-1ubuntu4.9).\n", + "The following package was automatically installed and is no longer required:\n", + " libnvidia-common-460\n", + "Use 'apt autoremove' to remove it.\n", + "0 upgraded, 0 newly installed, 0 to remove and 86 not upgraded.\n", + "Requirement already satisfied: gym[box2d] in /usr/local/lib/python3.7/dist-packages (0.17.3)\n", + "Requirement already satisfied: pyvirtualdisplay in /usr/local/lib/python3.7/dist-packages (2.1)\n", + "Requirement already satisfied: tqdm in /usr/local/lib/python3.7/dist-packages (4.41.1)\n", + "Requirement already satisfied: numpy>=1.10.4 in /usr/local/lib/python3.7/dist-packages (from gym[box2d]) (1.19.5)\n", + "Requirement already satisfied: scipy in /usr/local/lib/python3.7/dist-packages (from gym[box2d]) (1.4.1)\n", + "Requirement already satisfied: pyglet<=1.5.0,>=1.4.0 in /usr/local/lib/python3.7/dist-packages (from gym[box2d]) (1.5.0)\n", + "Requirement already satisfied: cloudpickle<1.7.0,>=1.2.0 in /usr/local/lib/python3.7/dist-packages (from gym[box2d]) (1.3.0)\n", + "Requirement already satisfied: box2d-py~=2.3.5; extra == \"box2d\" in /usr/local/lib/python3.7/dist-packages (from gym[box2d]) (2.3.8)\n", + "Requirement already satisfied: EasyProcess in /usr/local/lib/python3.7/dist-packages (from pyvirtualdisplay) (0.3)\n", + "Requirement already satisfied: future in /usr/local/lib/python3.7/dist-packages (from pyglet<=1.5.0,>=1.4.0->gym[box2d]) (0.16.0)\n" + ], + "name": "stdout" + } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "M_-i3cdoYsks" + }, + "source": [ + "\n", + "Next, set up virtual display,and import all necessaary packages." + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "nl2nREINDLiw" + }, + "source": [ + "%%capture\n", + "from pyvirtualdisplay import Display\n", + "virtual_display = Display(visible=0, size=(1400, 900))\n", + "virtual_display.start()\n", + "\n", + "%matplotlib inline\n", + "import matplotlib.pyplot as plt\n", + "\n", + "from IPython import display\n", + "\n", + "import numpy as np\n", + "import torch\n", + "import torch.nn as nn\n", + "import torch.optim as optim\n", + "import torch.nn.functional as F\n", + "from torch.distributions import Categorical\n", + "from tqdm.notebook import tqdm" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "CaEJ8BUCpN9P" + }, + "source": [ + "# Warning ! Do not revise random seed !!!\n", + "# Your submission on JudgeBoi will not reproduce your result !!!\n", + "Make your HW result to be reproducible.\n" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "fV9i8i2YkRbO" + }, + "source": [ + "seed = 543 # Do not change this\n", + "def fix(env, seed):\n", + " env.seed(seed)\n", + " env.action_space.seed(seed)\n", + " torch.manual_seed(seed)\n", + " torch.cuda.manual_seed(seed)\n", + " torch.cuda.manual_seed_all(seed)\n", + " np.random.seed(seed)\n", + " random.seed(seed)\n", + " torch.set_deterministic(True)\n", + " torch.backends.cudnn.benchmark = False\n", + " torch.backends.cudnn.deterministic = True" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "He0XDx6bzjgC" + }, + "source": [ + "Last, call gym and build an [Lunar Lander](https://gym.openai.com/envs/LunarLander-v2/) environment." + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "N_4-xJcbBt09" + }, + "source": [ + "%%capture\n", + "import gym\n", + "import random\n", + "env = gym.make('LunarLander-v2')\n", + "fix(env, seed) # fix the environment Do not revise this !!!" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "NrkVvTrvWZ5H" + }, + "source": [ + "## What Lunar Lander?\n", + "\n", + "“LunarLander-v2”is to simulate the situation when the craft lands on the surface of the moon.\n", + "\n", + "This task is to enable the craft to land \"safely\" at the pad between the two yellow flags.\n", + "> Landing pad is always at coordinates (0,0).\n", + "> Coordinates are the first two numbers in state vector.\n", + "\n", + "![](https://gym.openai.com/assets/docs/aeloop-138c89d44114492fd02822303e6b4b07213010bb14ca5856d2d49d6b62d88e53.svg)\n", + "\n", + "\"LunarLander-v2\" actually includes \"Agent\" and \"Environment\". \n", + "\n", + "In this homework, we will utilize the function `step()` to control the action of \"Agent\". \n", + "\n", + "Then `step()` will return the observation/state and reward given by the \"Environment\"." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "bIbp82sljvAt" + }, + "source": [ + "### Observation / State\n", + "\n", + "First, we can take a look at what an Observation / State looks like." + ] + }, + { + "cell_type": "code", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "rsXZra3N9R5T", + "outputId": "a36868de-bbbc-4de9-815b-0b43fc012c96" + }, + "source": [ + "print(env.observation_space)" + ], + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "text": [ + "Box(-inf, inf, (8,), float32)\n" + ], + "name": "stdout" + } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "ezdfoThbAQ49" + }, + "source": [ + "\n", + "`Box(8,)`means that observation is an 8-dim vector\n", + "### Action\n", + "\n", + "Actions can be taken by looks like" + ] + }, + { + "cell_type": "code", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "p1k4dIrBAaKi", + "outputId": "80c453ee-539f-4e40-c5d8-8e9dc8fffaef" + }, + "source": [ + "print(env.action_space)" + ], + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "text": [ + "Discrete(4)\n" + ], + "name": "stdout" + } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "dejXT6PHBrPn" + }, + "source": [ + "`Discrete(4)` implies that there are four kinds of actions can be taken by agent.\n", + "- 0 implies the agent will not take any actions\n", + "- 2 implies the agent will accelerate downward\n", + "- 1, 3 implies the agent will accelerate left and right\n", + "\n", + "Next, we will try to make the agent interact with the environment. \n", + "Before taking any actions, we recommend to call `reset()` function to reset the environment. Also, this function will return the initial state of the environment." + ] + }, + { + "cell_type": "code", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "pi4OmrmZgnWA", + "outputId": "2635bdfb-a4dc-442b-a21a-f57af67edc4b" + }, + "source": [ + "initial_state = env.reset()\n", + "print(initial_state)" + ], + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "text": [ + "[ 0.00396109 1.4083536 0.40119505 -0.11407257 -0.00458307 -0.09087662\n", + " 0. 0. ]\n" + ], + "name": "stdout" + } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "uBx0mEqqgxJ9" + }, + "source": [ + "Then, we try to get a random action from the agent's action space." + ] + }, + { + "cell_type": "code", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "vxkOEXRKgizt", + "outputId": "de93c740-f01c-464e-f436-a2b59e7dc7e5" + }, + "source": [ + "random_action = env.action_space.sample()\n", + "print(random_action)" + ], + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "text": [ + "0\n" + ], + "name": "stdout" + } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "mns-bO01g0-J" + }, + "source": [ + "More, we can utilize `step()` to make agent act according to the randomly-selected `random_action`.\n", + "The `step()` function will return four values:\n", + "- observation / state\n", + "- reward\n", + "- done (True/ False)\n", + "- Other information" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "E_WViSxGgIk9" + }, + "source": [ + "observation, reward, done, info = env.step(random_action)" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "yK7r126kuCNp", + "outputId": "2f3363d9-5bc3-4ba5-86f2-1c4abc89f179" + }, + "source": [ + "print(done)" + ], + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "text": [ + "False\n" + ], + "name": "stdout" + } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "GKdS8vOihxhc" + }, + "source": [ + "### Reward\n", + "\n", + "\n", + "> Landing pad is always at coordinates (0,0). Coordinates are the first two numbers in state vector. Reward for moving from the top of the screen to landing pad and zero speed is about 100..140 points. If lander moves away from landing pad it loses reward back. Episode finishes if the lander crashes or comes to rest, receiving additional -100 or +100 points. Each leg ground contact is +10. Firing main engine is -0.3 points each frame. Solved is 200 points. " + ] + }, + { + "cell_type": "code", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "vxQNs77hi0_7", + "outputId": "4633d678-be4f-4f52-8f91-1b6681642580" + }, + "source": [ + "print(reward)" + ], + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "text": [ + "-0.8588900517154912\n" + ], + "name": "stdout" + } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Mhqp6D-XgHpe" + }, + "source": [ + "### Random Agent\n", + "In the end, before we start training, we can see whether a random agent can successfully land the moon or not." + ] + }, + { + "cell_type": "code", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 269 + }, + "id": "Y3G0bxoccelv", + "outputId": "11ad28c1-058b-4243-bf35-1fdfbb60be9e" + }, + "source": [ + "env.reset()\n", + "\n", + "img = plt.imshow(env.render(mode='rgb_array'))\n", + "\n", + "done = False\n", + "while not done:\n", + " action = env.action_space.sample()\n", + " observation, reward, done, _ = env.step(action)\n", + "\n", + " img.set_data(env.render(mode='rgb_array'))\n", + " display.display(plt.gcf())\n", + " display.clear_output(wait=True)" + ], + "execution_count": null, + "outputs": [ + { + "output_type": "display_data", + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "tags": [], + "needs_background": "light" + } + } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "F5paWqo7tWL2" + }, + "source": [ + "## Policy Gradient\n", + "Now, we can build a simple policy network. The network will return one of action in the action space." + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "J8tdmeD-tZew" + }, + "source": [ + "class PolicyGradientNetwork(nn.Module):\n", + "\n", + " def __init__(self):\n", + " super().__init__()\n", + " self.fc1 = nn.Linear(8, 16)\n", + " self.fc2 = nn.Linear(16, 16)\n", + " self.fc3 = nn.Linear(16, 4)\n", + "\n", + " def forward(self, state):\n", + " hid = torch.tanh(self.fc1(state))\n", + " hid = torch.tanh(self.fc2(hid))\n", + " return F.softmax(self.fc3(hid), dim=-1)" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "ynbqJrhIFTC3" + }, + "source": [ + "Then, we need to build a simple agent. The agent will acts according to the output of the policy network above. There are a few things can be done by agent:\n", + "- `learn()`:update the policy network from log probabilities and rewards.\n", + "- `sample()`:After receiving observation from the environment, utilize policy network to tell which action to take. The return values of this function includes action and log probabilities. " + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "zZo-IxJx286z" + }, + "source": [ + "from torch.optim.lr_scheduler import StepLR\n", + "class PolicyGradientAgent():\n", + " \n", + " def __init__(self, network):\n", + " self.network = network\n", + " self.optimizer = optim.SGD(self.network.parameters(), lr=0.001)\n", + " \n", + " def forward(self, state):\n", + " return self.network(state)\n", + " def learn(self, log_probs, rewards):\n", + " loss = (-log_probs * rewards).sum() # You don't need to revise this to pass simple baseline (but you can)\n", + "\n", + " self.optimizer.zero_grad()\n", + " loss.backward()\n", + " self.optimizer.step()\n", + " \n", + " def sample(self, state):\n", + " action_prob = self.network(torch.FloatTensor(state))\n", + " action_dist = Categorical(action_prob)\n", + " action = action_dist.sample()\n", + " log_prob = action_dist.log_prob(action)\n", + " return action.item(), log_prob" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "ehPlnTKyRZf9" + }, + "source": [ + "Lastly, build a network and agent to start training." + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "GfJIvML-RYjL" + }, + "source": [ + "network = PolicyGradientNetwork()\n", + "agent = PolicyGradientAgent(network)" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "ouv23glgf5Qt" + }, + "source": [ + "## Trainin Agent\n", + "\n", + "Now let's start to train our agent.\n", + "Through taking all the interactions between agent and environment as training data, the policy network can learn from all these attempts," + ] + }, + { + "cell_type": "code", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 1000, + "referenced_widgets": [ + "de3a153737af485ea436d7e8393d8248", + "6345f8926212465291c04587353161f1", + "3aa84c8c097d4858a0c38f18dec8b060", + "6647e68cf064416ca593b990bed81edf", + "a25de7edbbee47cc8094f43125efa39b", + "a0bb5061ca0946ab89a89f9abadcebc8", + "d99d8e6bcbe0445eb7eddbfe31277635", + "5af1c6579f0f4c2eb03df4063749569e" + ] + }, + "id": "vg5rxBBaf38_", + "outputId": "a1b06e39-99d6-4233-eda0-a3d58e77ffed" + }, + "source": [ + "agent.network.train() # Switch network into training mode \n", + "EPISODE_PER_BATCH = 5 # update the agent every 5 episode\n", + "NUM_BATCH = 400 # totally update the agent for 400 time\n", + "\n", + "avg_total_rewards, avg_final_rewards = [], []\n", + "\n", + "prg_bar = tqdm(range(NUM_BATCH))\n", + "for batch in prg_bar:\n", + "\n", + " log_probs, rewards = [], []\n", + " total_rewards, final_rewards = [], []\n", + "\n", + " # collect trajectory\n", + " for episode in range(EPISODE_PER_BATCH):\n", + " \n", + " state = env.reset()\n", + " total_reward, total_step = 0, 0\n", + " seq_rewards = []\n", + " while True:\n", + "\n", + " action, log_prob = agent.sample(state) # at, log(at|st)\n", + " next_state, reward, done, _ = env.step(action)\n", + "\n", + " log_probs.append(log_prob) # [log(a1|s1), log(a2|s2), ...., log(at|st)]\n", + " # seq_rewards.append(reward)\n", + " state = next_state\n", + " total_reward += reward\n", + " total_step += 1\n", + " rewards.append(reward) # change here\n", + " # ! IMPORTANT !\n", + " # Current reward implementation: immediate reward, given action_list : a1, a2, a3 ......\n", + " # rewards : r1, r2 ,r3 ......\n", + " # medium:change \"rewards\" to accumulative decaying reward, given action_list : a1, a2, a3, ......\n", + " # rewards : r1+0.99*r2+0.99^2*r3+......, r2+0.99*r3+0.99^2*r4+...... , r3+0.99*r4+0.99^2*r5+ ......\n", + " # boss : implement DQN\n", + " if done:\n", + " final_rewards.append(reward)\n", + " total_rewards.append(total_reward)\n", + " \n", + " break\n", + "\n", + " print(f\"rewards looks like \", np.shape(rewards)) \n", + " print(f\"log_probs looks like \", np.shape(log_probs)) \n", + " # record training process\n", + " avg_total_reward = sum(total_rewards) / len(total_rewards)\n", + " avg_final_reward = sum(final_rewards) / len(final_rewards)\n", + " avg_total_rewards.append(avg_total_reward)\n", + " avg_final_rewards.append(avg_final_reward)\n", + " prg_bar.set_description(f\"Total: {avg_total_reward: 4.1f}, Final: {avg_final_reward: 4.1f}\")\n", + "\n", + " # update agent\n", + " # rewards = np.concatenate(rewards, axis=0)\n", + " rewards = (rewards - np.mean(rewards)) / (np.std(rewards) + 1e-9) # normalize the reward \n", + " agent.learn(torch.stack(log_probs), torch.from_numpy(rewards))\n", + " print(\"logs prob looks like \", torch.stack(log_probs).size())\n", + " print(\"torch.from_numpy(rewards) looks like \", torch.from_numpy(rewards).size())" + ], + "execution_count": null, + "outputs": [ + { + "output_type": "display_data", + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "de3a153737af485ea436d7e8393d8248", + "version_minor": 0, + "version_major": 2 + }, + "text/plain": [ + "HBox(children=(FloatProgress(value=0.0, max=400.0), HTML(value='')))" + ] + }, + "metadata": { + "tags": [] + } + }, + { + "output_type": "stream", + "text": [ + "rewards looks like (448,)\n", + "log_probs looks like (448,)\n", + "logs prob looks like torch.Size([448])\n", + "torch.from_numpy(rewards) looks like torch.Size([448])\n", + "rewards looks like (515,)\n", + "log_probs looks like (515,)\n", + "logs prob looks like torch.Size([515])\n", + "torch.from_numpy(rewards) looks like torch.Size([515])\n", + "rewards looks like (392,)\n", + "log_probs looks like (392,)\n", + "logs prob looks like torch.Size([392])\n", + "torch.from_numpy(rewards) looks like torch.Size([392])\n", + "rewards looks like (518,)\n", + "log_probs looks like (518,)\n", + "logs prob looks like torch.Size([518])\n", + "torch.from_numpy(rewards) looks like torch.Size([518])\n", + "rewards looks like (472,)\n", + "log_probs looks like (472,)\n", + "logs prob looks like torch.Size([472])\n", + "torch.from_numpy(rewards) looks like torch.Size([472])\n", + "rewards looks like (530,)\n", + "log_probs looks like (530,)\n", + "logs prob looks like torch.Size([530])\n", + "torch.from_numpy(rewards) looks like torch.Size([530])\n", + "rewards looks like (463,)\n", + "log_probs looks like (463,)\n", + "logs prob looks like torch.Size([463])\n", + "torch.from_numpy(rewards) looks like torch.Size([463])\n", + "rewards looks like (540,)\n", + "log_probs looks like (540,)\n", + "logs prob looks like torch.Size([540])\n", + "torch.from_numpy(rewards) looks like torch.Size([540])\n", + "rewards looks like (513,)\n", + "log_probs looks like (513,)\n", + "logs prob looks like torch.Size([513])\n", + "torch.from_numpy(rewards) looks like torch.Size([513])\n", + "rewards looks like (449,)\n", + "log_probs looks like (449,)\n", + "logs prob looks like torch.Size([449])\n", + "torch.from_numpy(rewards) looks like torch.Size([449])\n", + "rewards looks like (602,)\n", + "log_probs looks like (602,)\n", + "logs prob looks like torch.Size([602])\n", + "torch.from_numpy(rewards) looks like torch.Size([602])\n", + "rewards looks like (542,)\n", + "log_probs looks like (542,)\n", + "logs prob looks like torch.Size([542])\n", + "torch.from_numpy(rewards) looks like torch.Size([542])\n", + "rewards looks like (503,)\n", + "log_probs looks like (503,)\n", + "logs prob looks like torch.Size([503])\n", + "torch.from_numpy(rewards) looks like torch.Size([503])\n", + "rewards looks like (470,)\n", + "log_probs looks like (470,)\n", + "logs prob looks like torch.Size([470])\n", + "torch.from_numpy(rewards) looks like torch.Size([470])\n", + "rewards looks like (518,)\n", + "log_probs looks like (518,)\n", + "logs prob looks like torch.Size([518])\n", + "torch.from_numpy(rewards) looks like torch.Size([518])\n", + "rewards looks like (421,)\n", + "log_probs looks like (421,)\n", + "logs prob looks like torch.Size([421])\n", + "torch.from_numpy(rewards) looks like torch.Size([421])\n", + "rewards looks like (592,)\n", + "log_probs looks like (592,)\n", + "logs prob looks like torch.Size([592])\n", + "torch.from_numpy(rewards) looks like torch.Size([592])\n", + "rewards looks like (520,)\n", + "log_probs looks like (520,)\n", + "logs prob looks like torch.Size([520])\n", + "torch.from_numpy(rewards) looks like torch.Size([520])\n", + "rewards looks like (494,)\n", + "log_probs looks like (494,)\n", + "logs prob looks like torch.Size([494])\n", + "torch.from_numpy(rewards) looks like torch.Size([494])\n", + "rewards looks like (461,)\n", + "log_probs looks like (461,)\n", + "logs prob looks like torch.Size([461])\n", + "torch.from_numpy(rewards) looks like torch.Size([461])\n", + "rewards looks like (572,)\n", + "log_probs looks like (572,)\n", + "logs prob looks like torch.Size([572])\n", + "torch.from_numpy(rewards) looks like torch.Size([572])\n", + "rewards looks like (593,)\n", + "log_probs looks like (593,)\n", + "logs prob looks like torch.Size([593])\n", + "torch.from_numpy(rewards) looks like torch.Size([593])\n", + "rewards looks like (569,)\n", + "log_probs looks like (569,)\n", + "logs prob looks like torch.Size([569])\n", + "torch.from_numpy(rewards) looks like torch.Size([569])\n", + "rewards looks like (546,)\n", + "log_probs looks like (546,)\n", + "logs prob looks like torch.Size([546])\n", + "torch.from_numpy(rewards) looks like torch.Size([546])\n", + "rewards looks like (612,)\n", + "log_probs looks like (612,)\n", + "logs prob looks like torch.Size([612])\n", + "torch.from_numpy(rewards) looks like torch.Size([612])\n", + "rewards looks like (534,)\n", + "log_probs looks like (534,)\n", + "logs prob looks like torch.Size([534])\n", + "torch.from_numpy(rewards) looks like torch.Size([534])\n", + "rewards looks like (513,)\n", + "log_probs looks like (513,)\n", + "logs prob looks like torch.Size([513])\n", + "torch.from_numpy(rewards) looks like torch.Size([513])\n", + "rewards looks like (513,)\n", + "log_probs looks like (513,)\n", + "logs prob looks like torch.Size([513])\n", + "torch.from_numpy(rewards) looks like torch.Size([513])\n", + "rewards looks like (535,)\n", + "log_probs looks like (535,)\n", + "logs prob looks like torch.Size([535])\n", + "torch.from_numpy(rewards) looks like torch.Size([535])\n", + "rewards looks like (533,)\n", + "log_probs looks like (533,)\n", + "logs prob looks like torch.Size([533])\n", + "torch.from_numpy(rewards) looks like torch.Size([533])\n", + "rewards looks like (521,)\n", + "log_probs looks like (521,)\n", + "logs prob looks like torch.Size([521])\n", + "torch.from_numpy(rewards) looks like torch.Size([521])\n", + "rewards looks like (566,)\n", + "log_probs looks like (566,)\n", + "logs prob looks like torch.Size([566])\n", + "torch.from_numpy(rewards) looks like torch.Size([566])\n", + "rewards looks like (586,)\n", + "log_probs looks like (586,)\n", + "logs prob looks like torch.Size([586])\n", + "torch.from_numpy(rewards) looks like torch.Size([586])\n", + "rewards looks like (575,)\n", + "log_probs looks like (575,)\n", + "logs prob looks like torch.Size([575])\n", + "torch.from_numpy(rewards) looks like torch.Size([575])\n", + "rewards looks like (709,)\n", + "log_probs looks like (709,)\n", + "logs prob looks like torch.Size([709])\n", + "torch.from_numpy(rewards) looks like torch.Size([709])\n", + "rewards looks like (486,)\n", + "log_probs looks like (486,)\n", + "logs prob looks like torch.Size([486])\n", + "torch.from_numpy(rewards) looks like torch.Size([486])\n", + "rewards looks like (557,)\n", + "log_probs looks like (557,)\n", + "logs prob looks like torch.Size([557])\n", + "torch.from_numpy(rewards) looks like torch.Size([557])\n", + "rewards looks like (517,)\n", + "log_probs looks like (517,)\n", + "logs prob looks like torch.Size([517])\n", + "torch.from_numpy(rewards) looks like torch.Size([517])\n", + "rewards looks like (550,)\n", + "log_probs looks like (550,)\n", + "logs prob looks like torch.Size([550])\n", + "torch.from_numpy(rewards) looks like torch.Size([550])\n", + "rewards looks like (690,)\n", + "log_probs looks like (690,)\n", + "logs prob looks like torch.Size([690])\n", + "torch.from_numpy(rewards) looks like torch.Size([690])\n", + "rewards looks like (591,)\n", + "log_probs looks like (591,)\n", + "logs prob looks like torch.Size([591])\n", + "torch.from_numpy(rewards) looks like torch.Size([591])\n", + "rewards looks like (689,)\n", + "log_probs looks like (689,)\n", + "logs prob looks like torch.Size([689])\n", + "torch.from_numpy(rewards) looks like torch.Size([689])\n", + "rewards looks like (1059,)\n", + "log_probs looks like (1059,)\n", + "logs prob looks like torch.Size([1059])\n", + "torch.from_numpy(rewards) looks like torch.Size([1059])\n", + "rewards looks like (619,)\n", + "log_probs looks like (619,)\n", + "logs prob looks like torch.Size([619])\n", + "torch.from_numpy(rewards) looks like torch.Size([619])\n", + "rewards looks like (527,)\n", + "log_probs looks like (527,)\n", + "logs prob looks like torch.Size([527])\n", + "torch.from_numpy(rewards) looks like torch.Size([527])\n", + "rewards looks like (514,)\n", + "log_probs looks like (514,)\n", + "logs prob looks like torch.Size([514])\n", + "torch.from_numpy(rewards) looks like torch.Size([514])\n", + "rewards looks like (655,)\n", + "log_probs looks like (655,)\n", + "logs prob looks like torch.Size([655])\n", + "torch.from_numpy(rewards) looks like torch.Size([655])\n", + "rewards looks like (667,)\n", + "log_probs looks like (667,)\n", + "logs prob looks like torch.Size([667])\n", + "torch.from_numpy(rewards) looks like torch.Size([667])\n", + "rewards looks like (712,)\n", + "log_probs looks like (712,)\n", + "logs prob looks like torch.Size([712])\n", + "torch.from_numpy(rewards) looks like torch.Size([712])\n", + "rewards looks like (636,)\n", + "log_probs looks like (636,)\n", + "logs prob looks like torch.Size([636])\n", + "torch.from_numpy(rewards) looks like torch.Size([636])\n", + "rewards looks like (620,)\n", + "log_probs looks like (620,)\n", + "logs prob looks like torch.Size([620])\n", + "torch.from_numpy(rewards) looks like torch.Size([620])\n", + "rewards looks like (543,)\n", + "log_probs looks like (543,)\n", + "logs prob looks like torch.Size([543])\n", + "torch.from_numpy(rewards) looks like torch.Size([543])\n", + "rewards looks like (586,)\n", + "log_probs looks like (586,)\n", + "logs prob looks like torch.Size([586])\n", + "torch.from_numpy(rewards) looks like torch.Size([586])\n", + "rewards looks like (498,)\n", + "log_probs looks like (498,)\n", + "logs prob looks like torch.Size([498])\n", + "torch.from_numpy(rewards) looks like torch.Size([498])\n", + "rewards looks like (586,)\n", + "log_probs looks like (586,)\n", + "logs prob looks like torch.Size([586])\n", + "torch.from_numpy(rewards) looks like torch.Size([586])\n", + "rewards looks like (591,)\n", + "log_probs looks like (591,)\n", + "logs prob looks like torch.Size([591])\n", + "torch.from_numpy(rewards) looks like torch.Size([591])\n", + "rewards looks like (693,)\n", + "log_probs looks like (693,)\n", + "logs prob looks like torch.Size([693])\n", + "torch.from_numpy(rewards) looks like torch.Size([693])\n", + "rewards looks like (648,)\n", + "log_probs looks like (648,)\n", + "logs prob looks like torch.Size([648])\n", + "torch.from_numpy(rewards) looks like torch.Size([648])\n", + "rewards looks like (513,)\n", + "log_probs looks like (513,)\n", + "logs prob looks like torch.Size([513])\n", + "torch.from_numpy(rewards) looks like torch.Size([513])\n", + "rewards looks like (574,)\n", + "log_probs looks like (574,)\n", + "logs prob looks like torch.Size([574])\n", + "torch.from_numpy(rewards) looks like torch.Size([574])\n", + "rewards looks like (718,)\n", + "log_probs looks like (718,)\n", + "logs prob looks like torch.Size([718])\n", + "torch.from_numpy(rewards) looks like torch.Size([718])\n", + "rewards looks like (730,)\n", + "log_probs looks like (730,)\n", + "logs prob looks like torch.Size([730])\n", + "torch.from_numpy(rewards) looks like torch.Size([730])\n", + "rewards looks like (668,)\n", + "log_probs looks like (668,)\n", + "logs prob looks like torch.Size([668])\n", + "torch.from_numpy(rewards) looks like torch.Size([668])\n", + "rewards looks like (754,)\n", + "log_probs looks like (754,)\n", + "logs prob looks like torch.Size([754])\n", + "torch.from_numpy(rewards) looks like torch.Size([754])\n", + "rewards looks like (712,)\n", + "log_probs looks like (712,)\n", + "logs prob looks like torch.Size([712])\n", + "torch.from_numpy(rewards) looks like torch.Size([712])\n", + "rewards looks like (470,)\n", + "log_probs looks like (470,)\n", + "logs prob looks like torch.Size([470])\n", + "torch.from_numpy(rewards) looks like torch.Size([470])\n", + "rewards looks like (665,)\n", + "log_probs looks like (665,)\n", + "logs prob looks like torch.Size([665])\n", + "torch.from_numpy(rewards) looks like torch.Size([665])\n", + "rewards looks like (585,)\n", + "log_probs looks like (585,)\n", + "logs prob looks like torch.Size([585])\n", + "torch.from_numpy(rewards) looks like torch.Size([585])\n", + "rewards looks like (512,)\n", + "log_probs looks like (512,)\n", + "logs prob looks like torch.Size([512])\n", + "torch.from_numpy(rewards) looks like torch.Size([512])\n", + "rewards looks like (702,)\n", + "log_probs looks like (702,)\n", + "logs prob looks like torch.Size([702])\n", + "torch.from_numpy(rewards) looks like torch.Size([702])\n", + "rewards looks like (596,)\n", + "log_probs looks like (596,)\n", + "logs prob looks like torch.Size([596])\n", + "torch.from_numpy(rewards) looks like torch.Size([596])\n", + "rewards looks like (626,)\n", + "log_probs looks like (626,)\n", + "logs prob looks like torch.Size([626])\n", + "torch.from_numpy(rewards) looks like torch.Size([626])\n", + "rewards looks like (566,)\n", + "log_probs looks like (566,)\n", + "logs prob looks like torch.Size([566])\n", + "torch.from_numpy(rewards) looks like torch.Size([566])\n", + "rewards looks like (717,)\n", + "log_probs looks like (717,)\n", + "logs prob looks like torch.Size([717])\n", + "torch.from_numpy(rewards) looks like torch.Size([717])\n", + "rewards looks like (708,)\n", + "log_probs looks like (708,)\n", + "logs prob looks like torch.Size([708])\n", + "torch.from_numpy(rewards) looks like torch.Size([708])\n", + "rewards looks like (565,)\n", + "log_probs looks like (565,)\n", + "logs prob looks like torch.Size([565])\n", + "torch.from_numpy(rewards) looks like torch.Size([565])\n", + "rewards looks like (450,)\n", + "log_probs looks like (450,)\n", + "logs prob looks like torch.Size([450])\n", + "torch.from_numpy(rewards) looks like torch.Size([450])\n", + "rewards looks like (584,)\n", + "log_probs looks like (584,)\n", + "logs prob looks like torch.Size([584])\n", + "torch.from_numpy(rewards) looks like torch.Size([584])\n", + "rewards looks like (670,)\n", + "log_probs looks like (670,)\n", + "logs prob looks like torch.Size([670])\n", + "torch.from_numpy(rewards) looks like torch.Size([670])\n", + "rewards looks like (691,)\n", + "log_probs looks like (691,)\n", + "logs prob looks like torch.Size([691])\n", + "torch.from_numpy(rewards) looks like torch.Size([691])\n", + "rewards looks like (760,)\n", + "log_probs looks like (760,)\n", + "logs prob looks like torch.Size([760])\n", + "torch.from_numpy(rewards) looks like torch.Size([760])\n", + "rewards looks like (752,)\n", + "log_probs looks like (752,)\n", + "logs prob looks like torch.Size([752])\n", + "torch.from_numpy(rewards) looks like torch.Size([752])\n", + "rewards looks like (478,)\n", + "log_probs looks like (478,)\n", + "logs prob looks like torch.Size([478])\n", + "torch.from_numpy(rewards) looks like torch.Size([478])\n", + "rewards looks like (553,)\n", + "log_probs looks like (553,)\n", + "logs prob looks like torch.Size([553])\n", + "torch.from_numpy(rewards) looks like torch.Size([553])\n", + "rewards looks like (1660,)\n", + "log_probs looks like (1660,)\n", + "logs prob looks like torch.Size([1660])\n", + "torch.from_numpy(rewards) looks like torch.Size([1660])\n", + "rewards looks like (751,)\n", + "log_probs looks like (751,)\n", + "logs prob looks like torch.Size([751])\n", + "torch.from_numpy(rewards) looks like torch.Size([751])\n", + "rewards looks like (801,)\n", + "log_probs looks like (801,)\n", + "logs prob looks like torch.Size([801])\n", + "torch.from_numpy(rewards) looks like torch.Size([801])\n", + "rewards looks like (715,)\n", + "log_probs looks like (715,)\n", + "logs prob looks like torch.Size([715])\n", + "torch.from_numpy(rewards) looks like torch.Size([715])\n", + "rewards looks like (708,)\n", + "log_probs looks like (708,)\n", + "logs prob looks like torch.Size([708])\n", + "torch.from_numpy(rewards) looks like torch.Size([708])\n", + "rewards looks like (609,)\n", + "log_probs looks like (609,)\n", + "logs prob looks like torch.Size([609])\n", + "torch.from_numpy(rewards) looks like torch.Size([609])\n", + "rewards looks like (732,)\n", + "log_probs looks like (732,)\n", + "logs prob looks like torch.Size([732])\n", + "torch.from_numpy(rewards) looks like torch.Size([732])\n", + "rewards looks like (603,)\n", + "log_probs looks like (603,)\n", + "logs prob looks like torch.Size([603])\n", + "torch.from_numpy(rewards) looks like torch.Size([603])\n", + "rewards looks like (603,)\n", + "log_probs looks like (603,)\n", + "logs prob looks like torch.Size([603])\n", + "torch.from_numpy(rewards) looks like torch.Size([603])\n", + "rewards looks like (665,)\n", + "log_probs looks like (665,)\n", + "logs prob looks like torch.Size([665])\n", + "torch.from_numpy(rewards) looks like torch.Size([665])\n", + "rewards looks like (658,)\n", + "log_probs looks like (658,)\n", + "logs prob looks like torch.Size([658])\n", + "torch.from_numpy(rewards) looks like torch.Size([658])\n", + "rewards looks like (783,)\n", + "log_probs looks like (783,)\n", + "logs prob looks like torch.Size([783])\n", + "torch.from_numpy(rewards) looks like torch.Size([783])\n", + "rewards looks like (652,)\n", + "log_probs looks like (652,)\n", + "logs prob looks like torch.Size([652])\n", + "torch.from_numpy(rewards) looks like torch.Size([652])\n", + "rewards looks like (892,)\n", + "log_probs looks like (892,)\n", + "logs prob looks like torch.Size([892])\n", + "torch.from_numpy(rewards) looks like torch.Size([892])\n", + "rewards looks like (821,)\n", + "log_probs looks like (821,)\n", + "logs prob looks like torch.Size([821])\n", + "torch.from_numpy(rewards) looks like torch.Size([821])\n", + "rewards looks like (986,)\n", + "log_probs looks like (986,)\n", + "logs prob looks like torch.Size([986])\n", + "torch.from_numpy(rewards) looks like torch.Size([986])\n", + "rewards looks like (916,)\n", + "log_probs looks like (916,)\n", + "logs prob looks like torch.Size([916])\n", + "torch.from_numpy(rewards) looks like torch.Size([916])\n", + "rewards looks like (742,)\n", + "log_probs looks like (742,)\n", + "logs prob looks like torch.Size([742])\n", + "torch.from_numpy(rewards) looks like torch.Size([742])\n", + "rewards looks like (604,)\n", + "log_probs looks like (604,)\n", + "logs prob looks like torch.Size([604])\n", + "torch.from_numpy(rewards) looks like torch.Size([604])\n", + "rewards looks like (818,)\n", + "log_probs looks like (818,)\n", + "logs prob looks like torch.Size([818])\n", + "torch.from_numpy(rewards) looks like torch.Size([818])\n", + "rewards looks like (855,)\n", + "log_probs looks like (855,)\n", + "logs prob looks like torch.Size([855])\n", + "torch.from_numpy(rewards) looks like torch.Size([855])\n", + "rewards looks like (795,)\n", + "log_probs looks like (795,)\n", + "logs prob looks like torch.Size([795])\n", + "torch.from_numpy(rewards) looks like torch.Size([795])\n", + "rewards looks like (868,)\n", + "log_probs looks like (868,)\n", + "logs prob looks like torch.Size([868])\n", + "torch.from_numpy(rewards) looks like torch.Size([868])\n", + "rewards looks like (800,)\n", + "log_probs looks like (800,)\n", + "logs prob looks like torch.Size([800])\n", + "torch.from_numpy(rewards) looks like torch.Size([800])\n", + "rewards looks like (820,)\n", + "log_probs looks like (820,)\n", + "logs prob looks like torch.Size([820])\n", + "torch.from_numpy(rewards) looks like torch.Size([820])\n", + "rewards looks like (760,)\n", + "log_probs looks like (760,)\n", + "logs prob looks like torch.Size([760])\n", + "torch.from_numpy(rewards) looks like torch.Size([760])\n", + "rewards looks like (886,)\n", + "log_probs looks like (886,)\n", + "logs prob looks like torch.Size([886])\n", + "torch.from_numpy(rewards) looks like torch.Size([886])\n", + "rewards looks like (1027,)\n", + "log_probs looks like (1027,)\n", + "logs prob looks like torch.Size([1027])\n", + "torch.from_numpy(rewards) looks like torch.Size([1027])\n", + "rewards looks like (819,)\n", + "log_probs looks like (819,)\n", + "logs prob looks like torch.Size([819])\n", + "torch.from_numpy(rewards) looks like torch.Size([819])\n", + "rewards looks like (934,)\n", + "log_probs looks like (934,)\n", + "logs prob looks like torch.Size([934])\n", + "torch.from_numpy(rewards) looks like torch.Size([934])\n", + "rewards looks like (1648,)\n", + "log_probs looks like (1648,)\n", + "logs prob looks like torch.Size([1648])\n", + "torch.from_numpy(rewards) looks like torch.Size([1648])\n", + "rewards looks like (1057,)\n", + "log_probs looks like (1057,)\n", + "logs prob looks like torch.Size([1057])\n", + "torch.from_numpy(rewards) looks like torch.Size([1057])\n", + "rewards looks like (861,)\n", + "log_probs looks like (861,)\n", + "logs prob looks like torch.Size([861])\n", + "torch.from_numpy(rewards) looks like torch.Size([861])\n", + "rewards looks like (1533,)\n", + "log_probs looks like (1533,)\n", + "logs prob looks like torch.Size([1533])\n", + "torch.from_numpy(rewards) looks like torch.Size([1533])\n", + "rewards looks like (920,)\n", + "log_probs looks like (920,)\n", + "logs prob looks like torch.Size([920])\n", + "torch.from_numpy(rewards) looks like torch.Size([920])\n", + "rewards looks like (905,)\n", + "log_probs looks like (905,)\n", + "logs prob looks like torch.Size([905])\n", + "torch.from_numpy(rewards) looks like torch.Size([905])\n", + "rewards looks like (814,)\n", + "log_probs looks like (814,)\n", + "logs prob looks like torch.Size([814])\n", + "torch.from_numpy(rewards) looks like torch.Size([814])\n", + "rewards looks like (809,)\n", + "log_probs looks like (809,)\n", + "logs prob looks like torch.Size([809])\n", + "torch.from_numpy(rewards) looks like torch.Size([809])\n", + "rewards looks like (873,)\n", + "log_probs looks like (873,)\n", + "logs prob looks like torch.Size([873])\n", + "torch.from_numpy(rewards) looks like torch.Size([873])\n", + "rewards looks like (727,)\n", + "log_probs looks like (727,)\n", + "logs prob looks like torch.Size([727])\n", + "torch.from_numpy(rewards) looks like torch.Size([727])\n", + "rewards looks like (1129,)\n", + "log_probs looks like (1129,)\n", + "logs prob looks like torch.Size([1129])\n", + "torch.from_numpy(rewards) looks like torch.Size([1129])\n", + "rewards looks like (1394,)\n", + "log_probs looks like (1394,)\n", + "logs prob looks like torch.Size([1394])\n", + "torch.from_numpy(rewards) looks like torch.Size([1394])\n", + "rewards looks like (884,)\n", + "log_probs looks like (884,)\n", + "logs prob looks like torch.Size([884])\n", + "torch.from_numpy(rewards) looks like torch.Size([884])\n", + "rewards looks like (1132,)\n", + "log_probs looks like (1132,)\n", + "logs prob looks like torch.Size([1132])\n", + "torch.from_numpy(rewards) looks like torch.Size([1132])\n", + "rewards looks like (1007,)\n", + "log_probs looks like (1007,)\n", + "logs prob looks like torch.Size([1007])\n", + "torch.from_numpy(rewards) looks like torch.Size([1007])\n", + "rewards looks like (711,)\n", + "log_probs looks like (711,)\n", + "logs prob looks like torch.Size([711])\n", + "torch.from_numpy(rewards) looks like torch.Size([711])\n", + "rewards looks like (836,)\n", + "log_probs looks like (836,)\n", + "logs prob looks like torch.Size([836])\n", + "torch.from_numpy(rewards) looks like torch.Size([836])\n", + "rewards looks like (1514,)\n", + "log_probs looks like (1514,)\n", + "logs prob looks like torch.Size([1514])\n", + "torch.from_numpy(rewards) looks like torch.Size([1514])\n", + "rewards looks like (896,)\n", + "log_probs looks like (896,)\n", + "logs prob looks like torch.Size([896])\n", + "torch.from_numpy(rewards) looks like torch.Size([896])\n", + "rewards looks like (912,)\n", + "log_probs looks like (912,)\n", + "logs prob looks like torch.Size([912])\n", + "torch.from_numpy(rewards) looks like torch.Size([912])\n", + "rewards looks like (1478,)\n", + "log_probs looks like (1478,)\n", + "logs prob looks like torch.Size([1478])\n", + "torch.from_numpy(rewards) looks like torch.Size([1478])\n", + "rewards looks like (1279,)\n", + "log_probs looks like (1279,)\n", + "logs prob looks like torch.Size([1279])\n", + "torch.from_numpy(rewards) looks like torch.Size([1279])\n", + "rewards looks like (676,)\n", + "log_probs looks like (676,)\n", + "logs prob looks like torch.Size([676])\n", + "torch.from_numpy(rewards) looks like torch.Size([676])\n", + "rewards looks like (1768,)\n", + "log_probs looks like (1768,)\n", + "logs prob looks like torch.Size([1768])\n", + "torch.from_numpy(rewards) looks like torch.Size([1768])\n", + "rewards looks like (897,)\n", + "log_probs looks like (897,)\n", + "logs prob looks like torch.Size([897])\n", + "torch.from_numpy(rewards) looks like torch.Size([897])\n", + "rewards looks like (1119,)\n", + "log_probs looks like (1119,)\n", + "logs prob looks like torch.Size([1119])\n", + "torch.from_numpy(rewards) looks like torch.Size([1119])\n", + "rewards looks like (943,)\n", + "log_probs looks like (943,)\n", + "logs prob looks like torch.Size([943])\n", + "torch.from_numpy(rewards) looks like torch.Size([943])\n", + "rewards looks like (1255,)\n", + "log_probs looks like (1255,)\n", + "logs prob looks like torch.Size([1255])\n", + "torch.from_numpy(rewards) looks like torch.Size([1255])\n", + "rewards looks like (861,)\n", + "log_probs looks like (861,)\n", + "logs prob looks like torch.Size([861])\n", + "torch.from_numpy(rewards) looks like torch.Size([861])\n", + "rewards looks like (1149,)\n", + "log_probs looks like (1149,)\n", + "logs prob looks like torch.Size([1149])\n", + "torch.from_numpy(rewards) looks like torch.Size([1149])\n", + "rewards looks like (1229,)\n", + "log_probs looks like (1229,)\n", + "logs prob looks like torch.Size([1229])\n", + "torch.from_numpy(rewards) looks like torch.Size([1229])\n", + "rewards looks like (1680,)\n", + "log_probs looks like (1680,)\n", + "logs prob looks like torch.Size([1680])\n", + "torch.from_numpy(rewards) looks like torch.Size([1680])\n", + "rewards looks like (1731,)\n", + "log_probs looks like (1731,)\n", + "logs prob looks like torch.Size([1731])\n", + "torch.from_numpy(rewards) looks like torch.Size([1731])\n", + "rewards looks like (1017,)\n", + "log_probs looks like (1017,)\n", + "logs prob looks like torch.Size([1017])\n", + "torch.from_numpy(rewards) looks like torch.Size([1017])\n", + "rewards looks like (990,)\n", + "log_probs looks like (990,)\n", + "logs prob looks like torch.Size([990])\n", + "torch.from_numpy(rewards) looks like torch.Size([990])\n", + "rewards looks like (1020,)\n", + "log_probs looks like (1020,)\n", + "logs prob looks like torch.Size([1020])\n", + "torch.from_numpy(rewards) looks like torch.Size([1020])\n", + "rewards looks like (1240,)\n", + "log_probs looks like (1240,)\n", + "logs prob looks like torch.Size([1240])\n", + "torch.from_numpy(rewards) looks like torch.Size([1240])\n", + "rewards looks like (774,)\n", + "log_probs looks like (774,)\n", + "logs prob looks like torch.Size([774])\n", + "torch.from_numpy(rewards) looks like torch.Size([774])\n", + "rewards looks like (1069,)\n", + "log_probs looks like (1069,)\n", + "logs prob looks like torch.Size([1069])\n", + "torch.from_numpy(rewards) looks like torch.Size([1069])\n", + "rewards looks like (1355,)\n", + "log_probs looks like (1355,)\n", + "logs prob looks like torch.Size([1355])\n", + "torch.from_numpy(rewards) looks like torch.Size([1355])\n", + "rewards looks like (1556,)\n", + "log_probs looks like (1556,)\n", + "logs prob looks like torch.Size([1556])\n", + "torch.from_numpy(rewards) looks like torch.Size([1556])\n", + "rewards looks like (1840,)\n", + "log_probs looks like (1840,)\n", + "logs prob looks like torch.Size([1840])\n", + "torch.from_numpy(rewards) looks like torch.Size([1840])\n", + "rewards looks like (1352,)\n", + "log_probs looks like (1352,)\n", + "logs prob looks like torch.Size([1352])\n", + "torch.from_numpy(rewards) looks like torch.Size([1352])\n", + "rewards looks like (1617,)\n", + "log_probs looks like (1617,)\n", + "logs prob looks like torch.Size([1617])\n", + "torch.from_numpy(rewards) looks like torch.Size([1617])\n", + "rewards looks like (1637,)\n", + "log_probs looks like (1637,)\n", + "logs prob looks like torch.Size([1637])\n", + "torch.from_numpy(rewards) looks like torch.Size([1637])\n", + "rewards looks like (1606,)\n", + "log_probs looks like (1606,)\n", + "logs prob looks like torch.Size([1606])\n", + "torch.from_numpy(rewards) looks like torch.Size([1606])\n", + "rewards looks like (860,)\n", + "log_probs looks like (860,)\n", + "logs prob looks like torch.Size([860])\n", + "torch.from_numpy(rewards) looks like torch.Size([860])\n", + "rewards looks like (1780,)\n", + "log_probs looks like (1780,)\n", + "logs prob looks like torch.Size([1780])\n", + "torch.from_numpy(rewards) looks like torch.Size([1780])\n", + "rewards looks like (2248,)\n", + "log_probs looks like (2248,)\n", + "logs prob looks like torch.Size([2248])\n", + "torch.from_numpy(rewards) looks like torch.Size([2248])\n", + "rewards looks like (1410,)\n", + "log_probs looks like (1410,)\n", + "logs prob looks like torch.Size([1410])\n", + "torch.from_numpy(rewards) looks like torch.Size([1410])\n", + "rewards looks like (557,)\n", + "log_probs looks like (557,)\n", + "logs prob looks like torch.Size([557])\n", + "torch.from_numpy(rewards) looks like torch.Size([557])\n", + "rewards looks like (719,)\n", + "log_probs looks like (719,)\n", + "logs prob looks like torch.Size([719])\n", + "torch.from_numpy(rewards) looks like torch.Size([719])\n", + "rewards looks like (1919,)\n", + "log_probs looks like (1919,)\n", + "logs prob looks like torch.Size([1919])\n", + "torch.from_numpy(rewards) looks like torch.Size([1919])\n", + "rewards looks like (1250,)\n", + "log_probs looks like (1250,)\n", + "logs prob looks like torch.Size([1250])\n", + "torch.from_numpy(rewards) looks like torch.Size([1250])\n", + "rewards looks like (1054,)\n", + "log_probs looks like (1054,)\n", + "logs prob looks like torch.Size([1054])\n", + "torch.from_numpy(rewards) looks like torch.Size([1054])\n", + "rewards looks like (1276,)\n", + "log_probs looks like (1276,)\n", + "logs prob looks like torch.Size([1276])\n", + "torch.from_numpy(rewards) looks like torch.Size([1276])\n", + "rewards looks like (1040,)\n", + "log_probs looks like (1040,)\n", + "logs prob looks like torch.Size([1040])\n", + "torch.from_numpy(rewards) looks like torch.Size([1040])\n", + "rewards looks like (991,)\n", + "log_probs looks like (991,)\n", + "logs prob looks like torch.Size([991])\n", + "torch.from_numpy(rewards) looks like torch.Size([991])\n", + "rewards looks like (1390,)\n", + "log_probs looks like (1390,)\n", + "logs prob looks like torch.Size([1390])\n", + "torch.from_numpy(rewards) looks like torch.Size([1390])\n", + "rewards looks like (1349,)\n", + "log_probs looks like (1349,)\n", + "logs prob looks like torch.Size([1349])\n", + "torch.from_numpy(rewards) looks like torch.Size([1349])\n", + "rewards looks like (1332,)\n", + "log_probs looks like (1332,)\n", + "logs prob looks like torch.Size([1332])\n", + "torch.from_numpy(rewards) looks like torch.Size([1332])\n", + "rewards looks like (1378,)\n", + "log_probs looks like (1378,)\n", + "logs prob looks like torch.Size([1378])\n", + "torch.from_numpy(rewards) looks like torch.Size([1378])\n", + "rewards looks like (1967,)\n", + "log_probs looks like (1967,)\n", + "logs prob looks like torch.Size([1967])\n", + "torch.from_numpy(rewards) looks like torch.Size([1967])\n", + "rewards looks like (1789,)\n", + "log_probs looks like (1789,)\n", + "logs prob looks like torch.Size([1789])\n", + "torch.from_numpy(rewards) looks like torch.Size([1789])\n", + "rewards looks like (1325,)\n", + "log_probs looks like (1325,)\n", + "logs prob looks like torch.Size([1325])\n", + "torch.from_numpy(rewards) looks like torch.Size([1325])\n", + "rewards looks like (1685,)\n", + "log_probs looks like (1685,)\n", + "logs prob looks like torch.Size([1685])\n", + "torch.from_numpy(rewards) looks like torch.Size([1685])\n", + "rewards looks like (1895,)\n", + "log_probs looks like (1895,)\n", + "logs prob looks like torch.Size([1895])\n", + "torch.from_numpy(rewards) looks like torch.Size([1895])\n", + "rewards looks like (1920,)\n", + "log_probs looks like (1920,)\n", + "logs prob looks like torch.Size([1920])\n", + "torch.from_numpy(rewards) looks like torch.Size([1920])\n", + "rewards looks like (1522,)\n", + "log_probs looks like (1522,)\n", + "logs prob looks like torch.Size([1522])\n", + "torch.from_numpy(rewards) looks like torch.Size([1522])\n", + "rewards looks like (1173,)\n", + "log_probs looks like (1173,)\n", + "logs prob looks like torch.Size([1173])\n", + "torch.from_numpy(rewards) looks like torch.Size([1173])\n", + "rewards looks like (2136,)\n", + "log_probs looks like (2136,)\n", + "logs prob looks like torch.Size([2136])\n", + "torch.from_numpy(rewards) looks like torch.Size([2136])\n", + "rewards looks like (1696,)\n", + "log_probs looks like (1696,)\n", + "logs prob looks like torch.Size([1696])\n", + "torch.from_numpy(rewards) looks like torch.Size([1696])\n", + "rewards looks like (568,)\n", + "log_probs looks like (568,)\n", + "logs prob looks like torch.Size([568])\n", + "torch.from_numpy(rewards) looks like torch.Size([568])\n", + "rewards looks like (1475,)\n", + "log_probs looks like (1475,)\n", + "logs prob looks like torch.Size([1475])\n", + "torch.from_numpy(rewards) looks like torch.Size([1475])\n", + "rewards looks like (2470,)\n", + "log_probs looks like (2470,)\n", + "logs prob looks like torch.Size([2470])\n", + "torch.from_numpy(rewards) looks like torch.Size([2470])\n", + "rewards looks like (3053,)\n", + "log_probs looks like (3053,)\n", + "logs prob looks like torch.Size([3053])\n", + "torch.from_numpy(rewards) looks like torch.Size([3053])\n", + "rewards looks like (915,)\n", + "log_probs looks like (915,)\n", + "logs prob looks like torch.Size([915])\n", + "torch.from_numpy(rewards) looks like torch.Size([915])\n", + "rewards looks like (2049,)\n", + "log_probs looks like (2049,)\n", + "logs prob looks like torch.Size([2049])\n", + "torch.from_numpy(rewards) looks like torch.Size([2049])\n", + "rewards looks like (2068,)\n", + "log_probs looks like (2068,)\n", + "logs prob looks like torch.Size([2068])\n", + "torch.from_numpy(rewards) looks like torch.Size([2068])\n", + "rewards looks like (2528,)\n", + "log_probs looks like (2528,)\n", + "logs prob looks like torch.Size([2528])\n", + "torch.from_numpy(rewards) looks like torch.Size([2528])\n", + "rewards looks like (1839,)\n", + "log_probs looks like (1839,)\n", + "logs prob looks like torch.Size([1839])\n", + "torch.from_numpy(rewards) looks like torch.Size([1839])\n", + "rewards looks like (497,)\n", + "log_probs looks like (497,)\n", + "logs prob looks like torch.Size([497])\n", + "torch.from_numpy(rewards) looks like torch.Size([497])\n", + "rewards looks like (627,)\n", + "log_probs looks like (627,)\n", + "logs prob looks like torch.Size([627])\n", + "torch.from_numpy(rewards) looks like torch.Size([627])\n", + "rewards looks like (2354,)\n", + "log_probs looks like (2354,)\n", + "logs prob looks like torch.Size([2354])\n", + "torch.from_numpy(rewards) looks like torch.Size([2354])\n", + "rewards looks like (2394,)\n", + "log_probs looks like (2394,)\n", + "logs prob looks like torch.Size([2394])\n", + "torch.from_numpy(rewards) looks like torch.Size([2394])\n", + "rewards looks like (743,)\n", + "log_probs looks like (743,)\n", + "logs prob looks like torch.Size([743])\n", + "torch.from_numpy(rewards) looks like torch.Size([743])\n", + "rewards looks like (1572,)\n", + "log_probs looks like (1572,)\n", + "logs prob looks like torch.Size([1572])\n", + "torch.from_numpy(rewards) looks like torch.Size([1572])\n", + "rewards looks like (2575,)\n", + "log_probs looks like (2575,)\n", + "logs prob looks like torch.Size([2575])\n", + "torch.from_numpy(rewards) looks like torch.Size([2575])\n", + "rewards looks like (2226,)\n", + "log_probs looks like (2226,)\n", + "logs prob looks like torch.Size([2226])\n", + "torch.from_numpy(rewards) looks like torch.Size([2226])\n", + "rewards looks like (541,)\n", + "log_probs looks like (541,)\n", + "logs prob looks like torch.Size([541])\n", + "torch.from_numpy(rewards) looks like torch.Size([541])\n", + "rewards looks like (820,)\n", + "log_probs looks like (820,)\n", + "logs prob looks like torch.Size([820])\n", + "torch.from_numpy(rewards) looks like torch.Size([820])\n", + "rewards looks like (2584,)\n", + "log_probs looks like (2584,)\n", + "logs prob looks like torch.Size([2584])\n", + "torch.from_numpy(rewards) looks like torch.Size([2584])\n", + "rewards looks like (1792,)\n", + "log_probs looks like (1792,)\n", + "logs prob looks like torch.Size([1792])\n", + "torch.from_numpy(rewards) looks like torch.Size([1792])\n", + "rewards looks like (1613,)\n", + "log_probs looks like (1613,)\n", + "logs prob looks like torch.Size([1613])\n", + "torch.from_numpy(rewards) looks like torch.Size([1613])\n", + "rewards looks like (4300,)\n", + "log_probs looks like (4300,)\n", + "logs prob looks like torch.Size([4300])\n", + "torch.from_numpy(rewards) looks like torch.Size([4300])\n", + "rewards looks like (1602,)\n", + "log_probs looks like (1602,)\n", + "logs prob looks like torch.Size([1602])\n", + "torch.from_numpy(rewards) looks like torch.Size([1602])\n", + "rewards looks like (3313,)\n", + "log_probs looks like (3313,)\n", + "logs prob looks like torch.Size([3313])\n", + "torch.from_numpy(rewards) looks like torch.Size([3313])\n", + "rewards looks like (1538,)\n", + "log_probs looks like (1538,)\n", + "logs prob looks like torch.Size([1538])\n", + "torch.from_numpy(rewards) looks like torch.Size([1538])\n", + "rewards looks like (1824,)\n", + "log_probs looks like (1824,)\n", + "logs prob looks like torch.Size([1824])\n", + "torch.from_numpy(rewards) looks like torch.Size([1824])\n", + "rewards looks like (1320,)\n", + "log_probs looks like (1320,)\n", + "logs prob looks like torch.Size([1320])\n", + "torch.from_numpy(rewards) looks like torch.Size([1320])\n", + "rewards looks like (2077,)\n", + "log_probs looks like (2077,)\n", + "logs prob looks like torch.Size([2077])\n", + "torch.from_numpy(rewards) looks like torch.Size([2077])\n", + "rewards looks like (1995,)\n", + "log_probs looks like (1995,)\n", + "logs prob looks like torch.Size([1995])\n", + "torch.from_numpy(rewards) looks like torch.Size([1995])\n", + "rewards looks like (1089,)\n", + "log_probs looks like (1089,)\n", + "logs prob looks like torch.Size([1089])\n", + "torch.from_numpy(rewards) looks like torch.Size([1089])\n", + "rewards looks like (1135,)\n", + "log_probs looks like (1135,)\n", + "logs prob looks like torch.Size([1135])\n", + "torch.from_numpy(rewards) looks like torch.Size([1135])\n", + "rewards looks like (1617,)\n", + "log_probs looks like (1617,)\n", + "logs prob looks like torch.Size([1617])\n", + "torch.from_numpy(rewards) looks like torch.Size([1617])\n", + "rewards looks like (942,)\n", + "log_probs looks like (942,)\n", + "logs prob looks like torch.Size([942])\n", + "torch.from_numpy(rewards) looks like torch.Size([942])\n", + "rewards looks like (2006,)\n", + "log_probs looks like (2006,)\n", + "logs prob looks like torch.Size([2006])\n", + "torch.from_numpy(rewards) looks like torch.Size([2006])\n", + "rewards looks like (2204,)\n", + "log_probs looks like (2204,)\n", + "logs prob looks like torch.Size([2204])\n", + "torch.from_numpy(rewards) looks like torch.Size([2204])\n", + "rewards looks like (1060,)\n", + "log_probs looks like (1060,)\n", + "logs prob looks like torch.Size([1060])\n", + "torch.from_numpy(rewards) looks like torch.Size([1060])\n", + "rewards looks like (1994,)\n", + "log_probs looks like (1994,)\n", + "logs prob looks like torch.Size([1994])\n", + "torch.from_numpy(rewards) looks like torch.Size([1994])\n", + "rewards looks like (1118,)\n", + "log_probs looks like (1118,)\n", + "logs prob looks like torch.Size([1118])\n", + "torch.from_numpy(rewards) looks like torch.Size([1118])\n", + "rewards looks like (1298,)\n", + "log_probs looks like (1298,)\n", + "logs prob looks like torch.Size([1298])\n", + "torch.from_numpy(rewards) looks like torch.Size([1298])\n", + "rewards looks like (1377,)\n", + "log_probs looks like (1377,)\n", + "logs prob looks like torch.Size([1377])\n", + "torch.from_numpy(rewards) looks like torch.Size([1377])\n", + "rewards looks like (1902,)\n", + "log_probs looks like (1902,)\n", + "logs prob looks like torch.Size([1902])\n", + "torch.from_numpy(rewards) looks like torch.Size([1902])\n", + "rewards looks like (1982,)\n", + "log_probs looks like (1982,)\n", + "logs prob looks like torch.Size([1982])\n", + "torch.from_numpy(rewards) looks like torch.Size([1982])\n", + "rewards looks like (1625,)\n", + "log_probs looks like (1625,)\n", + "logs prob looks like torch.Size([1625])\n", + "torch.from_numpy(rewards) looks like torch.Size([1625])\n", + "rewards looks like (1947,)\n", + "log_probs looks like (1947,)\n", + "logs prob looks like torch.Size([1947])\n", + "torch.from_numpy(rewards) looks like torch.Size([1947])\n", + "rewards looks like (1589,)\n", + "log_probs looks like (1589,)\n", + "logs prob looks like torch.Size([1589])\n", + "torch.from_numpy(rewards) looks like torch.Size([1589])\n", + "rewards looks like (1625,)\n", + "log_probs looks like (1625,)\n", + "logs prob looks like torch.Size([1625])\n", + "torch.from_numpy(rewards) looks like torch.Size([1625])\n", + "rewards looks like (1492,)\n", + "log_probs looks like (1492,)\n", + "logs prob looks like torch.Size([1492])\n", + "torch.from_numpy(rewards) looks like torch.Size([1492])\n", + "rewards looks like (1347,)\n", + "log_probs looks like (1347,)\n", + "logs prob looks like torch.Size([1347])\n", + "torch.from_numpy(rewards) looks like torch.Size([1347])\n", + "rewards looks like (2110,)\n", + "log_probs looks like (2110,)\n", + "logs prob looks like torch.Size([2110])\n", + "torch.from_numpy(rewards) looks like torch.Size([2110])\n", + "rewards looks like (877,)\n", + "log_probs looks like (877,)\n", + "logs prob looks like torch.Size([877])\n", + "torch.from_numpy(rewards) looks like torch.Size([877])\n", + "rewards looks like (1078,)\n", + "log_probs looks like (1078,)\n", + "logs prob looks like torch.Size([1078])\n", + "torch.from_numpy(rewards) looks like torch.Size([1078])\n", + "rewards looks like (2001,)\n", + "log_probs looks like (2001,)\n", + "logs prob looks like torch.Size([2001])\n", + "torch.from_numpy(rewards) looks like torch.Size([2001])\n", + "rewards looks like (1452,)\n", + "log_probs looks like (1452,)\n", + "logs prob looks like torch.Size([1452])\n", + "torch.from_numpy(rewards) looks like torch.Size([1452])\n", + "rewards looks like (1169,)\n", + "log_probs looks like (1169,)\n", + "logs prob looks like torch.Size([1169])\n", + "torch.from_numpy(rewards) looks like torch.Size([1169])\n", + "rewards looks like (1977,)\n", + "log_probs looks like (1977,)\n", + "logs prob looks like torch.Size([1977])\n", + "torch.from_numpy(rewards) looks like torch.Size([1977])\n", + "rewards looks like (1263,)\n", + "log_probs looks like (1263,)\n", + "logs prob looks like torch.Size([1263])\n", + "torch.from_numpy(rewards) looks like torch.Size([1263])\n", + "rewards looks like (2219,)\n", + "log_probs looks like (2219,)\n", + "logs prob looks like torch.Size([2219])\n", + "torch.from_numpy(rewards) looks like torch.Size([2219])\n", + "rewards looks like (1732,)\n", + "log_probs looks like (1732,)\n", + "logs prob looks like torch.Size([1732])\n", + "torch.from_numpy(rewards) looks like torch.Size([1732])\n", + "rewards looks like (1413,)\n", + "log_probs looks like (1413,)\n", + "logs prob looks like torch.Size([1413])\n", + "torch.from_numpy(rewards) looks like torch.Size([1413])\n", + "rewards looks like (1099,)\n", + "log_probs looks like (1099,)\n", + "logs prob looks like torch.Size([1099])\n", + "torch.from_numpy(rewards) looks like torch.Size([1099])\n", + "rewards looks like (1184,)\n", + "log_probs looks like (1184,)\n", + "logs prob looks like torch.Size([1184])\n", + "torch.from_numpy(rewards) looks like torch.Size([1184])\n", + "rewards looks like (1148,)\n", + "log_probs looks like (1148,)\n", + "logs prob looks like torch.Size([1148])\n", + "torch.from_numpy(rewards) looks like torch.Size([1148])\n", + "rewards looks like (1339,)\n", + "log_probs looks like (1339,)\n", + "logs prob looks like torch.Size([1339])\n", + "torch.from_numpy(rewards) looks like torch.Size([1339])\n", + "rewards looks like (2095,)\n", + "log_probs looks like (2095,)\n", + "logs prob looks like torch.Size([2095])\n", + "torch.from_numpy(rewards) looks like torch.Size([2095])\n", + "rewards looks like (1514,)\n", + "log_probs looks like (1514,)\n", + "logs prob looks like torch.Size([1514])\n", + "torch.from_numpy(rewards) looks like torch.Size([1514])\n", + "rewards looks like (1276,)\n", + "log_probs looks like (1276,)\n", + "logs prob looks like torch.Size([1276])\n", + "torch.from_numpy(rewards) looks like torch.Size([1276])\n", + "rewards looks like (1277,)\n", + "log_probs looks like (1277,)\n", + "logs prob looks like torch.Size([1277])\n", + "torch.from_numpy(rewards) looks like torch.Size([1277])\n", + "rewards looks like (1453,)\n", + "log_probs looks like (1453,)\n", + "logs prob looks like torch.Size([1453])\n", + "torch.from_numpy(rewards) looks like torch.Size([1453])\n", + "rewards looks like (1467,)\n", + "log_probs looks like (1467,)\n", + "logs prob looks like torch.Size([1467])\n", + "torch.from_numpy(rewards) looks like torch.Size([1467])\n", + "rewards looks like (1383,)\n", + "log_probs looks like (1383,)\n", + "logs prob looks like torch.Size([1383])\n", + "torch.from_numpy(rewards) looks like torch.Size([1383])\n", + "rewards looks like (1741,)\n", + "log_probs looks like (1741,)\n", + "logs prob looks like torch.Size([1741])\n", + "torch.from_numpy(rewards) looks like torch.Size([1741])\n", + "rewards looks like (1039,)\n", + "log_probs looks like (1039,)\n", + "logs prob looks like torch.Size([1039])\n", + "torch.from_numpy(rewards) looks like torch.Size([1039])\n", + "rewards looks like (1063,)\n", + "log_probs looks like (1063,)\n", + "logs prob looks like torch.Size([1063])\n", + "torch.from_numpy(rewards) looks like torch.Size([1063])\n", + "rewards looks like (1731,)\n", + "log_probs looks like (1731,)\n", + "logs prob looks like torch.Size([1731])\n", + "torch.from_numpy(rewards) looks like torch.Size([1731])\n", + "rewards looks like (2661,)\n", + "log_probs looks like (2661,)\n", + "logs prob looks like torch.Size([2661])\n", + "torch.from_numpy(rewards) looks like torch.Size([2661])\n", + "rewards looks like (704,)\n", + "log_probs looks like (704,)\n", + "logs prob looks like torch.Size([704])\n", + "torch.from_numpy(rewards) looks like torch.Size([704])\n", + "rewards looks like (1389,)\n", + "log_probs looks like (1389,)\n", + "logs prob looks like torch.Size([1389])\n", + "torch.from_numpy(rewards) looks like torch.Size([1389])\n", + "rewards looks like (2131,)\n", + "log_probs looks like (2131,)\n", + "logs prob looks like torch.Size([2131])\n", + "torch.from_numpy(rewards) looks like torch.Size([2131])\n", + "rewards looks like (1779,)\n", + "log_probs looks like (1779,)\n", + "logs prob looks like torch.Size([1779])\n", + "torch.from_numpy(rewards) looks like torch.Size([1779])\n", + "rewards looks like (1415,)\n", + "log_probs looks like (1415,)\n", + "logs prob looks like torch.Size([1415])\n", + "torch.from_numpy(rewards) looks like torch.Size([1415])\n", + "rewards looks like (2320,)\n", + "log_probs looks like (2320,)\n", + "logs prob looks like torch.Size([2320])\n", + "torch.from_numpy(rewards) looks like torch.Size([2320])\n", + "rewards looks like (1147,)\n", + "log_probs looks like (1147,)\n", + "logs prob looks like torch.Size([1147])\n", + "torch.from_numpy(rewards) looks like torch.Size([1147])\n", + "rewards looks like (1022,)\n", + "log_probs looks like (1022,)\n", + "logs prob looks like torch.Size([1022])\n", + "torch.from_numpy(rewards) looks like torch.Size([1022])\n", + "rewards looks like (2141,)\n", + "log_probs looks like (2141,)\n", + "logs prob looks like torch.Size([2141])\n", + "torch.from_numpy(rewards) looks like torch.Size([2141])\n", + "rewards looks like (1362,)\n", + "log_probs looks like (1362,)\n", + "logs prob looks like torch.Size([1362])\n", + "torch.from_numpy(rewards) looks like torch.Size([1362])\n", + "rewards looks like (1450,)\n", + "log_probs looks like (1450,)\n", + "logs prob looks like torch.Size([1450])\n", + "torch.from_numpy(rewards) looks like torch.Size([1450])\n", + "rewards looks like (1546,)\n", + "log_probs looks like (1546,)\n", + "logs prob looks like torch.Size([1546])\n", + "torch.from_numpy(rewards) looks like torch.Size([1546])\n", + "rewards looks like (1166,)\n", + "log_probs looks like (1166,)\n", + "logs prob looks like torch.Size([1166])\n", + "torch.from_numpy(rewards) looks like torch.Size([1166])\n", + "rewards looks like (1647,)\n", + "log_probs looks like (1647,)\n", + "logs prob looks like torch.Size([1647])\n", + "torch.from_numpy(rewards) looks like torch.Size([1647])\n", + "rewards looks like (1205,)\n", + "log_probs looks like (1205,)\n", + "logs prob looks like torch.Size([1205])\n", + "torch.from_numpy(rewards) looks like torch.Size([1205])\n", + "rewards looks like (2098,)\n", + "log_probs looks like (2098,)\n", + "logs prob looks like torch.Size([2098])\n", + "torch.from_numpy(rewards) looks like torch.Size([2098])\n", + "rewards looks like (1940,)\n", + "log_probs looks like (1940,)\n", + "logs prob looks like torch.Size([1940])\n", + "torch.from_numpy(rewards) looks like torch.Size([1940])\n", + "rewards looks like (2191,)\n", + "log_probs looks like (2191,)\n", + "logs prob looks like torch.Size([2191])\n", + "torch.from_numpy(rewards) looks like torch.Size([2191])\n", + "rewards looks like (2740,)\n", + "log_probs looks like (2740,)\n", + "logs prob looks like torch.Size([2740])\n", + "torch.from_numpy(rewards) looks like torch.Size([2740])\n", + "rewards looks like (587,)\n", + "log_probs looks like (587,)\n", + "logs prob looks like torch.Size([587])\n", + "torch.from_numpy(rewards) looks like torch.Size([587])\n", + "rewards looks like (1063,)\n", + "log_probs looks like (1063,)\n", + "logs prob looks like torch.Size([1063])\n", + "torch.from_numpy(rewards) looks like torch.Size([1063])\n", + "rewards looks like (861,)\n", + "log_probs looks like (861,)\n", + "logs prob looks like torch.Size([861])\n", + "torch.from_numpy(rewards) looks like torch.Size([861])\n", + "rewards looks like (1051,)\n", + "log_probs looks like (1051,)\n", + "logs prob looks like torch.Size([1051])\n", + "torch.from_numpy(rewards) looks like torch.Size([1051])\n", + "rewards looks like (1389,)\n", + "log_probs looks like (1389,)\n", + "logs prob looks like torch.Size([1389])\n", + "torch.from_numpy(rewards) looks like torch.Size([1389])\n", + "rewards looks like (1152,)\n", + "log_probs looks like (1152,)\n", + "logs prob looks like torch.Size([1152])\n", + "torch.from_numpy(rewards) looks like torch.Size([1152])\n", + "rewards looks like (1103,)\n", + "log_probs looks like (1103,)\n", + "logs prob looks like torch.Size([1103])\n", + "torch.from_numpy(rewards) looks like torch.Size([1103])\n", + "rewards looks like (1887,)\n", + "log_probs looks like (1887,)\n", + "logs prob looks like torch.Size([1887])\n", + "torch.from_numpy(rewards) looks like torch.Size([1887])\n", + "rewards looks like (1753,)\n", + "log_probs looks like (1753,)\n", + "logs prob looks like torch.Size([1753])\n", + "torch.from_numpy(rewards) looks like torch.Size([1753])\n", + "rewards looks like (1372,)\n", + "log_probs looks like (1372,)\n", + "logs prob looks like torch.Size([1372])\n", + "torch.from_numpy(rewards) looks like torch.Size([1372])\n", + "rewards looks like (1056,)\n", + "log_probs looks like (1056,)\n", + "logs prob looks like torch.Size([1056])\n", + "torch.from_numpy(rewards) looks like torch.Size([1056])\n", + "rewards looks like (1465,)\n", + "log_probs looks like (1465,)\n", + "logs prob looks like torch.Size([1465])\n", + "torch.from_numpy(rewards) looks like torch.Size([1465])\n", + "rewards looks like (3297,)\n", + "log_probs looks like (3297,)\n", + "logs prob looks like torch.Size([3297])\n", + "torch.from_numpy(rewards) looks like torch.Size([3297])\n", + "rewards looks like (2492,)\n", + "log_probs looks like (2492,)\n", + "logs prob looks like torch.Size([2492])\n", + "torch.from_numpy(rewards) looks like torch.Size([2492])\n", + "rewards looks like (1580,)\n", + "log_probs looks like (1580,)\n", + "logs prob looks like torch.Size([1580])\n", + "torch.from_numpy(rewards) looks like torch.Size([1580])\n", + "rewards looks like (1357,)\n", + "log_probs looks like (1357,)\n", + "logs prob looks like torch.Size([1357])\n", + "torch.from_numpy(rewards) looks like torch.Size([1357])\n", + "rewards looks like (1227,)\n", + "log_probs looks like (1227,)\n", + "logs prob looks like torch.Size([1227])\n", + "torch.from_numpy(rewards) looks like torch.Size([1227])\n", + "rewards looks like (2123,)\n", + "log_probs looks like (2123,)\n", + "logs prob looks like torch.Size([2123])\n", + "torch.from_numpy(rewards) looks like torch.Size([2123])\n", + "rewards looks like (1864,)\n", + "log_probs looks like (1864,)\n", + "logs prob looks like torch.Size([1864])\n", + "torch.from_numpy(rewards) looks like torch.Size([1864])\n", + "rewards looks like (1324,)\n", + "log_probs looks like (1324,)\n", + "logs prob looks like torch.Size([1324])\n", + "torch.from_numpy(rewards) looks like torch.Size([1324])\n", + "rewards looks like (1281,)\n", + "log_probs looks like (1281,)\n", + "logs prob looks like torch.Size([1281])\n", + "torch.from_numpy(rewards) looks like torch.Size([1281])\n", + "rewards looks like (1366,)\n", + "log_probs looks like (1366,)\n", + "logs prob looks like torch.Size([1366])\n", + "torch.from_numpy(rewards) looks like torch.Size([1366])\n", + "rewards looks like (957,)\n", + "log_probs looks like (957,)\n", + "logs prob looks like torch.Size([957])\n", + "torch.from_numpy(rewards) looks like torch.Size([957])\n", + "rewards looks like (1187,)\n", + "log_probs looks like (1187,)\n", + "logs prob looks like torch.Size([1187])\n", + "torch.from_numpy(rewards) looks like torch.Size([1187])\n", + "rewards looks like (1625,)\n", + "log_probs looks like (1625,)\n", + "logs prob looks like torch.Size([1625])\n", + "torch.from_numpy(rewards) looks like torch.Size([1625])\n", + "rewards looks like (1605,)\n", + "log_probs looks like (1605,)\n", + "logs prob looks like torch.Size([1605])\n", + "torch.from_numpy(rewards) looks like torch.Size([1605])\n", + "rewards looks like (1015,)\n", + "log_probs looks like (1015,)\n", + "logs prob looks like torch.Size([1015])\n", + "torch.from_numpy(rewards) looks like torch.Size([1015])\n", + "rewards looks like (1565,)\n", + "log_probs looks like (1565,)\n", + "logs prob looks like torch.Size([1565])\n", + "torch.from_numpy(rewards) looks like torch.Size([1565])\n", + "rewards looks like (1353,)\n", + "log_probs looks like (1353,)\n", + "logs prob looks like torch.Size([1353])\n", + "torch.from_numpy(rewards) looks like torch.Size([1353])\n", + "rewards looks like (1321,)\n", + "log_probs looks like (1321,)\n", + "logs prob looks like torch.Size([1321])\n", + "torch.from_numpy(rewards) looks like torch.Size([1321])\n", + "rewards looks like (1074,)\n", + "log_probs looks like (1074,)\n", + "logs prob looks like torch.Size([1074])\n", + "torch.from_numpy(rewards) looks like torch.Size([1074])\n", + "rewards looks like (1301,)\n", + "log_probs looks like (1301,)\n", + "logs prob looks like torch.Size([1301])\n", + "torch.from_numpy(rewards) looks like torch.Size([1301])\n", + "rewards looks like (2105,)\n", + "log_probs looks like (2105,)\n", + "logs prob looks like torch.Size([2105])\n", + "torch.from_numpy(rewards) looks like torch.Size([2105])\n", + "rewards looks like (2008,)\n", + "log_probs looks like (2008,)\n", + "logs prob looks like torch.Size([2008])\n", + "torch.from_numpy(rewards) looks like torch.Size([2008])\n", + "rewards looks like (1885,)\n", + "log_probs looks like (1885,)\n", + "logs prob looks like torch.Size([1885])\n", + "torch.from_numpy(rewards) looks like torch.Size([1885])\n", + "rewards looks like (1184,)\n", + "log_probs looks like (1184,)\n", + "logs prob looks like torch.Size([1184])\n", + "torch.from_numpy(rewards) looks like torch.Size([1184])\n", + "rewards looks like (2551,)\n", + "log_probs looks like (2551,)\n", + "logs prob looks like torch.Size([2551])\n", + "torch.from_numpy(rewards) looks like torch.Size([2551])\n", + "rewards looks like (1330,)\n", + "log_probs looks like (1330,)\n", + "logs prob looks like torch.Size([1330])\n", + "torch.from_numpy(rewards) looks like torch.Size([1330])\n", + "rewards looks like (1510,)\n", + "log_probs looks like (1510,)\n", + "logs prob looks like torch.Size([1510])\n", + "torch.from_numpy(rewards) looks like torch.Size([1510])\n", + "rewards looks like (1330,)\n", + "log_probs looks like (1330,)\n", + "logs prob looks like torch.Size([1330])\n", + "torch.from_numpy(rewards) looks like torch.Size([1330])\n", + "rewards looks like (2157,)\n", + "log_probs looks like (2157,)\n", + "logs prob looks like torch.Size([2157])\n", + "torch.from_numpy(rewards) looks like torch.Size([2157])\n", + "rewards looks like (1276,)\n", + "log_probs looks like (1276,)\n", + "logs prob looks like torch.Size([1276])\n", + "torch.from_numpy(rewards) looks like torch.Size([1276])\n", + "rewards looks like (1188,)\n", + "log_probs looks like (1188,)\n", + "logs prob looks like torch.Size([1188])\n", + "torch.from_numpy(rewards) looks like torch.Size([1188])\n", + "rewards looks like (2381,)\n", + "log_probs looks like (2381,)\n", + "logs prob looks like torch.Size([2381])\n", + "torch.from_numpy(rewards) looks like torch.Size([2381])\n", + "rewards looks like (1450,)\n", + "log_probs looks like (1450,)\n", + "logs prob looks like torch.Size([1450])\n", + "torch.from_numpy(rewards) looks like torch.Size([1450])\n", + "rewards looks like (1612,)\n", + "log_probs looks like (1612,)\n", + "logs prob looks like torch.Size([1612])\n", + "torch.from_numpy(rewards) looks like torch.Size([1612])\n", + "rewards looks like (1780,)\n", + "log_probs looks like (1780,)\n", + "logs prob looks like torch.Size([1780])\n", + "torch.from_numpy(rewards) looks like torch.Size([1780])\n", + "rewards looks like (1350,)\n", + "log_probs looks like (1350,)\n", + "logs prob looks like torch.Size([1350])\n", + "torch.from_numpy(rewards) looks like torch.Size([1350])\n", + "rewards looks like (1459,)\n", + "log_probs looks like (1459,)\n", + "logs prob looks like torch.Size([1459])\n", + "torch.from_numpy(rewards) looks like torch.Size([1459])\n", + "rewards looks like (1958,)\n", + "log_probs looks like (1958,)\n", + "logs prob looks like torch.Size([1958])\n", + "torch.from_numpy(rewards) looks like torch.Size([1958])\n", + "rewards looks like (1325,)\n", + "log_probs looks like (1325,)\n", + "logs prob looks like torch.Size([1325])\n", + "torch.from_numpy(rewards) looks like torch.Size([1325])\n", + "rewards looks like (2168,)\n", + "log_probs looks like (2168,)\n", + "logs prob looks like torch.Size([2168])\n", + "torch.from_numpy(rewards) looks like torch.Size([2168])\n", + "rewards looks like (1682,)\n", + "log_probs looks like (1682,)\n", + "logs prob looks like torch.Size([1682])\n", + "torch.from_numpy(rewards) looks like torch.Size([1682])\n", + "rewards looks like (852,)\n", + "log_probs looks like (852,)\n", + "logs prob looks like torch.Size([852])\n", + "torch.from_numpy(rewards) looks like torch.Size([852])\n", + "rewards looks like (1757,)\n", + "log_probs looks like (1757,)\n", + "logs prob looks like torch.Size([1757])\n", + "torch.from_numpy(rewards) looks like torch.Size([1757])\n", + "rewards looks like (2313,)\n", + "log_probs looks like (2313,)\n", + "logs prob looks like torch.Size([2313])\n", + "torch.from_numpy(rewards) looks like torch.Size([2313])\n", + "rewards looks like (1662,)\n", + "log_probs looks like (1662,)\n", + "logs prob looks like torch.Size([1662])\n", + "torch.from_numpy(rewards) looks like torch.Size([1662])\n", + "rewards looks like (1559,)\n", + "log_probs looks like (1559,)\n", + "logs prob looks like torch.Size([1559])\n", + "torch.from_numpy(rewards) looks like torch.Size([1559])\n", + "rewards looks like (2077,)\n", + "log_probs looks like (2077,)\n", + "logs prob looks like torch.Size([2077])\n", + "torch.from_numpy(rewards) looks like torch.Size([2077])\n", + "rewards looks like (2119,)\n", + "log_probs looks like (2119,)\n", + "logs prob looks like torch.Size([2119])\n", + "torch.from_numpy(rewards) looks like torch.Size([2119])\n", + "rewards looks like (954,)\n", + "log_probs looks like (954,)\n", + "logs prob looks like torch.Size([954])\n", + "torch.from_numpy(rewards) looks like torch.Size([954])\n", + "rewards looks like (1797,)\n", + "log_probs looks like (1797,)\n", + "logs prob looks like torch.Size([1797])\n", + "torch.from_numpy(rewards) looks like torch.Size([1797])\n", + "rewards looks like (1579,)\n", + "log_probs looks like (1579,)\n", + "logs prob looks like torch.Size([1579])\n", + "torch.from_numpy(rewards) looks like torch.Size([1579])\n", + "rewards looks like (1277,)\n", + "log_probs looks like (1277,)\n", + "logs prob looks like torch.Size([1277])\n", + "torch.from_numpy(rewards) looks like torch.Size([1277])\n", + "rewards looks like (1196,)\n", + "log_probs looks like (1196,)\n", + "logs prob looks like torch.Size([1196])\n", + "torch.from_numpy(rewards) looks like torch.Size([1196])\n", + "rewards looks like (1294,)\n", + "log_probs looks like (1294,)\n", + "logs prob looks like torch.Size([1294])\n", + "torch.from_numpy(rewards) looks like torch.Size([1294])\n", + "rewards looks like (1318,)\n", + "log_probs looks like (1318,)\n", + "logs prob looks like torch.Size([1318])\n", + "torch.from_numpy(rewards) looks like torch.Size([1318])\n", + "rewards looks like (2605,)\n", + "log_probs looks like (2605,)\n", + "logs prob looks like torch.Size([2605])\n", + "torch.from_numpy(rewards) looks like torch.Size([2605])\n", + "rewards looks like (2002,)\n", + "log_probs looks like (2002,)\n", + "logs prob looks like torch.Size([2002])\n", + "torch.from_numpy(rewards) looks like torch.Size([2002])\n", + "rewards looks like (1354,)\n", + "log_probs looks like (1354,)\n", + "logs prob looks like torch.Size([1354])\n", + "torch.from_numpy(rewards) looks like torch.Size([1354])\n", + "rewards looks like (1785,)\n", + "log_probs looks like (1785,)\n", + "logs prob looks like torch.Size([1785])\n", + "torch.from_numpy(rewards) looks like torch.Size([1785])\n", + "rewards looks like (781,)\n", + "log_probs looks like (781,)\n", + "logs prob looks like torch.Size([781])\n", + "torch.from_numpy(rewards) looks like torch.Size([781])\n", + "rewards looks like (1965,)\n", + "log_probs looks like (1965,)\n", + "logs prob looks like torch.Size([1965])\n", + "torch.from_numpy(rewards) looks like torch.Size([1965])\n", + "rewards looks like (1135,)\n", + "log_probs looks like (1135,)\n", + "logs prob looks like torch.Size([1135])\n", + "torch.from_numpy(rewards) looks like torch.Size([1135])\n", + "rewards looks like (1672,)\n", + "log_probs looks like (1672,)\n", + "logs prob looks like torch.Size([1672])\n", + "torch.from_numpy(rewards) looks like torch.Size([1672])\n", + "rewards looks like (1278,)\n", + "log_probs looks like (1278,)\n", + "logs prob looks like torch.Size([1278])\n", + "torch.from_numpy(rewards) looks like torch.Size([1278])\n", + "rewards looks like (2499,)\n", + "log_probs looks like (2499,)\n", + "logs prob looks like torch.Size([2499])\n", + "torch.from_numpy(rewards) looks like torch.Size([2499])\n", + "rewards looks like (1275,)\n", + "log_probs looks like (1275,)\n", + "logs prob looks like torch.Size([1275])\n", + "torch.from_numpy(rewards) looks like torch.Size([1275])\n", + "rewards looks like (1144,)\n", + "log_probs looks like (1144,)\n", + "logs prob looks like torch.Size([1144])\n", + "torch.from_numpy(rewards) looks like torch.Size([1144])\n", + "rewards looks like (1605,)\n", + "log_probs looks like (1605,)\n", + "logs prob looks like torch.Size([1605])\n", + "torch.from_numpy(rewards) looks like torch.Size([1605])\n", + "rewards looks like (1178,)\n", + "log_probs looks like (1178,)\n", + "logs prob looks like torch.Size([1178])\n", + "torch.from_numpy(rewards) looks like torch.Size([1178])\n", + "rewards looks like (3269,)\n", + "log_probs looks like (3269,)\n", + "logs prob looks like torch.Size([3269])\n", + "torch.from_numpy(rewards) looks like torch.Size([3269])\n", + "rewards looks like (1492,)\n", + "log_probs looks like (1492,)\n", + "logs prob looks like torch.Size([1492])\n", + "torch.from_numpy(rewards) looks like torch.Size([1492])\n", + "rewards looks like (1285,)\n", + "log_probs looks like (1285,)\n", + "logs prob looks like torch.Size([1285])\n", + "torch.from_numpy(rewards) looks like torch.Size([1285])\n", + "rewards looks like (1687,)\n", + "log_probs looks like (1687,)\n", + "logs prob looks like torch.Size([1687])\n", + "torch.from_numpy(rewards) looks like torch.Size([1687])\n", + "rewards looks like (1124,)\n", + "log_probs looks like (1124,)\n", + "logs prob looks like torch.Size([1124])\n", + "torch.from_numpy(rewards) looks like torch.Size([1124])\n", + "rewards looks like (2043,)\n", + "log_probs looks like (2043,)\n", + "logs prob looks like torch.Size([2043])\n", + "torch.from_numpy(rewards) looks like torch.Size([2043])\n", + "rewards looks like (1280,)\n", + "log_probs looks like (1280,)\n", + "logs prob looks like torch.Size([1280])\n", + "torch.from_numpy(rewards) looks like torch.Size([1280])\n", + "rewards looks like (1418,)\n", + "log_probs looks like (1418,)\n", + "logs prob looks like torch.Size([1418])\n", + "torch.from_numpy(rewards) looks like torch.Size([1418])\n", + "rewards looks like (1365,)\n", + "log_probs looks like (1365,)\n", + "logs prob looks like torch.Size([1365])\n", + "torch.from_numpy(rewards) looks like torch.Size([1365])\n", + "rewards looks like (1091,)\n", + "log_probs looks like (1091,)\n", + "logs prob looks like torch.Size([1091])\n", + "torch.from_numpy(rewards) looks like torch.Size([1091])\n", + "rewards looks like (1279,)\n", + "log_probs looks like (1279,)\n", + "logs prob looks like torch.Size([1279])\n", + "torch.from_numpy(rewards) looks like torch.Size([1279])\n", + "rewards looks like (1109,)\n", + "log_probs looks like (1109,)\n", + "logs prob looks like torch.Size([1109])\n", + "torch.from_numpy(rewards) looks like torch.Size([1109])\n", + "rewards looks like (1285,)\n", + "log_probs looks like (1285,)\n", + "logs prob looks like torch.Size([1285])\n", + "torch.from_numpy(rewards) looks like torch.Size([1285])\n", + "rewards looks like (1222,)\n", + "log_probs looks like (1222,)\n", + "logs prob looks like torch.Size([1222])\n", + "torch.from_numpy(rewards) looks like torch.Size([1222])\n", + "rewards looks like (1538,)\n", + "log_probs looks like (1538,)\n", + "logs prob looks like torch.Size([1538])\n", + "torch.from_numpy(rewards) looks like torch.Size([1538])\n", + "rewards looks like (1139,)\n", + "log_probs looks like (1139,)\n", + "logs prob looks like torch.Size([1139])\n", + "torch.from_numpy(rewards) looks like torch.Size([1139])\n", + "rewards looks like (1354,)\n", + "log_probs looks like (1354,)\n", + "logs prob looks like torch.Size([1354])\n", + "torch.from_numpy(rewards) looks like torch.Size([1354])\n", + "rewards looks like (1166,)\n", + "log_probs looks like (1166,)\n", + "logs prob looks like torch.Size([1166])\n", + "torch.from_numpy(rewards) looks like torch.Size([1166])\n", + "rewards looks like (1348,)\n", + "log_probs looks like (1348,)\n", + "logs prob looks like torch.Size([1348])\n", + "torch.from_numpy(rewards) looks like torch.Size([1348])\n", + "rewards looks like (1347,)\n", + "log_probs looks like (1347,)\n", + "logs prob looks like torch.Size([1347])\n", + "torch.from_numpy(rewards) looks like torch.Size([1347])\n", + "rewards looks like (2059,)\n", + "log_probs looks like (2059,)\n", + "logs prob looks like torch.Size([2059])\n", + "torch.from_numpy(rewards) looks like torch.Size([2059])\n", + "rewards looks like (2021,)\n", + "log_probs looks like (2021,)\n", + "logs prob looks like torch.Size([2021])\n", + "torch.from_numpy(rewards) looks like torch.Size([2021])\n", + "rewards looks like (2232,)\n", + "log_probs looks like (2232,)\n", + "logs prob looks like torch.Size([2232])\n", + "torch.from_numpy(rewards) looks like torch.Size([2232])\n", + "rewards looks like (1102,)\n", + "log_probs looks like (1102,)\n", + "logs prob looks like torch.Size([1102])\n", + "torch.from_numpy(rewards) looks like torch.Size([1102])\n", + "rewards looks like (1165,)\n", + "log_probs looks like (1165,)\n", + "logs prob looks like torch.Size([1165])\n", + "torch.from_numpy(rewards) looks like torch.Size([1165])\n", + "rewards looks like (1264,)\n", + "log_probs looks like (1264,)\n", + "logs prob looks like torch.Size([1264])\n", + "torch.from_numpy(rewards) looks like torch.Size([1264])\n", + "rewards looks like (1346,)\n", + "log_probs looks like (1346,)\n", + "logs prob looks like torch.Size([1346])\n", + "torch.from_numpy(rewards) looks like torch.Size([1346])\n", + "rewards looks like (2848,)\n", + "log_probs looks like (2848,)\n", + "logs prob looks like torch.Size([2848])\n", + "torch.from_numpy(rewards) looks like torch.Size([2848])\n", + "rewards looks like (938,)\n", + "log_probs looks like (938,)\n", + "logs prob looks like torch.Size([938])\n", + "torch.from_numpy(rewards) looks like torch.Size([938])\n", + "rewards looks like (1069,)\n", + "log_probs looks like (1069,)\n", + "logs prob looks like torch.Size([1069])\n", + "torch.from_numpy(rewards) looks like torch.Size([1069])\n", + "rewards looks like (2588,)\n", + "log_probs looks like (2588,)\n", + "logs prob looks like torch.Size([2588])\n", + "torch.from_numpy(rewards) looks like torch.Size([2588])\n", + "rewards looks like (1461,)\n", + "log_probs looks like (1461,)\n", + "logs prob looks like torch.Size([1461])\n", + "torch.from_numpy(rewards) looks like torch.Size([1461])\n", + "rewards looks like (2153,)\n", + "log_probs looks like (2153,)\n", + "logs prob looks like torch.Size([2153])\n", + "torch.from_numpy(rewards) looks like torch.Size([2153])\n", + "rewards looks like (2312,)\n", + "log_probs looks like (2312,)\n", + "logs prob looks like torch.Size([2312])\n", + "torch.from_numpy(rewards) looks like torch.Size([2312])\n", + "rewards looks like (1636,)\n", + "log_probs looks like (1636,)\n", + "logs prob looks like torch.Size([1636])\n", + "torch.from_numpy(rewards) looks like torch.Size([1636])\n", + "rewards looks like (2019,)\n", + "log_probs looks like (2019,)\n", + "logs prob looks like torch.Size([2019])\n", + "torch.from_numpy(rewards) looks like torch.Size([2019])\n", + "rewards looks like (1450,)\n", + "log_probs looks like (1450,)\n", + "logs prob looks like torch.Size([1450])\n", + "torch.from_numpy(rewards) looks like torch.Size([1450])\n", + "rewards looks like (2105,)\n", + "log_probs looks like (2105,)\n", + "logs prob looks like torch.Size([2105])\n", + "torch.from_numpy(rewards) looks like torch.Size([2105])\n", + "\n" + ], + "name": "stdout" + } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "vNb_tuFYhKVK" + }, + "source": [ + "### Training Result\n", + "During the training process, we recorded `avg_total_reward`, which represents the average total reward of episodes before updating the policy network.\n", + "\n", + "Theoretically, if the agent becomes better, the `avg_total_reward` will increase.\n", + "The visualization of the training process is shown below: \n" + ] + }, + { + "cell_type": "code", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 281 + }, + "id": "wZYOI8H10SHN", + "outputId": "54840043-6fe4-4771-e8c9-78785c55aa79" + }, + "source": [ + "plt.plot(avg_total_rewards)\n", + "plt.title(\"Total Rewards\")\n", + "plt.show()" + ], + "execution_count": null, + "outputs": [ + { + "output_type": "display_data", + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "tags": [], + "needs_background": "light" + } + } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "mV5jj4dThz0Y" + }, + "source": [ + "In addition, `avg_final_reward` represents average final rewards of episodes. To be specific, final rewards is the last reward received in one episode, indicating whether the craft lands successfully or not.\n" + ] + }, + { + "cell_type": "code", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "txDZ5vlGWz5w", + "outputId": "c7c5e9ca-6329-4ee2-f3d6-a1ba46b5aea2" + }, + "source": [ + "plt.plot(avg_final_rewards)\n", + "plt.title(\"Final Rewards\")\n", + "plt.show()" + ], + "execution_count": null, + "outputs": [ + { + "output_type": "display_data", + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "tags": [], + "needs_background": "light" + } + } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "u2HaGRVEYGQS" + }, + "source": [ + "## Testing\n", + "The testing result will be the average reward of 5 testing" + ] + }, + { + "cell_type": "code", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 286 + }, + "id": "5yFuUKKRYH73", + "outputId": "3ad20d21-4c56-47b7-e617-d71e4211aed7" + }, + "source": [ + "fix(env, seed)\n", + "agent.network.eval() # set the network into evaluation mode\n", + "NUM_OF_TEST = 5 # Do not revise this !!!\n", + "test_total_reward = []\n", + "action_list = []\n", + "for i in range(NUM_OF_TEST):\n", + " actions = []\n", + " state = env.reset()\n", + "\n", + " img = plt.imshow(env.render(mode='rgb_array'))\n", + "\n", + " total_reward = 0\n", + "\n", + " done = False\n", + " while not done:\n", + " action, _ = agent.sample(state)\n", + " actions.append(action)\n", + " state, reward, done, _ = env.step(action)\n", + "\n", + " total_reward += reward\n", + "\n", + " img.set_data(env.render(mode='rgb_array'))\n", + " display.display(plt.gcf())\n", + " display.clear_output(wait=True)\n", + " \n", + " print(total_reward)\n", + " test_total_reward.append(total_reward)\n", + "\n", + " action_list.append(actions) # save the result of testing \n" + ], + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "text": [ + "-207.9114975585693\n" + ], + "name": "stdout" + }, + { + "output_type": "display_data", + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "tags": [], + "needs_background": "light" + } + } + ] + }, + { + "cell_type": "code", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "Aex7mcKr0J01", + "outputId": "a36aaa35-ec20-4089-ddde-ab4742d3e90e" + }, + "source": [ + "print(np.mean(test_total_reward))" + ], + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "text": [ + "-147.2620449863271\n" + ], + "name": "stdout" + } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "leyebGYRpqsF" + }, + "source": [ + "Action list" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "hGAH4YWDpp4u", + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "24f547ee-6648-4d2a-f9c7-718b23e93251" + }, + "source": [ + "print(\"Action list looks like \", action_list)\n", + "print(\"Action list's shape looks like \", np.shape(action_list))" + ], + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "text": [ + "Action list looks like [[2, 2, 2, 2, 2, 2, 1, 2, 1, 2, 1, 2, 2, 2, 2, 1, 1, 2, 1, 2, 1, 2, 1, 2, 2, 1, 2, 1, 2, 1, 1, 2, 2, 2, 2, 2, 2, 1, 2, 2, 2, 1, 2, 2, 2, 2, 2, 0, 3, 2, 3, 2, 3, 3, 2, 2, 3, 3, 3, 3, 3, 3, 2, 3, 3, 3, 3, 3, 3, 2, 3, 3, 2, 3, 2, 3, 3, 2, 2, 3, 3, 2, 2, 3, 3, 2, 2, 3, 2, 2, 3, 2, 2, 2, 2, 3, 2, 2, 2, 3, 2, 3, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 1, 1, 2, 1, 2, 1, 1, 2, 1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 2, 1, 2, 1, 2, 1, 1, 2, 2, 2, 1, 1, 2, 2, 1, 1, 1, 2, 2, 1, 2, 2, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 1, 2, 2, 2, 2, 2, 2, 2, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 2, 3, 3, 3, 3, 3, 3, 3, 2, 3, 3, 3, 3, 2, 3, 3, 3, 3, 3, 2, 2, 3, 2, 3, 3, 2, 3, 2, 3, 2, 2, 2, 3, 2, 2, 2, 3, 2, 2, 2, 2, 2, 3, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3], [2, 2, 2, 2, 2, 2, 2, 2, 1, 2, 1, 1, 2, 2, 1, 2, 1, 2, 1, 2, 2, 2, 1, 2, 1, 1, 2, 1, 2, 2, 2, 2, 2, 2, 1, 1, 2, 2, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 2, 3, 3, 2, 3, 3, 3, 3, 3, 2, 3, 2, 3, 3, 3, 3, 2, 2, 3, 3, 3, 3, 2, 2, 2, 3, 3, 3, 2, 2, 3, 3, 3, 3, 2, 2, 3, 3, 2, 2, 2, 2, 2, 3, 2, 2, 2, 2, 2, 2, 3, 2, 3, 2, 2, 2, 2, 3, 2, 2, 2, 3, 2, 2, 2, 3, 2, 2, 2, 3, 2, 2, 2, 1, 1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 2, 1, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 1, 1, 2, 1, 1, 2, 1, 2, 2, 1, 2, 1, 2, 2, 2, 2, 2, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 2, 1, 2, 1, 2, 1, 1, 0, 2, 1, 1, 2, 2, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 1, 2, 1, 2, 2, 2, 2, 2, 2, 1, 2, 2, 2, 2, 2, 2, 2, 3, 2, 3, 2, 3, 3, 3, 3, 3, 2, 3, 3, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 2, 3, 2, 3, 3, 2, 2, 2, 3, 3, 2, 3, 0, 2, 3, 2, 0, 2, 3, 3, 2, 3, 2, 3, 2, 3, 2, 2, 3, 2, 3, 2, 2, 3, 3, 2, 3, 2, 2, 3, 2, 2, 2, 2, 3, 3, 2, 2, 2, 2, 3, 2, 2, 0, 1, 2, 1, 2, 1, 2, 1, 1, 1, 1, 2, 1, 2, 2, 1, 1, 1, 1, 2, 2, 1, 1, 1, 2, 1, 2, 1, 2, 2, 1, 1, 2, 2, 1, 2, 2, 2, 2, 2, 1, 2, 2, 2, 2, 1, 2, 2, 2, 0, 2, 3, 2, 3, 3, 3, 2, 3, 3, 3, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 2, 3, 2, 2, 3, 2, 3, 3, 3, 2, 3, 2, 3, 2, 2, 3, 2, 3, 2, 3, 2, 3, 3, 2, 2, 3, 2, 3, 3, 3, 2, 3, 3, 2, 2, 3, 2, 3, 2, 2, 3, 3, 2, 2, 2, 2, 2, 2, 0, 2, 1, 2, 2, 0, 1, 2, 2, 0, 2, 1, 1, 1, 2, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 1, 1, 1, 1, 1, 2, 1, 1, 1, 2, 2, 1, 2, 2, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 1, 2, 2, 2, 1, 2, 0, 2, 0, 3, 2, 3, 2, 0, 2, 0, 2, 3, 3, 3, 2, 2, 3, 2, 3, 2, 3, 2, 3, 3, 3, 2, 3, 3, 3, 3, 2, 2, 3, 2, 2, 2, 3, 2, 2, 3, 2, 3, 2, 2, 3, 2, 2, 3, 2, 0, 2, 3, 2, 2, 2, 1, 1, 1, 1, 2, 2, 1, 1, 1, 1, 2, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 2, 2, 2, 2, 2, 2, 2, 0, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 2, 1, 2, 2, 1, 2, 2, 1, 2, 2, 1, 2, 1, 1, 2, 2, 0, 2, 1, 2, 1, 1, 1, 2, 1, 2, 2, 1, 2, 1, 2, 2, 2, 2, 1, 2, 2, 1, 2, 2, 1, 2, 1, 2, 2, 2, 1, 1, 2, 3, 3, 2, 3, 2, 3, 2, 3, 2, 3, 3, 3, 3, 3, 3, 2, 2, 2, 2, 2, 3, 2, 3, 3, 2, 2, 2, 2, 3, 2, 3, 2, 3, 3, 2, 3, 2, 2, 3, 2, 3, 2, 2, 2, 3, 2, 0, 2, 1, 2, 1, 2, 1, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 1, 1, 2, 2, 2, 2, 1, 2, 2, 2, 2, 2, 1, 1, 2, 2, 1, 1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 2, 1, 2, 1, 2, 2, 1, 2, 1, 2, 1, 2, 1, 2, 2, 2, 1, 1, 2, 2, 2, 1, 1, 2, 1, 2, 2, 1, 2, 1, 2, 1, 2, 1, 1, 2, 2, 1, 2, 2, 2, 0, 2, 2, 1, 2, 2, 2, 3, 0, 3, 2, 3, 3, 2, 3, 3, 2, 3, 3, 2, 2, 3, 3, 3, 3, 2, 3, 3, 3, 2, 2, 3, 2, 2, 3, 3, 3, 2, 2, 2, 3, 3, 2, 2, 2, 3, 2, 2, 2, 2, 2, 2, 2, 2, 2, 0, 1, 2, 2, 1, 1, 1, 1, 2, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 1, 1, 1, 2, 2, 1, 2, 1, 1, 2, 2, 1, 2, 1, 1, 2, 2, 2, 1, 2, 1], [0, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 2, 2, 3, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 2, 1, 1, 1, 1, 2, 2, 1, 1, 1, 1, 1, 2, 2, 2, 1, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 2, 2, 2, 2, 1, 2, 2, 2, 2, 2, 1, 1, 2, 1, 2, 2, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 2, 2, 3, 3, 3, 2, 3, 3, 3, 3, 3, 3, 2, 3, 3, 3, 3, 3, 3, 3, 2, 3, 2, 3, 2, 3, 3, 2, 2, 3, 3, 2, 3, 2, 3, 3, 2, 3, 2, 3, 2, 2, 2, 2, 2, 2, 2, 3, 2, 3, 2, 2, 2, 3, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 2, 2, 1, 1, 1, 1, 1, 1], [2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 2, 2, 2, 3, 2, 3, 2, 2, 2, 2, 2, 2, 0, 2, 1, 2, 2, 0, 1, 2, 1, 2, 2, 1, 1, 2, 0, 2, 1, 2, 2, 2, 2, 2, 1, 1, 2, 1, 0, 2, 2, 2, 2, 1, 2, 2, 2, 1, 1, 2, 2, 2, 1, 2, 1, 2, 2, 2, 2, 1, 2, 1, 2, 1, 1, 2, 2, 1, 2, 2, 2, 3, 2, 3, 2, 2, 3, 2, 3, 2, 3, 3, 2, 3, 3, 3, 2, 3, 3, 3, 3, 2, 3, 3, 2, 2, 2, 2, 3, 2, 3, 3, 2, 3, 2, 2, 3, 2, 2, 3, 2, 2, 2, 3, 2, 2, 3, 2, 2, 3, 2, 2, 2, 2, 2, 3, 2, 2, 2, 2, 2, 0, 0, 2, 2, 1, 1, 2, 2, 2, 1, 1, 2, 2, 1, 1, 1, 1, 2, 2, 1, 1, 1, 1, 1, 1, 2, 1, 2, 1, 0, 1, 2, 1, 1, 2, 1, 2, 2, 1, 2, 1, 2, 1, 2, 1, 1, 2, 2, 2, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 2, 3, 2, 3, 3, 3, 2, 3, 3, 3, 2, 3, 3, 2, 2, 2, 3, 2, 3, 2, 2, 3, 3, 2, 3, 3, 3, 2, 2, 3, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 2, 1, 1, 2, 2, 1, 2, 1, 2, 2, 2, 2, 1, 1, 1, 1, 1, 1], [2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 2, 2, 3, 3, 2, 3, 2, 2, 3, 2, 2, 2, 3, 2, 2, 2, 3, 2, 2, 2, 2, 3, 2, 2, 2, 2, 3, 2, 2, 2, 2, 1, 1, 1, 2, 1, 2, 1, 2, 2, 1, 2, 2, 1, 1, 1, 1, 1, 2, 1, 2, 1, 2, 1, 1, 1, 1, 2, 1, 2, 2, 1, 2, 1, 2, 2, 2, 2, 2, 2, 2, 2, 1, 2, 2, 2, 1, 2, 2, 2, 3, 2, 2, 0, 2, 2, 2, 3, 2, 3, 2, 2, 3, 3, 2, 3, 3, 3, 3, 3, 3, 3, 3, 2, 3, 2, 3, 3, 3, 3, 2, 3, 3, 2, 3, 3, 3, 2, 2, 3, 2, 3, 3, 2, 3, 2, 3, 2, 2, 2, 3, 2, 3, 2, 3, 2, 3, 2, 2, 2, 3, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 2, 1, 2, 2, 1, 1, 1, 2, 2, 1, 2, 1, 1, 2, 2, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3]]\n", + "Action list's shape looks like (5,)\n" + ], + "name": "stdout" + }, + { + "output_type": "stream", + "text": [ + "/usr/local/lib/python3.7/dist-packages/numpy/core/_asarray.py:83: VisibleDeprecationWarning: Creating an ndarray from ragged nested sequences (which is a list-or-tuple of lists-or-tuples-or ndarrays with different lengths or shapes) is deprecated. If you meant to do this, you must specify 'dtype=object' when creating the ndarray\n", + " return array(a, dtype, copy=False, order=order)\n" + ], + "name": "stderr" + } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "fNkmwucrHMen" + }, + "source": [ + "Analysis of actions taken by agent" + ] + }, + { + "cell_type": "code", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "WHdAItjj1nxw", + "outputId": "20aea8b1-e011-4397-b775-7b4c4b593871" + }, + "source": [ + "distribution = {}\n", + "for actions in action_list:\n", + " for action in actions:\n", + " if action not in distribution.keys():\n", + " distribution[action] = 1\n", + " else:\n", + " distribution[action] += 1\n", + "print(distribution)" + ], + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "text": [ + "{2: 1144, 1: 516, 0: 30, 3: 501}\n" + ], + "name": "stdout" + } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "ricE0schY75M" + }, + "source": [ + "Saving the result of Model Testing\n" + ] + }, + { + "cell_type": "code", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "GZsMkGmIY42b", + "outputId": "c47a3123-eb1b-4dc1-f1b2-7a09a82cd8a6" + }, + "source": [ + "PATH = \"Action_List.npy\" # Can be modified into the name or path you want\n", + "np.save(PATH ,np.array(action_list)) " + ], + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "text": [ + "/usr/local/lib/python3.7/dist-packages/ipykernel_launcher.py:2: VisibleDeprecationWarning: Creating an ndarray from ragged nested sequences (which is a list-or-tuple of lists-or-tuples-or ndarrays with different lengths or shapes) is deprecated. If you meant to do this, you must specify 'dtype=object' when creating the ndarray\n", + " \n" + ], + "name": "stderr" + } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "asK7WfbkaLjt" + }, + "source": [ + "### This is the file you need to submit !!!\n", + "Download the testing result to your device\n", + "\n" + ] + }, + { + "cell_type": "code", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 34 + }, + "id": "c-CqyhHzaWAL", + "outputId": "38653c82-673e-4f90-8746-3a0424fe3aca" + }, + "source": [ + "from google.colab import files\n", + "files.download(PATH)" + ], + "execution_count": null, + "outputs": [ + { + "output_type": "display_data", + "data": { + "application/javascript": [ + "\n", + " async function download(id, filename, size) {\n", + " if (!google.colab.kernel.accessAllowed) {\n", + " return;\n", + " }\n", + " const div = document.createElement('div');\n", + " const label = document.createElement('label');\n", + " label.textContent = `Downloading \"${filename}\": `;\n", + " div.appendChild(label);\n", + " const progress = document.createElement('progress');\n", + " progress.max = size;\n", + " div.appendChild(progress);\n", + " document.body.appendChild(div);\n", + "\n", + " const buffers = [];\n", + " let downloaded = 0;\n", + "\n", + " const channel = await google.colab.kernel.comms.open(id);\n", + " // Send a message to notify the kernel that we're ready.\n", + " channel.send({})\n", + "\n", + " for await (const message of channel.messages) {\n", + " // Send a message to notify the kernel that we're ready.\n", + " channel.send({})\n", + " if (message.buffers) {\n", + " for (const buffer of message.buffers) {\n", + " buffers.push(buffer);\n", + " downloaded += buffer.byteLength;\n", + " progress.value = downloaded;\n", + " }\n", + " }\n", + " }\n", + " const blob = new Blob(buffers, {type: 'application/binary'});\n", + " const a = document.createElement('a');\n", + " a.href = window.URL.createObjectURL(blob);\n", + " a.download = filename;\n", + " div.appendChild(a);\n", + " a.click();\n", + " div.remove();\n", + " }\n", + " " + ], + "text/plain": [ + "" + ] + }, + "metadata": { + "tags": [] + } + }, + { + "output_type": "display_data", + "data": { + "application/javascript": [ + "download(\"download_899cef0d-01bc-40fd-a501-1573f2382641\", \"Action_List.npy\", 4689)" + ], + "text/plain": [ + "" + ] + }, + "metadata": { + "tags": [] + } + } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "seT4NUmWmAZ1" + }, + "source": [ + "# Server \n", + "The code below simulate the environment on the judge server. Can be used for testing." + ] + }, + { + "cell_type": "code", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "U69c-YTxaw6b", + "outputId": "65dd4cf4-e667-469d-a7dc-34f6a053e1d0" + }, + "source": [ + "action_list = np.load(PATH,allow_pickle=True) # The action list you upload\n", + "seed = 543 # Do not revise this\n", + "fix(env, seed)\n", + "\n", + "agent.network.eval() # set network to evaluation mode\n", + "\n", + "test_total_reward = []\n", + "if len(action_list) != 5:\n", + " print(\"Wrong format of file !!!\")\n", + " exit(0)\n", + "for actions in action_list:\n", + " state = env.reset()\n", + " img = plt.imshow(env.render(mode='rgb_array'))\n", + "\n", + " total_reward = 0\n", + "\n", + " done = False\n", + "\n", + " for action in actions:\n", + " \n", + " state, reward, done, _ = env.step(action)\n", + " total_reward += reward\n", + " if done:\n", + " break\n", + "\n", + " print(f\"Your reward is : %.2f\"%total_reward)\n", + " test_total_reward.append(total_reward)" + ], + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "text": [ + "/usr/local/lib/python3.7/dist-packages/torch/__init__.py:422: UserWarning: torch.set_deterministic is deprecated and will be removed in a future release. Please use torch.use_deterministic_algorithms instead\n", + " \"torch.set_deterministic is deprecated and will be removed in a future \"\n" + ], + "name": "stderr" + }, + { + "output_type": "stream", + "text": [ + "Your reward is : -29.53\n", + "Your reward is : -36.44\n", + "Your reward is : -194.16\n", + "Your reward is : -268.27\n", + "Your reward is : -207.91\n" + ], + "name": "stdout" + }, + { + "output_type": "display_data", + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "tags": [], + "needs_background": "light" + } + } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "TjFBWwQP1hVe" + }, + "source": [ + "# Your score" + ] + }, + { + "cell_type": "code", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "GpJpZz3Wbm0X", + "outputId": "7d4677c7-b285-42d3-d8c0-5f0a0d8230c8" + }, + "source": [ + "print(f\"Your final reward is : %.2f\"%np.mean(test_total_reward))" + ], + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "text": [ + "Your final reward is : -147.26\n" + ], + "name": "stdout" + } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "wUBtYXG2eaqf" + }, + "source": [ + "## Reference\n", + "\n", + "Below are some useful tips for you to get high score.\n", + "\n", + "- [DRL Lecture 1: Policy Gradient (Review)](https://youtu.be/z95ZYgPgXOY)\n", + "- [ML Lecture 23-3: Reinforcement Learning (including Q-learning) start at 30:00](https://youtu.be/2-JNBzCq77c?t=1800)\n", + "- [Lecture 7: Policy Gradient, David Silver](http://www0.cs.ucl.ac.uk/staff/d.silver/web/Teaching_files/pg.pdf)\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "eZ7VDw-C19Qe" + }, + "source": [ + "" + ] + } + ] +} \ No newline at end of file diff --git a/范例/HW12/HW12_EN.pdf b/范例/HW12/HW12_EN.pdf new file mode 100644 index 0000000..f07f555 Binary files /dev/null and b/范例/HW12/HW12_EN.pdf differ diff --git a/范例/HW12/HW12_ZH.ipynb b/范例/HW12/HW12_ZH.ipynb new file mode 100644 index 0000000..771a9aa --- /dev/null +++ b/范例/HW12/HW12_ZH.ipynb @@ -0,0 +1,3645 @@ +{ + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "colab": { + "name": "hw12_reinforcement_learning_chinese_version.ipynb", + "provenance": [], + "collapsed_sections": [] + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + }, + "widgets": { + "application/vnd.jupyter.widget-state+json": { + "2acab9542fe64b979fa2ac2adb3f10a8": { + "model_module": "@jupyter-widgets/controls", + "model_name": "HBoxModel", + "state": { + "_view_name": "HBoxView", + "_dom_classes": [], + "_model_name": "HBoxModel", + "_view_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_view_count": null, + "_view_module_version": "1.5.0", + "box_style": "", + "layout": "IPY_MODEL_f288c64b5ff748eb82178bf1de17934f", + "_model_module": "@jupyter-widgets/controls", + "children": [ + "IPY_MODEL_de34e5b178f5470e98e0275102a65042", + "IPY_MODEL_c93cba301cac439ca56fb6b45bd1c4e4" + ] + } + }, + "f288c64b5ff748eb82178bf1de17934f": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "state": { + "_view_name": "LayoutView", + "grid_template_rows": null, + "right": null, + "justify_content": null, + "_view_module": "@jupyter-widgets/base", + "overflow": null, + "_model_module_version": "1.2.0", + "_view_count": null, + "flex_flow": null, + "width": null, + "min_width": null, + "border": null, + "align_items": null, + "bottom": null, + "_model_module": "@jupyter-widgets/base", + "top": null, + "grid_column": null, + "overflow_y": null, + "overflow_x": null, + "grid_auto_flow": null, + "grid_area": null, + "grid_template_columns": null, + "flex": null, + "_model_name": "LayoutModel", + "justify_items": null, + "grid_row": null, + "max_height": null, + "align_content": null, + "visibility": null, + "align_self": null, + "height": null, + "min_height": null, + "padding": null, + "grid_auto_rows": null, + "grid_gap": null, + "max_width": null, + "order": null, + "_view_module_version": "1.2.0", + "grid_template_areas": null, + "object_position": null, + "object_fit": null, + "grid_auto_columns": null, + "margin": null, + "display": null, + "left": null + } + }, + "de34e5b178f5470e98e0275102a65042": { + "model_module": "@jupyter-widgets/controls", + "model_name": "FloatProgressModel", + "state": { + "_view_name": "ProgressView", + "style": "IPY_MODEL_43c6ee720b674626ab3a869bda5dd6e3", + "_dom_classes": [], + "description": "Total: -24.0, Final: -40.0: 100%", + "_model_name": "FloatProgressModel", + "bar_style": "success", + "max": 400, + "_view_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "value": 400, + "_view_count": null, + "_view_module_version": "1.5.0", + "orientation": "horizontal", + "min": 0, + "description_tooltip": null, + "_model_module": "@jupyter-widgets/controls", + "layout": "IPY_MODEL_2465d2b109d34922a486341232d86ad6" + } + }, + "c93cba301cac439ca56fb6b45bd1c4e4": { + "model_module": "@jupyter-widgets/controls", + "model_name": "HTMLModel", + "state": { + "_view_name": "HTMLView", + "style": "IPY_MODEL_aa27187195be4da9874025395eac35eb", + "_dom_classes": [], + "description": "", + "_model_name": "HTMLModel", + "placeholder": "​", + "_view_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "value": " 400/400 [11:02<00:00, 1.66s/it]", + "_view_count": null, + "_view_module_version": "1.5.0", + "description_tooltip": null, + "_model_module": "@jupyter-widgets/controls", + "layout": "IPY_MODEL_02d196d4f9734f998455d92bd9300adb" + } + }, + "43c6ee720b674626ab3a869bda5dd6e3": { + "model_module": "@jupyter-widgets/controls", + "model_name": "ProgressStyleModel", + "state": { + "_view_name": "StyleView", + "_model_name": "ProgressStyleModel", + "description_width": "initial", + "_view_module": "@jupyter-widgets/base", + "_model_module_version": "1.5.0", + "_view_count": null, + "_view_module_version": "1.2.0", + "bar_color": null, + "_model_module": "@jupyter-widgets/controls" + } + }, + "2465d2b109d34922a486341232d86ad6": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "state": { + "_view_name": "LayoutView", + "grid_template_rows": null, + "right": null, + "justify_content": null, + "_view_module": "@jupyter-widgets/base", + "overflow": null, + "_model_module_version": "1.2.0", + "_view_count": null, + "flex_flow": null, + "width": null, + "min_width": null, + "border": null, + "align_items": null, + "bottom": null, + "_model_module": "@jupyter-widgets/base", + "top": null, + "grid_column": null, + "overflow_y": null, + "overflow_x": null, + "grid_auto_flow": null, + "grid_area": null, + "grid_template_columns": null, + "flex": null, + "_model_name": "LayoutModel", + "justify_items": null, + "grid_row": null, + "max_height": null, + "align_content": null, + "visibility": null, + "align_self": null, + "height": null, + "min_height": null, + "padding": null, + "grid_auto_rows": null, + "grid_gap": null, + "max_width": null, + "order": null, + "_view_module_version": "1.2.0", + "grid_template_areas": null, + "object_position": null, + "object_fit": null, + "grid_auto_columns": null, + "margin": null, + "display": null, + "left": null + } + }, + "aa27187195be4da9874025395eac35eb": { + "model_module": "@jupyter-widgets/controls", + "model_name": "DescriptionStyleModel", + "state": { + "_view_name": "StyleView", + "_model_name": "DescriptionStyleModel", + "description_width": "", + "_view_module": "@jupyter-widgets/base", + "_model_module_version": "1.5.0", + "_view_count": null, + "_view_module_version": "1.2.0", + "_model_module": "@jupyter-widgets/controls" + } + }, + "02d196d4f9734f998455d92bd9300adb": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "state": { + "_view_name": "LayoutView", + "grid_template_rows": null, + "right": null, + "justify_content": null, + "_view_module": "@jupyter-widgets/base", + "overflow": null, + "_model_module_version": "1.2.0", + "_view_count": null, + "flex_flow": null, + "width": null, + "min_width": null, + "border": null, + "align_items": null, + "bottom": null, + "_model_module": "@jupyter-widgets/base", + "top": null, + "grid_column": null, + "overflow_y": null, + "overflow_x": null, + "grid_auto_flow": null, + "grid_area": null, + "grid_template_columns": null, + "flex": null, + "_model_name": "LayoutModel", + "justify_items": null, + "grid_row": null, + "max_height": null, + "align_content": null, + "visibility": null, + "align_self": null, + "height": null, + "min_height": null, + "padding": null, + "grid_auto_rows": null, + "grid_gap": null, + "max_width": null, + "order": null, + "_view_module_version": "1.2.0", + "grid_template_areas": null, + "object_position": null, + "object_fit": null, + "grid_auto_columns": null, + "margin": null, + "display": null, + "left": null + } + } + } + }, + "accelerator": "GPU" + }, + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "Fp30SB4bxeQb" + }, + "source": [ + "# **Homework 12 - Reinforcement Learning**\n", + "\n", + "若有任何問題,歡迎來信至助教信箱 ntu-ml-2021spring-ta@googlegroups.com\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "yXsnCWPtWSNk" + }, + "source": [ + "## 前置作業\n", + "\n", + "首先我們需要安裝必要的系統套件及 PyPi 套件。\n", + "gym 這個套件由 OpenAI 所提供,是一套用來開發與比較 Reinforcement Learning 演算法的工具包(toolkit)。\n", + "而其餘套件則是為了在 Notebook 中繪圖所需要的套件。" + ] + }, + { + "cell_type": "code", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "5e2bScpnkVbv", + "outputId": "dd8cf053-de15-4a11-c146-5f3405d1e377" + }, + "source": [ + "!apt update\n", + "!apt install python-opengl xvfb -y\n", + "!pip install gym[box2d]==0.18.3 pyvirtualdisplay tqdm numpy==1.19.5 torch==1.8.1" + ], + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "text": [ + "\u001b[33m\r0% [Working]\u001b[0m\r \rGet:1 https://cloud.r-project.org/bin/linux/ubuntu bionic-cran40/ InRelease [3,626 B]\n", + "\u001b[33m\r0% [Connecting to archive.ubuntu.com (91.189.88.142)] [Waiting for headers] [1 \u001b[0m\u001b[33m\r0% [Connecting to archive.ubuntu.com (91.189.88.142)] [Waiting for headers] [Co\u001b[0m\r \rIgn:2 https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64 InRelease\n", + "\u001b[33m\r0% [Connecting to archive.ubuntu.com (91.189.88.142)] [Waiting for headers] [Co\u001b[0m\u001b[33m\r0% [1 InRelease gpgv 3,626 B] [Connecting to archive.ubuntu.com (91.189.88.142)\u001b[0m\r \rIgn:3 https://developer.download.nvidia.com/compute/machine-learning/repos/ubuntu1804/x86_64 InRelease\n", + "\u001b[33m\r0% [1 InRelease gpgv 3,626 B] [Connecting to archive.ubuntu.com (91.189.88.142)\u001b[0m\r \rGet:4 https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64 Release [697 B]\n", + "\u001b[33m\r0% [1 InRelease gpgv 3,626 B] [Connecting to archive.ubuntu.com (91.189.88.142)\u001b[0m\u001b[33m\r0% [1 InRelease gpgv 3,626 B] [Connecting to archive.ubuntu.com (91.189.88.142)\u001b[0m\r \rHit:5 https://developer.download.nvidia.com/compute/machine-learning/repos/ubuntu1804/x86_64 Release\n", + "\u001b[33m\r0% [1 InRelease gpgv 3,626 B] [Connecting to archive.ubuntu.com (91.189.88.142)\u001b[0m\r \rGet:6 https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64 Release.gpg [836 B]\n", + "Get:7 http://security.ubuntu.com/ubuntu bionic-security InRelease [88.7 kB]\n", + "Hit:8 http://archive.ubuntu.com/ubuntu bionic InRelease\n", + "Get:9 http://ppa.launchpad.net/c2d4u.team/c2d4u4.0+/ubuntu bionic InRelease [15.9 kB]\n", + "Get:10 https://cloud.r-project.org/bin/linux/ubuntu bionic-cran40/ Packages [60.9 kB]\n", + "Get:11 http://archive.ubuntu.com/ubuntu bionic-updates InRelease [88.7 kB]\n", + "Hit:13 http://ppa.launchpad.net/cran/libgit2/ubuntu bionic InRelease\n", + "Ign:14 https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64 Packages\n", + "Get:14 https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64 Packages [798 kB]\n", + "Get:15 http://archive.ubuntu.com/ubuntu bionic-backports InRelease [74.6 kB]\n", + "Hit:16 http://ppa.launchpad.net/deadsnakes/ppa/ubuntu bionic InRelease\n", + "Get:17 http://ppa.launchpad.net/graphics-drivers/ppa/ubuntu bionic InRelease [21.3 kB]\n", + "Get:18 http://security.ubuntu.com/ubuntu bionic-security/restricted amd64 Packages [423 kB]\n", + "Get:19 http://security.ubuntu.com/ubuntu bionic-security/main amd64 Packages [2,152 kB]\n", + "Get:20 http://ppa.launchpad.net/c2d4u.team/c2d4u4.0+/ubuntu bionic/main Sources [1,770 kB]\n", + "Get:21 http://security.ubuntu.com/ubuntu bionic-security/universe amd64 Packages [1,413 kB]\n", + "Get:22 http://archive.ubuntu.com/ubuntu bionic-updates/universe amd64 Packages [2,184 kB]\n", + "Get:23 http://archive.ubuntu.com/ubuntu bionic-updates/restricted amd64 Packages [452 kB]\n", + "Get:24 http://archive.ubuntu.com/ubuntu bionic-updates/main amd64 Packages [2,584 kB]\n", + "Get:25 http://ppa.launchpad.net/c2d4u.team/c2d4u4.0+/ubuntu bionic/main amd64 Packages [905 kB]\n", + "Get:26 http://ppa.launchpad.net/graphics-drivers/ppa/ubuntu bionic/main amd64 Packages [41.5 kB]\n", + "Fetched 13.1 MB in 4s (3,031 kB/s)\n", + "Reading package lists... Done\n", + "Building dependency tree \n", + "Reading state information... Done\n", + "86 packages can be upgraded. Run 'apt list --upgradable' to see them.\n", + "Reading package lists... Done\n", + "Building dependency tree \n", + "Reading state information... Done\n", + "The following package was automatically installed and is no longer required:\n", + " libnvidia-common-460\n", + "Use 'apt autoremove' to remove it.\n", + "Suggested packages:\n", + " libgle3\n", + "The following NEW packages will be installed:\n", + " python-opengl xvfb\n", + "0 upgraded, 2 newly installed, 0 to remove and 86 not upgraded.\n", + "Need to get 1,281 kB of archives.\n", + "After this operation, 7,686 kB of additional disk space will be used.\n", + "Get:1 http://archive.ubuntu.com/ubuntu bionic/universe amd64 python-opengl all 3.1.0+dfsg-1 [496 kB]\n", + "Get:2 http://archive.ubuntu.com/ubuntu bionic-updates/universe amd64 xvfb amd64 2:1.19.6-1ubuntu4.9 [784 kB]\n", + "Fetched 1,281 kB in 1s (977 kB/s)\n", + "Selecting previously unselected package python-opengl.\n", + "(Reading database ... 160706 files and directories currently installed.)\n", + "Preparing to unpack .../python-opengl_3.1.0+dfsg-1_all.deb ...\n", + "Unpacking python-opengl (3.1.0+dfsg-1) ...\n", + "Selecting previously unselected package xvfb.\n", + "Preparing to unpack .../xvfb_2%3a1.19.6-1ubuntu4.9_amd64.deb ...\n", + "Unpacking xvfb (2:1.19.6-1ubuntu4.9) ...\n", + "Setting up python-opengl (3.1.0+dfsg-1) ...\n", + "Setting up xvfb (2:1.19.6-1ubuntu4.9) ...\n", + "Processing triggers for man-db (2.8.3-2ubuntu0.1) ...\n", + "Requirement already satisfied: gym[box2d] in /usr/local/lib/python3.7/dist-packages (0.17.3)\n", + "Collecting pyvirtualdisplay\n", + " Downloading https://files.pythonhosted.org/packages/19/88/7a198a5ee3baa3d547f5a49574cd8c3913b216f5276b690b028f89ffb325/PyVirtualDisplay-2.1-py3-none-any.whl\n", + "Requirement already satisfied: tqdm in /usr/local/lib/python3.7/dist-packages (4.41.1)\n", + "Requirement already satisfied: cloudpickle<1.7.0,>=1.2.0 in /usr/local/lib/python3.7/dist-packages (from gym[box2d]) (1.3.0)\n", + "Requirement already satisfied: numpy>=1.10.4 in /usr/local/lib/python3.7/dist-packages (from gym[box2d]) (1.19.5)\n", + "Requirement already satisfied: pyglet<=1.5.0,>=1.4.0 in /usr/local/lib/python3.7/dist-packages (from gym[box2d]) (1.5.0)\n", + "Requirement already satisfied: scipy in /usr/local/lib/python3.7/dist-packages (from gym[box2d]) (1.4.1)\n", + "Collecting box2d-py~=2.3.5; extra == \"box2d\"\n", + "\u001b[?25l Downloading https://files.pythonhosted.org/packages/87/34/da5393985c3ff9a76351df6127c275dcb5749ae0abbe8d5210f06d97405d/box2d_py-2.3.8-cp37-cp37m-manylinux1_x86_64.whl (448kB)\n", + "\u001b[K |████████████████████████████████| 450kB 10.3MB/s \n", + "\u001b[?25hCollecting EasyProcess\n", + " Downloading https://files.pythonhosted.org/packages/48/3c/75573613641c90c6d094059ac28adb748560d99bd27ee6f80cce398f404e/EasyProcess-0.3-py2.py3-none-any.whl\n", + "Requirement already satisfied: future in /usr/local/lib/python3.7/dist-packages (from pyglet<=1.5.0,>=1.4.0->gym[box2d]) (0.16.0)\n", + "Installing collected packages: EasyProcess, pyvirtualdisplay, box2d-py\n", + "Successfully installed EasyProcess-0.3 box2d-py-2.3.8 pyvirtualdisplay-2.1\n" + ], + "name": "stdout" + } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "M_-i3cdoYsks" + }, + "source": [ + "接下來,設置好 virtual display,並引入所有必要的套件。" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "nl2nREINDLiw" + }, + "source": [ + "%%capture\n", + "from pyvirtualdisplay import Display\n", + "virtual_display = Display(visible=0, size=(1400, 900))\n", + "virtual_display.start()\n", + "\n", + "%matplotlib inline\n", + "import matplotlib.pyplot as plt\n", + "\n", + "from IPython import display\n", + "\n", + "import numpy as np\n", + "import torch\n", + "import torch.nn as nn\n", + "import torch.optim as optim\n", + "import torch.nn.functional as F\n", + "from torch.distributions import Categorical\n", + "from tqdm.notebook import tqdm" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "HVu9-Vdrl4E3" + }, + "source": [ + "# 請不要更改 random seed !!!!\n", + "# 不然在judgeboi上 你的成績不會被reproduce !!!!" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "fV9i8i2YkRbO" + }, + "source": [ + "seed = 543 # Do not change this\n", + "def fix(env, seed):\n", + " env.seed(seed)\n", + " env.action_space.seed(seed)\n", + " torch.manual_seed(seed)\n", + " torch.cuda.manual_seed(seed)\n", + " torch.cuda.manual_seed_all(seed)\n", + " np.random.seed(seed)\n", + " random.seed(seed)\n", + " torch.set_deterministic(True)\n", + " torch.backends.cudnn.benchmark = False\n", + " torch.backends.cudnn.deterministic = True" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "He0XDx6bzjgC" + }, + "source": [ + "最後,引入 OpenAI 的 gym,並建立一個 [Lunar Lander](https://gym.openai.com/envs/LunarLander-v2/) 環境。" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "N_4-xJcbBt09" + }, + "source": [ + "%%capture\n", + "import gym\n", + "import random\n", + "import numpy as np\n", + "\n", + "env = gym.make('LunarLander-v2')\n", + "\n", + "fix(env, seed)" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "NmiAOfqRwRX5" + }, + "source": [ + "import time\n", + "start = time.time()" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "LcMjEUWTBEEB", + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "7a5146e4-e877-4d26-fd61-652c57ef1f4e" + }, + "source": [ + "!pip freeze" + ], + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "text": [ + "absl-py==0.12.0\n", + "alabaster==0.7.12\n", + "albumentations==0.1.12\n", + "altair==4.1.0\n", + "appdirs==1.4.4\n", + "argon2-cffi==20.1.0\n", + "arviz==0.11.2\n", + "astor==0.8.1\n", + "astropy==4.2.1\n", + "astunparse==1.6.3\n", + "async-generator==1.10\n", + "atari-py==0.2.9\n", + "atomicwrites==1.4.0\n", + "attrs==21.2.0\n", + "audioread==2.1.9\n", + "autograd==1.3\n", + "Babel==2.9.1\n", + "backcall==0.2.0\n", + "beautifulsoup4==4.6.3\n", + "bleach==3.3.0\n", + "blis==0.4.1\n", + "bokeh==2.3.2\n", + "Bottleneck==1.3.2\n", + "box2d-py==2.3.8\n", + "branca==0.4.2\n", + "bs4==0.0.1\n", + "CacheControl==0.12.6\n", + "cached-property==1.5.2\n", + "cachetools==4.2.2\n", + "catalogue==1.0.0\n", + "certifi==2020.12.5\n", + "cffi==1.14.5\n", + "cftime==1.5.0\n", + "chainer==7.4.0\n", + "chardet==3.0.4\n", + "click==7.1.2\n", + "cloudpickle==1.3.0\n", + "cmake==3.12.0\n", + "cmdstanpy==0.9.5\n", + "colorcet==2.0.6\n", + "colorlover==0.3.0\n", + "community==1.0.0b1\n", + "contextlib2==0.5.5\n", + "convertdate==2.3.2\n", + "coverage==3.7.1\n", + "coveralls==0.5\n", + "crcmod==1.7\n", + "cufflinks==0.17.3\n", + "cupy-cuda101==7.4.0\n", + "cvxopt==1.2.6\n", + "cvxpy==1.0.31\n", + "cycler==0.10.0\n", + "cymem==2.0.5\n", + "Cython==0.29.23\n", + "daft==0.0.4\n", + "dask==2.12.0\n", + "datascience==0.10.6\n", + "debugpy==1.0.0\n", + "decorator==4.4.2\n", + "defusedxml==0.7.1\n", + "descartes==1.1.0\n", + "dill==0.3.3\n", + "distributed==1.25.3\n", + "dlib==19.18.0\n", + "dm-tree==0.1.6\n", + "docopt==0.6.2\n", + "docutils==0.17.1\n", + "dopamine-rl==1.0.5\n", + "earthengine-api==0.1.266\n", + "easydict==1.9\n", + "EasyProcess==0.3\n", + "ecos==2.0.7.post1\n", + "editdistance==0.5.3\n", + "en-core-web-sm==2.2.5\n", + "entrypoints==0.3\n", + "ephem==3.7.7.1\n", + "et-xmlfile==1.1.0\n", + "fa2==0.3.5\n", + "fastai==1.0.61\n", + "fastdtw==0.3.4\n", + "fastprogress==1.0.0\n", + "fastrlock==0.6\n", + "fbprophet==0.7.1\n", + "feather-format==0.4.1\n", + "filelock==3.0.12\n", + "firebase-admin==4.4.0\n", + "fix-yahoo-finance==0.0.22\n", + "Flask==1.1.4\n", + "flatbuffers==1.12\n", + "folium==0.8.3\n", + "future==0.16.0\n", + "gast==0.4.0\n", + "GDAL==2.2.2\n", + "gdown==3.6.4\n", + "gensim==3.6.0\n", + "geographiclib==1.50\n", + "geopy==1.17.0\n", + "gin-config==0.4.0\n", + "glob2==0.7\n", + "google==2.0.3\n", + "google-api-core==1.26.3\n", + "google-api-python-client==1.12.8\n", + "google-auth==1.30.0\n", + "google-auth-httplib2==0.0.4\n", + "google-auth-oauthlib==0.4.4\n", + "google-cloud-bigquery==1.21.0\n", + "google-cloud-bigquery-storage==1.1.0\n", + "google-cloud-core==1.0.3\n", + "google-cloud-datastore==1.8.0\n", + "google-cloud-firestore==1.7.0\n", + "google-cloud-language==1.2.0\n", + "google-cloud-storage==1.18.1\n", + "google-cloud-translate==1.5.0\n", + "google-colab==1.0.0\n", + "google-pasta==0.2.0\n", + "google-resumable-media==0.4.1\n", + "googleapis-common-protos==1.53.0\n", + "googledrivedownloader==0.4\n", + "graphviz==0.10.1\n", + "greenlet==1.1.0\n", + "grpcio==1.34.1\n", + "gspread==3.0.1\n", + "gspread-dataframe==3.0.8\n", + "gym==0.17.3\n", + "h5py==3.1.0\n", + "HeapDict==1.0.1\n", + "hijri-converter==2.1.1\n", + "holidays==0.10.5.2\n", + "holoviews==1.14.3\n", + "html5lib==1.0.1\n", + "httpimport==0.5.18\n", + "httplib2==0.17.4\n", + "httplib2shim==0.0.3\n", + "humanize==0.5.1\n", + "hyperopt==0.1.2\n", + "ideep4py==2.0.0.post3\n", + "idna==2.10\n", + "imageio==2.4.1\n", + "imagesize==1.2.0\n", + "imbalanced-learn==0.4.3\n", + "imblearn==0.0\n", + "imgaug==0.2.9\n", + "importlib-metadata==4.0.1\n", + "importlib-resources==5.1.3\n", + "imutils==0.5.4\n", + "inflect==2.1.0\n", + "iniconfig==1.1.1\n", + "install==1.3.4\n", + "intel-openmp==2021.2.0\n", + "intervaltree==2.1.0\n", + "ipykernel==4.10.1\n", + "ipython==5.5.0\n", + "ipython-genutils==0.2.0\n", + "ipython-sql==0.3.9\n", + "ipywidgets==7.6.3\n", + "itsdangerous==1.1.0\n", + "jax==0.2.13\n", + "jaxlib==0.1.66+cuda110\n", + "jdcal==1.4.1\n", + "jedi==0.18.0\n", + "jieba==0.42.1\n", + "Jinja2==2.11.3\n", + "joblib==1.0.1\n", + "jpeg4py==0.1.4\n", + "jsonschema==2.6.0\n", + "jupyter==1.0.0\n", + "jupyter-client==5.3.5\n", + "jupyter-console==5.2.0\n", + "jupyter-core==4.7.1\n", + "jupyterlab-pygments==0.1.2\n", + "jupyterlab-widgets==1.0.0\n", + "kaggle==1.5.12\n", + "kapre==0.3.5\n", + "Keras==2.4.3\n", + "keras-nightly==2.5.0.dev2021032900\n", + "Keras-Preprocessing==1.1.2\n", + "keras-vis==0.4.1\n", + "kiwisolver==1.3.1\n", + "korean-lunar-calendar==0.2.1\n", + "librosa==0.8.0\n", + "lightgbm==2.2.3\n", + "llvmlite==0.34.0\n", + "lmdb==0.99\n", + "LunarCalendar==0.0.9\n", + "lxml==4.2.6\n", + "Markdown==3.3.4\n", + "MarkupSafe==2.0.1\n", + "matplotlib==3.2.2\n", + "matplotlib-inline==0.1.2\n", + "matplotlib-venn==0.11.6\n", + "missingno==0.4.2\n", + "mistune==0.8.4\n", + "mizani==0.6.0\n", + "mkl==2019.0\n", + "mlxtend==0.14.0\n", + "more-itertools==8.7.0\n", + "moviepy==0.2.3.5\n", + "mpmath==1.2.1\n", + "msgpack==1.0.2\n", + "multiprocess==0.70.11.1\n", + "multitasking==0.0.9\n", + "murmurhash==1.0.5\n", + "music21==5.5.0\n", + "natsort==5.5.0\n", + "nbclient==0.5.3\n", + "nbconvert==5.6.1\n", + "nbformat==5.1.3\n", + "nest-asyncio==1.5.1\n", + "netCDF4==1.5.6\n", + "networkx==2.5.1\n", + "nibabel==3.0.2\n", + "nltk==3.2.5\n", + "notebook==5.3.1\n", + "numba==0.51.2\n", + "numexpr==2.7.3\n", + "numpy==1.19.5\n", + "nvidia-ml-py3==7.352.0\n", + "oauth2client==4.1.3\n", + "oauthlib==3.1.0\n", + "okgrade==0.4.3\n", + "opencv-contrib-python==4.1.2.30\n", + "opencv-python==4.1.2.30\n", + "openpyxl==2.5.9\n", + "opt-einsum==3.3.0\n", + "osqp==0.6.2.post0\n", + "packaging==20.9\n", + "palettable==3.3.0\n", + "pandas==1.1.5\n", + "pandas-datareader==0.9.0\n", + "pandas-gbq==0.13.3\n", + "pandas-profiling==1.4.1\n", + "pandocfilters==1.4.3\n", + "panel==0.11.3\n", + "param==1.10.1\n", + "parso==0.8.2\n", + "pathlib==1.0.1\n", + "patsy==0.5.1\n", + "pexpect==4.8.0\n", + "pickleshare==0.7.5\n", + "Pillow==7.1.2\n", + "pip-tools==4.5.1\n", + "plac==1.1.3\n", + "plotly==4.4.1\n", + "plotnine==0.6.0\n", + "pluggy==0.7.1\n", + "pooch==1.3.0\n", + "portpicker==1.3.9\n", + "prefetch-generator==1.0.1\n", + "preshed==3.0.5\n", + "prettytable==2.1.0\n", + "progressbar2==3.38.0\n", + "prometheus-client==0.10.1\n", + "promise==2.3\n", + "prompt-toolkit==1.0.18\n", + "protobuf==3.12.4\n", + "psutil==5.4.8\n", + "psycopg2==2.7.6.1\n", + "ptyprocess==0.7.0\n", + "py==1.10.0\n", + "pyarrow==3.0.0\n", + "pyasn1==0.4.8\n", + "pyasn1-modules==0.2.8\n", + "pycocotools==2.0.2\n", + "pycparser==2.20\n", + "pyct==0.4.8\n", + "pydata-google-auth==1.2.0\n", + "pydot==1.3.0\n", + "pydot-ng==2.0.0\n", + "pydotplus==2.0.2\n", + "PyDrive==1.3.1\n", + "pyemd==0.5.1\n", + "pyerfa==2.0.0\n", + "pyglet==1.5.0\n", + "Pygments==2.6.1\n", + "pygobject==3.26.1\n", + "pymc3==3.11.2\n", + "PyMeeus==0.5.11\n", + "pymongo==3.11.4\n", + "pymystem3==0.2.0\n", + "PyOpenGL==3.1.5\n", + "pyparsing==2.4.7\n", + "pyrsistent==0.17.3\n", + "pysndfile==1.3.8\n", + "PySocks==1.7.1\n", + "pystan==2.19.1.1\n", + "pytest==3.6.4\n", + "python-apt==0.0.0\n", + "python-chess==0.23.11\n", + "python-dateutil==2.8.1\n", + "python-louvain==0.15\n", + "python-slugify==5.0.2\n", + "python-utils==2.5.6\n", + "pytz==2018.9\n", + "PyVirtualDisplay==2.1\n", + "pyviz-comms==2.0.1\n", + "PyWavelets==1.1.1\n", + "PyYAML==3.13\n", + "pyzmq==22.0.3\n", + "qdldl==0.1.5.post0\n", + "qtconsole==5.1.0\n", + "QtPy==1.9.0\n", + "regex==2019.12.20\n", + "requests==2.23.0\n", + "requests-oauthlib==1.3.0\n", + "resampy==0.2.2\n", + "retrying==1.3.3\n", + "rpy2==3.4.4\n", + "rsa==4.7.2\n", + "scikit-image==0.16.2\n", + "scikit-learn==0.22.2.post1\n", + "scipy==1.4.1\n", + "screen-resolution-extra==0.0.0\n", + "scs==2.1.3\n", + "seaborn==0.11.1\n", + "semver==2.13.0\n", + "Send2Trash==1.5.0\n", + "setuptools-git==1.2\n", + "Shapely==1.7.1\n", + "simplegeneric==0.8.1\n", + "six==1.15.0\n", + "sklearn==0.0\n", + "sklearn-pandas==1.8.0\n", + "smart-open==5.0.0\n", + "snowballstemmer==2.1.0\n", + "sortedcontainers==2.4.0\n", + "SoundFile==0.10.3.post1\n", + "spacy==2.2.4\n", + "Sphinx==1.8.5\n", + "sphinxcontrib-serializinghtml==1.1.4\n", + "sphinxcontrib-websupport==1.2.4\n", + "SQLAlchemy==1.4.15\n", + "sqlparse==0.4.1\n", + "srsly==1.0.5\n", + "statsmodels==0.10.2\n", + "sympy==1.7.1\n", + "tables==3.4.4\n", + "tabulate==0.8.9\n", + "tblib==1.7.0\n", + "tensorboard==2.5.0\n", + "tensorboard-data-server==0.6.1\n", + "tensorboard-plugin-wit==1.8.0\n", + "tensorflow==2.5.0\n", + "tensorflow-datasets==4.0.1\n", + "tensorflow-estimator==2.5.0\n", + "tensorflow-gcs-config==2.5.0\n", + "tensorflow-hub==0.12.0\n", + "tensorflow-metadata==1.0.0\n", + "tensorflow-probability==0.12.1\n", + "termcolor==1.1.0\n", + "terminado==0.10.0\n", + "testpath==0.5.0\n", + "text-unidecode==1.3\n", + "textblob==0.15.3\n", + "Theano-PyMC==1.1.2\n", + "thinc==7.4.0\n", + "tifffile==2021.4.8\n", + "toml==0.10.2\n", + "toolz==0.11.1\n", + "torch==1.8.1+cu101\n", + "torchsummary==1.5.1\n", + "torchtext==0.9.1\n", + "torchvision==0.9.1+cu101\n", + "tornado==5.1.1\n", + "tqdm==4.41.1\n", + "traitlets==5.0.5\n", + "tweepy==3.10.0\n", + "typeguard==2.7.1\n", + "typing-extensions==3.7.4.3\n", + "tzlocal==1.5.1\n", + "uritemplate==3.0.1\n", + "urllib3==1.24.3\n", + "vega-datasets==0.9.0\n", + "wasabi==0.8.2\n", + "wcwidth==0.2.5\n", + "webencodings==0.5.1\n", + "Werkzeug==1.0.1\n", + "widgetsnbextension==3.5.1\n", + "wordcloud==1.5.0\n", + "wrapt==1.12.1\n", + "xarray==0.18.2\n", + "xgboost==0.90\n", + "xkit==0.0.0\n", + "xlrd==1.1.0\n", + "xlwt==1.3.0\n", + "yellowbrick==0.9.1\n", + "zict==2.0.0\n", + "zipp==3.4.1\n" + ], + "name": "stdout" + } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "NrkVvTrvWZ5H" + }, + "source": [ + "## 什麼是 Lunar Lander?\n", + "\n", + "“LunarLander-v2” 這個環境是在模擬登月小艇降落在月球表面時的情形。\n", + "這個任務的目標是讓登月小艇「安全地」降落在兩個黃色旗幟間的平地上。\n", + "> Landing pad is always at coordinates (0,0).\n", + "> Coordinates are the first two numbers in state vector.\n", + "\n", + "![](https://gym.openai.com/assets/docs/aeloop-138c89d44114492fd02822303e6b4b07213010bb14ca5856d2d49d6b62d88e53.svg)\n", + "\n", + "所謂的「環境」其實同時包括了 agent 和 environment。\n", + "我們利用 `step()` 這個函式讓 agent 行動,而後函式便會回傳 environment 給予的 observation/state(以下這兩個名詞代表同樣的意思)和 reward。" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "bIbp82sljvAt" + }, + "source": [ + "### Observation / State\n", + "\n", + "首先,我們可以看看 environment 回傳給 agent 的 observation 究竟是長什麼樣子的資料:" + ] + }, + { + "cell_type": "code", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "rsXZra3N9R5T", + "outputId": "9512a449-f90a-4545-8aef-dd9aeb9b2b9e" + }, + "source": [ + "print(env.observation_space)" + ], + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "text": [ + "Box(-inf, inf, (8,), float32)\n" + ], + "name": "stdout" + } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "ezdfoThbAQ49" + }, + "source": [ + "`Box(8,)` 說明我們會拿到 8 維的向量作為 observation,其中包含:垂直及水平座標、速度、角度、加速度等等,這部分我們就不細說。\n", + "\n", + "### Action\n", + "\n", + "而在 agent 得到 observation 和 reward 以後,能夠採取的動作有:" + ] + }, + { + "cell_type": "code", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "p1k4dIrBAaKi", + "outputId": "64cd523a-bbff-4569-cae9-f65123b3c604" + }, + "source": [ + "print(env.action_space)" + ], + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "text": [ + "Discrete(4)\n" + ], + "name": "stdout" + } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "dejXT6PHBrPn" + }, + "source": [ + "`Discrete(4)` 說明 agent 可以採取四種離散的行動:\n", + "- 0 代表不採取任何行動\n", + "- 2 代表主引擎向下噴射\n", + "- 1, 3 則是向左右噴射\n", + "\n", + "接下來,我們嘗試讓 agent 與 environment 互動。\n", + "在進行任何操作前,建議先呼叫 `reset()` 函式讓整個「環境」重置。\n", + "而這個函式同時會回傳「環境」最初始的狀態。" + ] + }, + { + "cell_type": "code", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "pi4OmrmZgnWA", + "outputId": "c358ff73-1879-4a74-9579-9ee97740dc16" + }, + "source": [ + "initial_state = env.reset()\n", + "print(initial_state)" + ], + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "text": [ + "[ 0.00396109 1.4083536 0.40119505 -0.11407257 -0.00458307 -0.09087662\n", + " 0. 0. ]\n" + ], + "name": "stdout" + } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "uBx0mEqqgxJ9" + }, + "source": [ + "接著,我們試著從 agent 的四種行動空間中,隨機採取一個行動" + ] + }, + { + "cell_type": "code", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "vxkOEXRKgizt", + "outputId": "8912cf80-2310-401b-a37e-c0ded59626ee" + }, + "source": [ + "random_action = env.action_space.sample()\n", + "print(random_action)" + ], + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "text": [ + "0\n" + ], + "name": "stdout" + } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "mns-bO01g0-J" + }, + "source": [ + "再利用 `step()` 函式讓 agent 根據我們隨機抽樣出來的 `random_action` 動作。\n", + "而這個函式會回傳四項資訊:\n", + "- observation / state\n", + "- reward\n", + "- 完成與否\n", + "- 其餘資訊" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "E_WViSxGgIk9" + }, + "source": [ + "observation, reward, done, info = env.step(random_action)" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "FdieGq7NuBIm" + }, + "source": [ + "第一項資訊 `observation` 即為 agent 採取行動之後,agent 對於環境的 observation 或者說環境的 state 為何。\n", + "而第三項資訊 `done` 則是 `True` 或 `False` 的布林值,當登月小艇成功著陸或是不幸墜毀時,代表這個回合(episode)也就跟著結束了,此時 `step()` 函式便會回傳 `done = True`,而在那之前,`done` 則保持 `False`。" + ] + }, + { + "cell_type": "code", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "yK7r126kuCNp", + "outputId": "3b99114f-e6b4-4a18-c80b-75189083bd55" + }, + "source": [ + "print(done)" + ], + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "text": [ + "False\n" + ], + "name": "stdout" + } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "GKdS8vOihxhc" + }, + "source": [ + "### Reward\n", + "\n", + "而「環境」給予的 reward 大致是這樣計算:\n", + "- 小艇墜毀得到 -100 分\n", + "- 小艇在黃旗幟之間成功著地則得 100~140 分\n", + "- 噴射主引擎(向下噴火)每次 -0.3 分\n", + "- 小艇最終完全靜止則再得 100 分\n", + "- 小艇每隻腳碰觸地面 +10 分\n", + "\n", + "> Reward for moving from the top of the screen to landing pad and zero speed is about 100..140 points.\n", + "> If lander moves away from landing pad it loses reward back.\n", + "> Episode finishes if the lander crashes or comes to rest, receiving additional -100 or +100 points.\n", + "> Each leg ground contact is +10.\n", + "> Firing main engine is -0.3 points each frame.\n" + ] + }, + { + "cell_type": "code", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "vxQNs77hi0_7", + "outputId": "dacd87b3-734e-44f3-c5b4-361b323def84" + }, + "source": [ + "print(reward) # after doing a random action (0), the immediate reward is stored in this " + ], + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "text": [ + "-0.8588900517154912\n" + ], + "name": "stdout" + } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Mhqp6D-XgHpe" + }, + "source": [ + "### Random Agent\n", + "\n", + "最後,在進入實做之前,我們就來看看這樣一個 random agent 能否成功登陸月球:" + ] + }, + { + "cell_type": "code", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 269 + }, + "id": "Y3G0bxoccelv", + "outputId": "36096915-445e-40fb-b349-a6a9a5b900d5" + }, + "source": [ + "\n", + "env.reset()\n", + "\n", + "img = plt.imshow(env.render(mode='rgb_array'))\n", + "\n", + "done = False\n", + "while not done:\n", + " action = env.action_space.sample()\n", + " observation, reward, done, _ = env.step(action)\n", + "\n", + " img.set_data(env.render(mode='rgb_array'))\n", + " display.display(plt.gcf())\n", + " display.clear_output(wait=True)" + ], + "execution_count": null, + "outputs": [ + { + "output_type": "display_data", + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "tags": [], + "needs_background": "light" + } + } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "F5paWqo7tWL2" + }, + "source": [ + "## Policy Gradient\n", + "\n", + "現在來搭建一個簡單的 policy network。\n", + "我們預設模型的輸入是 8-dim 的 observation,輸出則是離散的四個動作之一:" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "J8tdmeD-tZew" + }, + "source": [ + "class PolicyGradientNetwork(nn.Module):\n", + "\n", + " def __init__(self):\n", + " super().__init__()\n", + " self.fc1 = nn.Linear(8, 16)\n", + " self.fc2 = nn.Linear(16, 16)\n", + " self.fc3 = nn.Linear(16, 4)\n", + "\n", + " def forward(self, state):\n", + " hid = torch.tanh(self.fc1(state))\n", + " hid = torch.tanh(self.fc2(hid))\n", + " return F.softmax(self.fc3(hid), dim=-1)" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "ynbqJrhIFTC3" + }, + "source": [ + "再來,搭建一個簡單的 agent,並搭配上方的 policy network 來採取行動。\n", + "這個 agent 能做到以下幾件事:\n", + "- `learn()`:從記下來的 log probabilities 及 rewards 來更新 policy network。\n", + "- `sample()`:從 environment 得到 observation 之後,利用 policy network 得出應該採取的行動。\n", + "而此函式除了回傳抽樣出來的 action,也會回傳此次抽樣的 log probabilities。" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "zZo-IxJx286z" + }, + "source": [ + "\n", + "class PolicyGradientAgent():\n", + " \n", + " def __init__(self, network):\n", + " self.network = network\n", + " self.optimizer = optim.SGD(self.network.parameters(), lr=0.001)\n", + " \n", + " def forward(self, state):\n", + " return self.network(state)\n", + " def learn(self, log_probs, rewards):\n", + " loss = (-log_probs * rewards).sum() # You don't need to revise this to pass simple baseline (but you can)\n", + "\n", + " self.optimizer.zero_grad()\n", + " loss.backward()\n", + " self.optimizer.step()\n", + " \n", + " def sample(self, state):\n", + " action_prob = self.network(torch.FloatTensor(state))\n", + " action_dist = Categorical(action_prob)\n", + " action = action_dist.sample()\n", + " log_prob = action_dist.log_prob(action)\n", + " return action.item(), log_prob\n", + "\n", + " def save(self, PATH): # You should not revise this\n", + " Agent_Dict = {\n", + " \"network\" : self.network.state_dict(),\n", + " \"optimizer\" : self.optimizer.state_dict()\n", + " }\n", + " torch.save(Agent_Dict, PATH)\n", + "\n", + " def load(self, PATH): # You should not revise this\n", + " checkpoint = torch.load(PATH)\n", + " self.network.load_state_dict(checkpoint[\"network\"])\n", + " #如果要儲存過程或是中斷訓練後想繼續可以用喔 ^_^\n", + " self.optimizer.load_state_dict(checkpoint[\"optimizer\"])\n" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "ehPlnTKyRZf9" + }, + "source": [ + "最後,建立一個 network 和 agent,就可以開始進行訓練了。" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "GfJIvML-RYjL" + }, + "source": [ + "network = PolicyGradientNetwork()\n", + "agent = PolicyGradientAgent(network)\n", + "#agent = PolicyGradientAgent()" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "ouv23glgf5Qt" + }, + "source": [ + "## 訓練 Agent\n", + "\n", + "現在我們開始訓練 agent。\n", + "透過讓 agent 和 environment 互動,我們記住每一組對應的 log probabilities 及 reward,並在成功登陸或者不幸墜毀後,回放這些「記憶」來訓練 policy network。" + ] + }, + { + "cell_type": "code", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 1000, + "referenced_widgets": [ + "2acab9542fe64b979fa2ac2adb3f10a8", + "f288c64b5ff748eb82178bf1de17934f", + "de34e5b178f5470e98e0275102a65042", + "c93cba301cac439ca56fb6b45bd1c4e4", + "43c6ee720b674626ab3a869bda5dd6e3", + "2465d2b109d34922a486341232d86ad6", + "aa27187195be4da9874025395eac35eb", + "02d196d4f9734f998455d92bd9300adb" + ] + }, + "id": "vg5rxBBaf38_", + "outputId": "eae0c9f4-0efc-40fe-a29e-7f7194613f6d" + }, + "source": [ + "agent.network.train() # 訓練前,先確保 network 處在 training 模式\n", + "EPISODE_PER_BATCH = 5 # 每蒐集 5 個 episodes 更新一次 agent\n", + "NUM_BATCH = 400 # 總共更新 400 次\n", + "\n", + "avg_total_rewards, avg_final_rewards = [], []\n", + "\n", + "prg_bar = tqdm(range(NUM_BATCH))\n", + "for batch in prg_bar:\n", + "\n", + " log_probs, rewards = [], []\n", + " total_rewards, final_rewards = [], []\n", + "\n", + " # 蒐集訓練資料\n", + " for episode in range(EPISODE_PER_BATCH):\n", + " \n", + " state = env.reset()\n", + " total_reward, total_step = 0, 0\n", + " seq_rewards = []\n", + " while True:\n", + "\n", + " action, log_prob = agent.sample(state) # at , log(at|st)\n", + " next_state, reward, done, _ = env.step(action)\n", + "\n", + " log_probs.append(log_prob) # [log(a1|s1), log(a2|s2), ...., log(at|st)]\n", + " # seq_rewards.append(reward)\n", + " state = next_state\n", + " total_reward += reward\n", + " total_step += 1\n", + " rewards.append(reward) #改這裡\n", + " # ! 重要 !\n", + " # 現在的reward 的implementation 為每個時刻的瞬時reward, 給定action_list : a1, a2, a3 ......\n", + " # reward : r1, r2 ,r3 ......\n", + " # medium:將reward調整成accumulative decaying reward, 給定action_list : a1, a2, a3 ......\n", + " # reward : r1+0.99*r2+0.99^2*r3+......, r2+0.99*r3+0.99^2*r4+...... ,r3+0.99*r4+0.99^2*r5+ ......\n", + " # boss : implement DQN\n", + " if done:\n", + " final_rewards.append(reward)\n", + " total_rewards.append(total_reward)\n", + " break\n", + "\n", + " print(f\"rewards looks like \", np.shape(rewards)) \n", + " print(f\"log_probs looks like \", np.shape(log_probs)) \n", + " # 紀錄訓練過程\n", + " avg_total_reward = sum(total_rewards) / len(total_rewards)\n", + " avg_final_reward = sum(final_rewards) / len(final_rewards)\n", + " avg_total_rewards.append(avg_total_reward)\n", + " avg_final_rewards.append(avg_final_reward)\n", + " prg_bar.set_description(f\"Total: {avg_total_reward: 4.1f}, Final: {avg_final_reward: 4.1f}\")\n", + "\n", + " # 更新網路\n", + " # rewards = np.concatenate(rewards, axis=0)\n", + " rewards = (rewards - np.mean(rewards)) / (np.std(rewards) + 1e-9) # 將 reward 正規標準化\n", + " agent.learn(torch.stack(log_probs), torch.from_numpy(rewards))\n", + " print(\"logs prob looks like \", torch.stack(log_probs).size())\n", + " print(\"torch.from_numpy(rewards) looks like \", torch.from_numpy(rewards).size())" + ], + "execution_count": null, + "outputs": [ + { + "output_type": "display_data", + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "2acab9542fe64b979fa2ac2adb3f10a8", + "version_minor": 0, + "version_major": 2 + }, + "text/plain": [ + "HBox(children=(FloatProgress(value=0.0, max=400.0), HTML(value='')))" + ] + }, + "metadata": { + "tags": [] + } + }, + { + "output_type": "stream", + "text": [ + "rewards looks like (448,)\n", + "log_probs looks like (448,)\n", + "logs prob looks like torch.Size([448])\n", + "torch.from_numpy(rewards) looks like torch.Size([448])\n", + "rewards looks like (515,)\n", + "log_probs looks like (515,)\n", + "logs prob looks like torch.Size([515])\n", + "torch.from_numpy(rewards) looks like torch.Size([515])\n", + "rewards looks like (392,)\n", + "log_probs looks like (392,)\n", + "logs prob looks like torch.Size([392])\n", + "torch.from_numpy(rewards) looks like torch.Size([392])\n", + "rewards looks like (518,)\n", + "log_probs looks like (518,)\n", + "logs prob looks like torch.Size([518])\n", + "torch.from_numpy(rewards) looks like torch.Size([518])\n", + "rewards looks like (472,)\n", + "log_probs looks like (472,)\n", + "logs prob looks like torch.Size([472])\n", + "torch.from_numpy(rewards) looks like torch.Size([472])\n", + "rewards looks like (530,)\n", + "log_probs looks like (530,)\n", + "logs prob looks like torch.Size([530])\n", + "torch.from_numpy(rewards) looks like torch.Size([530])\n", + "rewards looks like (463,)\n", + "log_probs looks like (463,)\n", + "logs prob looks like torch.Size([463])\n", + "torch.from_numpy(rewards) looks like torch.Size([463])\n", + "rewards looks like (540,)\n", + "log_probs looks like (540,)\n", + "logs prob looks like torch.Size([540])\n", + "torch.from_numpy(rewards) looks like torch.Size([540])\n", + "rewards looks like (513,)\n", + "log_probs looks like (513,)\n", + "logs prob looks like torch.Size([513])\n", + "torch.from_numpy(rewards) looks like torch.Size([513])\n", + "rewards looks like (449,)\n", + "log_probs looks like (449,)\n", + "logs prob looks like torch.Size([449])\n", + "torch.from_numpy(rewards) looks like torch.Size([449])\n", + "rewards looks like (602,)\n", + "log_probs looks like (602,)\n", + "logs prob looks like torch.Size([602])\n", + "torch.from_numpy(rewards) looks like torch.Size([602])\n", + "rewards looks like (542,)\n", + "log_probs looks like (542,)\n", + "logs prob looks like torch.Size([542])\n", + "torch.from_numpy(rewards) looks like torch.Size([542])\n", + "rewards looks like (503,)\n", + "log_probs looks like (503,)\n", + "logs prob looks like torch.Size([503])\n", + "torch.from_numpy(rewards) looks like torch.Size([503])\n", + "rewards looks like (470,)\n", + "log_probs looks like (470,)\n", + "logs prob looks like torch.Size([470])\n", + "torch.from_numpy(rewards) looks like torch.Size([470])\n", + "rewards looks like (518,)\n", + "log_probs looks like (518,)\n", + "logs prob looks like torch.Size([518])\n", + "torch.from_numpy(rewards) looks like torch.Size([518])\n", + "rewards looks like (421,)\n", + "log_probs looks like (421,)\n", + "logs prob looks like torch.Size([421])\n", + "torch.from_numpy(rewards) looks like torch.Size([421])\n", + "rewards looks like (592,)\n", + "log_probs looks like (592,)\n", + "logs prob looks like torch.Size([592])\n", + "torch.from_numpy(rewards) looks like torch.Size([592])\n", + "rewards looks like (520,)\n", + "log_probs looks like (520,)\n", + "logs prob looks like torch.Size([520])\n", + "torch.from_numpy(rewards) looks like torch.Size([520])\n", + "rewards looks like (494,)\n", + "log_probs looks like (494,)\n", + "logs prob looks like torch.Size([494])\n", + "torch.from_numpy(rewards) looks like torch.Size([494])\n", + "rewards looks like (461,)\n", + "log_probs looks like (461,)\n", + "logs prob looks like torch.Size([461])\n", + "torch.from_numpy(rewards) looks like torch.Size([461])\n", + "rewards looks like (572,)\n", + "log_probs looks like (572,)\n", + "logs prob looks like torch.Size([572])\n", + "torch.from_numpy(rewards) looks like torch.Size([572])\n", + "rewards looks like (593,)\n", + "log_probs looks like (593,)\n", + "logs prob looks like torch.Size([593])\n", + "torch.from_numpy(rewards) looks like torch.Size([593])\n", + "rewards looks like (569,)\n", + "log_probs looks like (569,)\n", + "logs prob looks like torch.Size([569])\n", + "torch.from_numpy(rewards) looks like torch.Size([569])\n", + "rewards looks like (546,)\n", + "log_probs looks like (546,)\n", + "logs prob looks like torch.Size([546])\n", + "torch.from_numpy(rewards) looks like torch.Size([546])\n", + "rewards looks like (612,)\n", + "log_probs looks like (612,)\n", + "logs prob looks like torch.Size([612])\n", + "torch.from_numpy(rewards) looks like torch.Size([612])\n", + "rewards looks like (534,)\n", + "log_probs looks like (534,)\n", + "logs prob looks like torch.Size([534])\n", + "torch.from_numpy(rewards) looks like torch.Size([534])\n", + "rewards looks like (513,)\n", + "log_probs looks like (513,)\n", + "logs prob looks like torch.Size([513])\n", + "torch.from_numpy(rewards) looks like torch.Size([513])\n", + "rewards looks like (513,)\n", + "log_probs looks like (513,)\n", + "logs prob looks like torch.Size([513])\n", + "torch.from_numpy(rewards) looks like torch.Size([513])\n", + "rewards looks like (535,)\n", + "log_probs looks like (535,)\n", + "logs prob looks like torch.Size([535])\n", + "torch.from_numpy(rewards) looks like torch.Size([535])\n", + "rewards looks like (533,)\n", + "log_probs looks like (533,)\n", + "logs prob looks like torch.Size([533])\n", + "torch.from_numpy(rewards) looks like torch.Size([533])\n", + "rewards looks like (521,)\n", + "log_probs looks like (521,)\n", + "logs prob looks like torch.Size([521])\n", + "torch.from_numpy(rewards) looks like torch.Size([521])\n", + "rewards looks like (566,)\n", + "log_probs looks like (566,)\n", + "logs prob looks like torch.Size([566])\n", + "torch.from_numpy(rewards) looks like torch.Size([566])\n", + "rewards looks like (586,)\n", + "log_probs looks like (586,)\n", + "logs prob looks like torch.Size([586])\n", + "torch.from_numpy(rewards) looks like torch.Size([586])\n", + "rewards looks like (575,)\n", + "log_probs looks like (575,)\n", + "logs prob looks like torch.Size([575])\n", + "torch.from_numpy(rewards) looks like torch.Size([575])\n", + "rewards looks like (709,)\n", + "log_probs looks like (709,)\n", + "logs prob looks like torch.Size([709])\n", + "torch.from_numpy(rewards) looks like torch.Size([709])\n", + "rewards looks like (486,)\n", + "log_probs looks like (486,)\n", + "logs prob looks like torch.Size([486])\n", + "torch.from_numpy(rewards) looks like torch.Size([486])\n", + "rewards looks like (557,)\n", + "log_probs looks like (557,)\n", + "logs prob looks like torch.Size([557])\n", + "torch.from_numpy(rewards) looks like torch.Size([557])\n", + "rewards looks like (517,)\n", + "log_probs looks like (517,)\n", + "logs prob looks like torch.Size([517])\n", + "torch.from_numpy(rewards) looks like torch.Size([517])\n", + "rewards looks like (550,)\n", + "log_probs looks like (550,)\n", + "logs prob looks like torch.Size([550])\n", + "torch.from_numpy(rewards) looks like torch.Size([550])\n", + "rewards looks like (690,)\n", + "log_probs looks like (690,)\n", + "logs prob looks like torch.Size([690])\n", + "torch.from_numpy(rewards) looks like torch.Size([690])\n", + "rewards looks like (591,)\n", + "log_probs looks like (591,)\n", + "logs prob looks like torch.Size([591])\n", + "torch.from_numpy(rewards) looks like torch.Size([591])\n", + "rewards looks like (689,)\n", + "log_probs looks like (689,)\n", + "logs prob looks like torch.Size([689])\n", + "torch.from_numpy(rewards) looks like torch.Size([689])\n", + "rewards looks like (1059,)\n", + "log_probs looks like (1059,)\n", + "logs prob looks like torch.Size([1059])\n", + "torch.from_numpy(rewards) looks like torch.Size([1059])\n", + "rewards looks like (619,)\n", + "log_probs looks like (619,)\n", + "logs prob looks like torch.Size([619])\n", + "torch.from_numpy(rewards) looks like torch.Size([619])\n", + "rewards looks like (527,)\n", + "log_probs looks like (527,)\n", + "logs prob looks like torch.Size([527])\n", + "torch.from_numpy(rewards) looks like torch.Size([527])\n", + "rewards looks like (514,)\n", + "log_probs looks like (514,)\n", + "logs prob looks like torch.Size([514])\n", + "torch.from_numpy(rewards) looks like torch.Size([514])\n", + "rewards looks like (655,)\n", + "log_probs looks like (655,)\n", + "logs prob looks like torch.Size([655])\n", + "torch.from_numpy(rewards) looks like torch.Size([655])\n", + "rewards looks like (667,)\n", + "log_probs looks like (667,)\n", + "logs prob looks like torch.Size([667])\n", + "torch.from_numpy(rewards) looks like torch.Size([667])\n", + "rewards looks like (712,)\n", + "log_probs looks like (712,)\n", + "logs prob looks like torch.Size([712])\n", + "torch.from_numpy(rewards) looks like torch.Size([712])\n", + "rewards looks like (636,)\n", + "log_probs looks like (636,)\n", + "logs prob looks like torch.Size([636])\n", + "torch.from_numpy(rewards) looks like torch.Size([636])\n", + "rewards looks like (620,)\n", + "log_probs looks like (620,)\n", + "logs prob looks like torch.Size([620])\n", + "torch.from_numpy(rewards) looks like torch.Size([620])\n", + "rewards looks like (543,)\n", + "log_probs looks like (543,)\n", + "logs prob looks like torch.Size([543])\n", + "torch.from_numpy(rewards) looks like torch.Size([543])\n", + "rewards looks like (586,)\n", + "log_probs looks like (586,)\n", + "logs prob looks like torch.Size([586])\n", + "torch.from_numpy(rewards) looks like torch.Size([586])\n", + "rewards looks like (498,)\n", + "log_probs looks like (498,)\n", + "logs prob looks like torch.Size([498])\n", + "torch.from_numpy(rewards) looks like torch.Size([498])\n", + "rewards looks like (586,)\n", + "log_probs looks like (586,)\n", + "logs prob looks like torch.Size([586])\n", + "torch.from_numpy(rewards) looks like torch.Size([586])\n", + "rewards looks like (591,)\n", + "log_probs looks like (591,)\n", + "logs prob looks like torch.Size([591])\n", + "torch.from_numpy(rewards) looks like torch.Size([591])\n", + "rewards looks like (693,)\n", + "log_probs looks like (693,)\n", + "logs prob looks like torch.Size([693])\n", + "torch.from_numpy(rewards) looks like torch.Size([693])\n", + "rewards looks like (648,)\n", + "log_probs looks like (648,)\n", + "logs prob looks like torch.Size([648])\n", + "torch.from_numpy(rewards) looks like torch.Size([648])\n", + "rewards looks like (513,)\n", + "log_probs looks like (513,)\n", + "logs prob looks like torch.Size([513])\n", + "torch.from_numpy(rewards) looks like torch.Size([513])\n", + "rewards looks like (574,)\n", + "log_probs looks like (574,)\n", + "logs prob looks like torch.Size([574])\n", + "torch.from_numpy(rewards) looks like torch.Size([574])\n", + "rewards looks like (718,)\n", + "log_probs looks like (718,)\n", + "logs prob looks like torch.Size([718])\n", + "torch.from_numpy(rewards) looks like torch.Size([718])\n", + "rewards looks like (730,)\n", + "log_probs looks like (730,)\n", + "logs prob looks like torch.Size([730])\n", + "torch.from_numpy(rewards) looks like torch.Size([730])\n", + "rewards looks like (668,)\n", + "log_probs looks like (668,)\n", + "logs prob looks like torch.Size([668])\n", + "torch.from_numpy(rewards) looks like torch.Size([668])\n", + "rewards looks like (754,)\n", + "log_probs looks like (754,)\n", + "logs prob looks like torch.Size([754])\n", + "torch.from_numpy(rewards) looks like torch.Size([754])\n", + "rewards looks like (712,)\n", + "log_probs looks like (712,)\n", + "logs prob looks like torch.Size([712])\n", + "torch.from_numpy(rewards) looks like torch.Size([712])\n", + "rewards looks like (470,)\n", + "log_probs looks like (470,)\n", + "logs prob looks like torch.Size([470])\n", + "torch.from_numpy(rewards) looks like torch.Size([470])\n", + "rewards looks like (665,)\n", + "log_probs looks like (665,)\n", + "logs prob looks like torch.Size([665])\n", + "torch.from_numpy(rewards) looks like torch.Size([665])\n", + "rewards looks like (585,)\n", + "log_probs looks like (585,)\n", + "logs prob looks like torch.Size([585])\n", + "torch.from_numpy(rewards) looks like torch.Size([585])\n", + "rewards looks like (512,)\n", + "log_probs looks like (512,)\n", + "logs prob looks like torch.Size([512])\n", + "torch.from_numpy(rewards) looks like torch.Size([512])\n", + "rewards looks like (702,)\n", + "log_probs looks like (702,)\n", + "logs prob looks like torch.Size([702])\n", + "torch.from_numpy(rewards) looks like torch.Size([702])\n", + "rewards looks like (596,)\n", + "log_probs looks like (596,)\n", + "logs prob looks like torch.Size([596])\n", + "torch.from_numpy(rewards) looks like torch.Size([596])\n", + "rewards looks like (626,)\n", + "log_probs looks like (626,)\n", + "logs prob looks like torch.Size([626])\n", + "torch.from_numpy(rewards) looks like torch.Size([626])\n", + "rewards looks like (566,)\n", + "log_probs looks like (566,)\n", + "logs prob looks like torch.Size([566])\n", + "torch.from_numpy(rewards) looks like torch.Size([566])\n", + "rewards looks like (717,)\n", + "log_probs looks like (717,)\n", + "logs prob looks like torch.Size([717])\n", + "torch.from_numpy(rewards) looks like torch.Size([717])\n", + "rewards looks like (708,)\n", + "log_probs looks like (708,)\n", + "logs prob looks like torch.Size([708])\n", + "torch.from_numpy(rewards) looks like torch.Size([708])\n", + "rewards looks like (565,)\n", + "log_probs looks like (565,)\n", + "logs prob looks like torch.Size([565])\n", + "torch.from_numpy(rewards) looks like torch.Size([565])\n", + "rewards looks like (450,)\n", + "log_probs looks like (450,)\n", + "logs prob looks like torch.Size([450])\n", + "torch.from_numpy(rewards) looks like torch.Size([450])\n", + "rewards looks like (584,)\n", + "log_probs looks like (584,)\n", + "logs prob looks like torch.Size([584])\n", + "torch.from_numpy(rewards) looks like torch.Size([584])\n", + "rewards looks like (670,)\n", + "log_probs looks like (670,)\n", + "logs prob looks like torch.Size([670])\n", + "torch.from_numpy(rewards) looks like torch.Size([670])\n", + "rewards looks like (691,)\n", + "log_probs looks like (691,)\n", + "logs prob looks like torch.Size([691])\n", + "torch.from_numpy(rewards) looks like torch.Size([691])\n", + "rewards looks like (760,)\n", + "log_probs looks like (760,)\n", + "logs prob looks like torch.Size([760])\n", + "torch.from_numpy(rewards) looks like torch.Size([760])\n", + "rewards looks like (752,)\n", + "log_probs looks like (752,)\n", + "logs prob looks like torch.Size([752])\n", + "torch.from_numpy(rewards) looks like torch.Size([752])\n", + "rewards looks like (478,)\n", + "log_probs looks like (478,)\n", + "logs prob looks like torch.Size([478])\n", + "torch.from_numpy(rewards) looks like torch.Size([478])\n", + "rewards looks like (553,)\n", + "log_probs looks like (553,)\n", + "logs prob looks like torch.Size([553])\n", + "torch.from_numpy(rewards) looks like torch.Size([553])\n", + "rewards looks like (1660,)\n", + "log_probs looks like (1660,)\n", + "logs prob looks like torch.Size([1660])\n", + "torch.from_numpy(rewards) looks like torch.Size([1660])\n", + "rewards looks like (751,)\n", + "log_probs looks like (751,)\n", + "logs prob looks like torch.Size([751])\n", + "torch.from_numpy(rewards) looks like torch.Size([751])\n", + "rewards looks like (801,)\n", + "log_probs looks like (801,)\n", + "logs prob looks like torch.Size([801])\n", + "torch.from_numpy(rewards) looks like torch.Size([801])\n", + "rewards looks like (715,)\n", + "log_probs looks like (715,)\n", + "logs prob looks like torch.Size([715])\n", + "torch.from_numpy(rewards) looks like torch.Size([715])\n", + "rewards looks like (708,)\n", + "log_probs looks like (708,)\n", + "logs prob looks like torch.Size([708])\n", + "torch.from_numpy(rewards) looks like torch.Size([708])\n", + "rewards looks like (609,)\n", + "log_probs looks like (609,)\n", + "logs prob looks like torch.Size([609])\n", + "torch.from_numpy(rewards) looks like torch.Size([609])\n", + "rewards looks like (732,)\n", + "log_probs looks like (732,)\n", + "logs prob looks like torch.Size([732])\n", + "torch.from_numpy(rewards) looks like torch.Size([732])\n", + "rewards looks like (603,)\n", + "log_probs looks like (603,)\n", + "logs prob looks like torch.Size([603])\n", + "torch.from_numpy(rewards) looks like torch.Size([603])\n", + "rewards looks like (603,)\n", + "log_probs looks like (603,)\n", + "logs prob looks like torch.Size([603])\n", + "torch.from_numpy(rewards) looks like torch.Size([603])\n", + "rewards looks like (665,)\n", + "log_probs looks like (665,)\n", + "logs prob looks like torch.Size([665])\n", + "torch.from_numpy(rewards) looks like torch.Size([665])\n", + "rewards looks like (658,)\n", + "log_probs looks like (658,)\n", + "logs prob looks like torch.Size([658])\n", + "torch.from_numpy(rewards) looks like torch.Size([658])\n", + "rewards looks like (783,)\n", + "log_probs looks like (783,)\n", + "logs prob looks like torch.Size([783])\n", + "torch.from_numpy(rewards) looks like torch.Size([783])\n", + "rewards looks like (652,)\n", + "log_probs looks like (652,)\n", + "logs prob looks like torch.Size([652])\n", + "torch.from_numpy(rewards) looks like torch.Size([652])\n", + "rewards looks like (892,)\n", + "log_probs looks like (892,)\n", + "logs prob looks like torch.Size([892])\n", + "torch.from_numpy(rewards) looks like torch.Size([892])\n", + "rewards looks like (821,)\n", + "log_probs looks like (821,)\n", + "logs prob looks like torch.Size([821])\n", + "torch.from_numpy(rewards) looks like torch.Size([821])\n", + "rewards looks like (986,)\n", + "log_probs looks like (986,)\n", + "logs prob looks like torch.Size([986])\n", + "torch.from_numpy(rewards) looks like torch.Size([986])\n", + "rewards looks like (916,)\n", + "log_probs looks like (916,)\n", + "logs prob looks like torch.Size([916])\n", + "torch.from_numpy(rewards) looks like torch.Size([916])\n", + "rewards looks like (742,)\n", + "log_probs looks like (742,)\n", + "logs prob looks like torch.Size([742])\n", + "torch.from_numpy(rewards) looks like torch.Size([742])\n", + "rewards looks like (604,)\n", + "log_probs looks like (604,)\n", + "logs prob looks like torch.Size([604])\n", + "torch.from_numpy(rewards) looks like torch.Size([604])\n", + "rewards looks like (818,)\n", + "log_probs looks like (818,)\n", + "logs prob looks like torch.Size([818])\n", + "torch.from_numpy(rewards) looks like torch.Size([818])\n", + "rewards looks like (855,)\n", + "log_probs looks like (855,)\n", + "logs prob looks like torch.Size([855])\n", + "torch.from_numpy(rewards) looks like torch.Size([855])\n", + "rewards looks like (795,)\n", + "log_probs looks like (795,)\n", + "logs prob looks like torch.Size([795])\n", + "torch.from_numpy(rewards) looks like torch.Size([795])\n", + "rewards looks like (868,)\n", + "log_probs looks like (868,)\n", + "logs prob looks like torch.Size([868])\n", + "torch.from_numpy(rewards) looks like torch.Size([868])\n", + "rewards looks like (800,)\n", + "log_probs looks like (800,)\n", + "logs prob looks like torch.Size([800])\n", + "torch.from_numpy(rewards) looks like torch.Size([800])\n", + "rewards looks like (820,)\n", + "log_probs looks like (820,)\n", + "logs prob looks like torch.Size([820])\n", + "torch.from_numpy(rewards) looks like torch.Size([820])\n", + "rewards looks like (760,)\n", + "log_probs looks like (760,)\n", + "logs prob looks like torch.Size([760])\n", + "torch.from_numpy(rewards) looks like torch.Size([760])\n", + "rewards looks like (886,)\n", + "log_probs looks like (886,)\n", + "logs prob looks like torch.Size([886])\n", + "torch.from_numpy(rewards) looks like torch.Size([886])\n", + "rewards looks like (1027,)\n", + "log_probs looks like (1027,)\n", + "logs prob looks like torch.Size([1027])\n", + "torch.from_numpy(rewards) looks like torch.Size([1027])\n", + "rewards looks like (819,)\n", + "log_probs looks like (819,)\n", + "logs prob looks like torch.Size([819])\n", + "torch.from_numpy(rewards) looks like torch.Size([819])\n", + "rewards looks like (934,)\n", + "log_probs looks like (934,)\n", + "logs prob looks like torch.Size([934])\n", + "torch.from_numpy(rewards) looks like torch.Size([934])\n", + "rewards looks like (1648,)\n", + "log_probs looks like (1648,)\n", + "logs prob looks like torch.Size([1648])\n", + "torch.from_numpy(rewards) looks like torch.Size([1648])\n", + "rewards looks like (1057,)\n", + "log_probs looks like (1057,)\n", + "logs prob looks like torch.Size([1057])\n", + "torch.from_numpy(rewards) looks like torch.Size([1057])\n", + "rewards looks like (861,)\n", + "log_probs looks like (861,)\n", + "logs prob looks like torch.Size([861])\n", + "torch.from_numpy(rewards) looks like torch.Size([861])\n", + "rewards looks like (1533,)\n", + "log_probs looks like (1533,)\n", + "logs prob looks like torch.Size([1533])\n", + "torch.from_numpy(rewards) looks like torch.Size([1533])\n", + "rewards looks like (920,)\n", + "log_probs looks like (920,)\n", + "logs prob looks like torch.Size([920])\n", + "torch.from_numpy(rewards) looks like torch.Size([920])\n", + "rewards looks like (905,)\n", + "log_probs looks like (905,)\n", + "logs prob looks like torch.Size([905])\n", + "torch.from_numpy(rewards) looks like torch.Size([905])\n", + "rewards looks like (814,)\n", + "log_probs looks like (814,)\n", + "logs prob looks like torch.Size([814])\n", + "torch.from_numpy(rewards) looks like torch.Size([814])\n", + "rewards looks like (809,)\n", + "log_probs looks like (809,)\n", + "logs prob looks like torch.Size([809])\n", + "torch.from_numpy(rewards) looks like torch.Size([809])\n", + "rewards looks like (873,)\n", + "log_probs looks like (873,)\n", + "logs prob looks like torch.Size([873])\n", + "torch.from_numpy(rewards) looks like torch.Size([873])\n", + "rewards looks like (727,)\n", + "log_probs looks like (727,)\n", + "logs prob looks like torch.Size([727])\n", + "torch.from_numpy(rewards) looks like torch.Size([727])\n", + "rewards looks like (1129,)\n", + "log_probs looks like (1129,)\n", + "logs prob looks like torch.Size([1129])\n", + "torch.from_numpy(rewards) looks like torch.Size([1129])\n", + "rewards looks like (1394,)\n", + "log_probs looks like (1394,)\n", + "logs prob looks like torch.Size([1394])\n", + "torch.from_numpy(rewards) looks like torch.Size([1394])\n", + "rewards looks like (884,)\n", + "log_probs looks like (884,)\n", + "logs prob looks like torch.Size([884])\n", + "torch.from_numpy(rewards) looks like torch.Size([884])\n", + "rewards looks like (1132,)\n", + "log_probs looks like (1132,)\n", + "logs prob looks like torch.Size([1132])\n", + "torch.from_numpy(rewards) looks like torch.Size([1132])\n", + "rewards looks like (1007,)\n", + "log_probs looks like (1007,)\n", + "logs prob looks like torch.Size([1007])\n", + "torch.from_numpy(rewards) looks like torch.Size([1007])\n", + "rewards looks like (711,)\n", + "log_probs looks like (711,)\n", + "logs prob looks like torch.Size([711])\n", + "torch.from_numpy(rewards) looks like torch.Size([711])\n", + "rewards looks like (836,)\n", + "log_probs looks like (836,)\n", + "logs prob looks like torch.Size([836])\n", + "torch.from_numpy(rewards) looks like torch.Size([836])\n", + "rewards looks like (1514,)\n", + "log_probs looks like (1514,)\n", + "logs prob looks like torch.Size([1514])\n", + "torch.from_numpy(rewards) looks like torch.Size([1514])\n", + "rewards looks like (896,)\n", + "log_probs looks like (896,)\n", + "logs prob looks like torch.Size([896])\n", + "torch.from_numpy(rewards) looks like torch.Size([896])\n", + "rewards looks like (912,)\n", + "log_probs looks like (912,)\n", + "logs prob looks like torch.Size([912])\n", + "torch.from_numpy(rewards) looks like torch.Size([912])\n", + "rewards looks like (1478,)\n", + "log_probs looks like (1478,)\n", + "logs prob looks like torch.Size([1478])\n", + "torch.from_numpy(rewards) looks like torch.Size([1478])\n", + "rewards looks like (1279,)\n", + "log_probs looks like (1279,)\n", + "logs prob looks like torch.Size([1279])\n", + "torch.from_numpy(rewards) looks like torch.Size([1279])\n", + "rewards looks like (676,)\n", + "log_probs looks like (676,)\n", + "logs prob looks like torch.Size([676])\n", + "torch.from_numpy(rewards) looks like torch.Size([676])\n", + "rewards looks like (1768,)\n", + "log_probs looks like (1768,)\n", + "logs prob looks like torch.Size([1768])\n", + "torch.from_numpy(rewards) looks like torch.Size([1768])\n", + "rewards looks like (897,)\n", + "log_probs looks like (897,)\n", + "logs prob looks like torch.Size([897])\n", + "torch.from_numpy(rewards) looks like torch.Size([897])\n", + "rewards looks like (1252,)\n", + "log_probs looks like (1252,)\n", + "logs prob looks like torch.Size([1252])\n", + "torch.from_numpy(rewards) looks like torch.Size([1252])\n", + "rewards looks like (995,)\n", + "log_probs looks like (995,)\n", + "logs prob looks like torch.Size([995])\n", + "torch.from_numpy(rewards) looks like torch.Size([995])\n", + "rewards looks like (1075,)\n", + "log_probs looks like (1075,)\n", + "logs prob looks like torch.Size([1075])\n", + "torch.from_numpy(rewards) looks like torch.Size([1075])\n", + "rewards looks like (878,)\n", + "log_probs looks like (878,)\n", + "logs prob looks like torch.Size([878])\n", + "torch.from_numpy(rewards) looks like torch.Size([878])\n", + "rewards looks like (1341,)\n", + "log_probs looks like (1341,)\n", + "logs prob looks like torch.Size([1341])\n", + "torch.from_numpy(rewards) looks like torch.Size([1341])\n", + "rewards looks like (1518,)\n", + "log_probs looks like (1518,)\n", + "logs prob looks like torch.Size([1518])\n", + "torch.from_numpy(rewards) looks like torch.Size([1518])\n", + "rewards looks like (1781,)\n", + "log_probs looks like (1781,)\n", + "logs prob looks like torch.Size([1781])\n", + "torch.from_numpy(rewards) looks like torch.Size([1781])\n", + "rewards looks like (1725,)\n", + "log_probs looks like (1725,)\n", + "logs prob looks like torch.Size([1725])\n", + "torch.from_numpy(rewards) looks like torch.Size([1725])\n", + "rewards looks like (1602,)\n", + "log_probs looks like (1602,)\n", + "logs prob looks like torch.Size([1602])\n", + "torch.from_numpy(rewards) looks like torch.Size([1602])\n", + "rewards looks like (846,)\n", + "log_probs looks like (846,)\n", + "logs prob looks like torch.Size([846])\n", + "torch.from_numpy(rewards) looks like torch.Size([846])\n", + "rewards looks like (1211,)\n", + "log_probs looks like (1211,)\n", + "logs prob looks like torch.Size([1211])\n", + "torch.from_numpy(rewards) looks like torch.Size([1211])\n", + "rewards looks like (3273,)\n", + "log_probs looks like (3273,)\n", + "logs prob looks like torch.Size([3273])\n", + "torch.from_numpy(rewards) looks like torch.Size([3273])\n", + "rewards looks like (744,)\n", + "log_probs looks like (744,)\n", + "logs prob looks like torch.Size([744])\n", + "torch.from_numpy(rewards) looks like torch.Size([744])\n", + "rewards looks like (1751,)\n", + "log_probs looks like (1751,)\n", + "logs prob looks like torch.Size([1751])\n", + "torch.from_numpy(rewards) looks like torch.Size([1751])\n", + "rewards looks like (1244,)\n", + "log_probs looks like (1244,)\n", + "logs prob looks like torch.Size([1244])\n", + "torch.from_numpy(rewards) looks like torch.Size([1244])\n", + "rewards looks like (1313,)\n", + "log_probs looks like (1313,)\n", + "logs prob looks like torch.Size([1313])\n", + "torch.from_numpy(rewards) looks like torch.Size([1313])\n", + "rewards looks like (1993,)\n", + "log_probs looks like (1993,)\n", + "logs prob looks like torch.Size([1993])\n", + "torch.from_numpy(rewards) looks like torch.Size([1993])\n", + "rewards looks like (709,)\n", + "log_probs looks like (709,)\n", + "logs prob looks like torch.Size([709])\n", + "torch.from_numpy(rewards) looks like torch.Size([709])\n", + "rewards looks like (934,)\n", + "log_probs looks like (934,)\n", + "logs prob looks like torch.Size([934])\n", + "torch.from_numpy(rewards) looks like torch.Size([934])\n", + "rewards looks like (1386,)\n", + "log_probs looks like (1386,)\n", + "logs prob looks like torch.Size([1386])\n", + "torch.from_numpy(rewards) looks like torch.Size([1386])\n", + "rewards looks like (635,)\n", + "log_probs looks like (635,)\n", + "logs prob looks like torch.Size([635])\n", + "torch.from_numpy(rewards) looks like torch.Size([635])\n", + "rewards looks like (750,)\n", + "log_probs looks like (750,)\n", + "logs prob looks like torch.Size([750])\n", + "torch.from_numpy(rewards) looks like torch.Size([750])\n", + "rewards looks like (1832,)\n", + "log_probs looks like (1832,)\n", + "logs prob looks like torch.Size([1832])\n", + "torch.from_numpy(rewards) looks like torch.Size([1832])\n", + "rewards looks like (1237,)\n", + "log_probs looks like (1237,)\n", + "logs prob looks like torch.Size([1237])\n", + "torch.from_numpy(rewards) looks like torch.Size([1237])\n", + "rewards looks like (1605,)\n", + "log_probs looks like (1605,)\n", + "logs prob looks like torch.Size([1605])\n", + "torch.from_numpy(rewards) looks like torch.Size([1605])\n", + "rewards looks like (718,)\n", + "log_probs looks like (718,)\n", + "logs prob looks like torch.Size([718])\n", + "torch.from_numpy(rewards) looks like torch.Size([718])\n", + "rewards looks like (966,)\n", + "log_probs looks like (966,)\n", + "logs prob looks like torch.Size([966])\n", + "torch.from_numpy(rewards) looks like torch.Size([966])\n", + "rewards looks like (2696,)\n", + "log_probs looks like (2696,)\n", + "logs prob looks like torch.Size([2696])\n", + "torch.from_numpy(rewards) looks like torch.Size([2696])\n", + "rewards looks like (762,)\n", + "log_probs looks like (762,)\n", + "logs prob looks like torch.Size([762])\n", + "torch.from_numpy(rewards) looks like torch.Size([762])\n", + "rewards looks like (1048,)\n", + "log_probs looks like (1048,)\n", + "logs prob looks like torch.Size([1048])\n", + "torch.from_numpy(rewards) looks like torch.Size([1048])\n", + "rewards looks like (1573,)\n", + "log_probs looks like (1573,)\n", + "logs prob looks like torch.Size([1573])\n", + "torch.from_numpy(rewards) looks like torch.Size([1573])\n", + "rewards looks like (2192,)\n", + "log_probs looks like (2192,)\n", + "logs prob looks like torch.Size([2192])\n", + "torch.from_numpy(rewards) looks like torch.Size([2192])\n", + "rewards looks like (599,)\n", + "log_probs looks like (599,)\n", + "logs prob looks like torch.Size([599])\n", + "torch.from_numpy(rewards) looks like torch.Size([599])\n", + "rewards looks like (758,)\n", + "log_probs looks like (758,)\n", + "logs prob looks like torch.Size([758])\n", + "torch.from_numpy(rewards) looks like torch.Size([758])\n", + "rewards looks like (1955,)\n", + "log_probs looks like (1955,)\n", + "logs prob looks like torch.Size([1955])\n", + "torch.from_numpy(rewards) looks like torch.Size([1955])\n", + "rewards looks like (1770,)\n", + "log_probs looks like (1770,)\n", + "logs prob looks like torch.Size([1770])\n", + "torch.from_numpy(rewards) looks like torch.Size([1770])\n", + "rewards looks like (1343,)\n", + "log_probs looks like (1343,)\n", + "logs prob looks like torch.Size([1343])\n", + "torch.from_numpy(rewards) looks like torch.Size([1343])\n", + "rewards looks like (507,)\n", + "log_probs looks like (507,)\n", + "logs prob looks like torch.Size([507])\n", + "torch.from_numpy(rewards) looks like torch.Size([507])\n", + "rewards looks like (1343,)\n", + "log_probs looks like (1343,)\n", + "logs prob looks like torch.Size([1343])\n", + "torch.from_numpy(rewards) looks like torch.Size([1343])\n", + "rewards looks like (1341,)\n", + "log_probs looks like (1341,)\n", + "logs prob looks like torch.Size([1341])\n", + "torch.from_numpy(rewards) looks like torch.Size([1341])\n", + "rewards looks like (1489,)\n", + "log_probs looks like (1489,)\n", + "logs prob looks like torch.Size([1489])\n", + "torch.from_numpy(rewards) looks like torch.Size([1489])\n", + "rewards looks like (3342,)\n", + "log_probs looks like (3342,)\n", + "logs prob looks like torch.Size([3342])\n", + "torch.from_numpy(rewards) looks like torch.Size([3342])\n", + "rewards looks like (1891,)\n", + "log_probs looks like (1891,)\n", + "logs prob looks like torch.Size([1891])\n", + "torch.from_numpy(rewards) looks like torch.Size([1891])\n", + "rewards looks like (1401,)\n", + "log_probs looks like (1401,)\n", + "logs prob looks like torch.Size([1401])\n", + "torch.from_numpy(rewards) looks like torch.Size([1401])\n", + "rewards looks like (2964,)\n", + "log_probs looks like (2964,)\n", + "logs prob looks like torch.Size([2964])\n", + "torch.from_numpy(rewards) looks like torch.Size([2964])\n", + "rewards looks like (1404,)\n", + "log_probs looks like (1404,)\n", + "logs prob looks like torch.Size([1404])\n", + "torch.from_numpy(rewards) looks like torch.Size([1404])\n", + "rewards looks like (780,)\n", + "log_probs looks like (780,)\n", + "logs prob looks like torch.Size([780])\n", + "torch.from_numpy(rewards) looks like torch.Size([780])\n", + "rewards looks like (1632,)\n", + "log_probs looks like (1632,)\n", + "logs prob looks like torch.Size([1632])\n", + "torch.from_numpy(rewards) looks like torch.Size([1632])\n", + "rewards looks like (1578,)\n", + "log_probs looks like (1578,)\n", + "logs prob looks like torch.Size([1578])\n", + "torch.from_numpy(rewards) looks like torch.Size([1578])\n", + "rewards looks like (1082,)\n", + "log_probs looks like (1082,)\n", + "logs prob looks like torch.Size([1082])\n", + "torch.from_numpy(rewards) looks like torch.Size([1082])\n", + "rewards looks like (1423,)\n", + "log_probs looks like (1423,)\n", + "logs prob looks like torch.Size([1423])\n", + "torch.from_numpy(rewards) looks like torch.Size([1423])\n", + "rewards looks like (2867,)\n", + "log_probs looks like (2867,)\n", + "logs prob looks like torch.Size([2867])\n", + "torch.from_numpy(rewards) looks like torch.Size([2867])\n", + "rewards looks like (1733,)\n", + "log_probs looks like (1733,)\n", + "logs prob looks like torch.Size([1733])\n", + "torch.from_numpy(rewards) looks like torch.Size([1733])\n", + "rewards looks like (646,)\n", + "log_probs looks like (646,)\n", + "logs prob looks like torch.Size([646])\n", + "torch.from_numpy(rewards) looks like torch.Size([646])\n", + "rewards looks like (1576,)\n", + "log_probs looks like (1576,)\n", + "logs prob looks like torch.Size([1576])\n", + "torch.from_numpy(rewards) looks like torch.Size([1576])\n", + "rewards looks like (1869,)\n", + "log_probs looks like (1869,)\n", + "logs prob looks like torch.Size([1869])\n", + "torch.from_numpy(rewards) looks like torch.Size([1869])\n", + "rewards looks like (1862,)\n", + "log_probs looks like (1862,)\n", + "logs prob looks like torch.Size([1862])\n", + "torch.from_numpy(rewards) looks like torch.Size([1862])\n", + "rewards looks like (3182,)\n", + "log_probs looks like (3182,)\n", + "logs prob looks like torch.Size([3182])\n", + "torch.from_numpy(rewards) looks like torch.Size([3182])\n", + "rewards looks like (1746,)\n", + "log_probs looks like (1746,)\n", + "logs prob looks like torch.Size([1746])\n", + "torch.from_numpy(rewards) looks like torch.Size([1746])\n", + "rewards looks like (1855,)\n", + "log_probs looks like (1855,)\n", + "logs prob looks like torch.Size([1855])\n", + "torch.from_numpy(rewards) looks like torch.Size([1855])\n", + "rewards looks like (2710,)\n", + "log_probs looks like (2710,)\n", + "logs prob looks like torch.Size([2710])\n", + "torch.from_numpy(rewards) looks like torch.Size([2710])\n", + "rewards looks like (1707,)\n", + "log_probs looks like (1707,)\n", + "logs prob looks like torch.Size([1707])\n", + "torch.from_numpy(rewards) looks like torch.Size([1707])\n", + "rewards looks like (1723,)\n", + "log_probs looks like (1723,)\n", + "logs prob looks like torch.Size([1723])\n", + "torch.from_numpy(rewards) looks like torch.Size([1723])\n", + "rewards looks like (1590,)\n", + "log_probs looks like (1590,)\n", + "logs prob looks like torch.Size([1590])\n", + "torch.from_numpy(rewards) looks like torch.Size([1590])\n", + "rewards looks like (1432,)\n", + "log_probs looks like (1432,)\n", + "logs prob looks like torch.Size([1432])\n", + "torch.from_numpy(rewards) looks like torch.Size([1432])\n", + "rewards looks like (2742,)\n", + "log_probs looks like (2742,)\n", + "logs prob looks like torch.Size([2742])\n", + "torch.from_numpy(rewards) looks like torch.Size([2742])\n", + "rewards looks like (3007,)\n", + "log_probs looks like (3007,)\n", + "logs prob looks like torch.Size([3007])\n", + "torch.from_numpy(rewards) looks like torch.Size([3007])\n", + "rewards looks like (2064,)\n", + "log_probs looks like (2064,)\n", + "logs prob looks like torch.Size([2064])\n", + "torch.from_numpy(rewards) looks like torch.Size([2064])\n", + "rewards looks like (1447,)\n", + "log_probs looks like (1447,)\n", + "logs prob looks like torch.Size([1447])\n", + "torch.from_numpy(rewards) looks like torch.Size([1447])\n", + "rewards looks like (4007,)\n", + "log_probs looks like (4007,)\n", + "logs prob looks like torch.Size([4007])\n", + "torch.from_numpy(rewards) looks like torch.Size([4007])\n", + "rewards looks like (611,)\n", + "log_probs looks like (611,)\n", + "logs prob looks like torch.Size([611])\n", + "torch.from_numpy(rewards) looks like torch.Size([611])\n", + "rewards looks like (1633,)\n", + "log_probs looks like (1633,)\n", + "logs prob looks like torch.Size([1633])\n", + "torch.from_numpy(rewards) looks like torch.Size([1633])\n", + "rewards looks like (3295,)\n", + "log_probs looks like (3295,)\n", + "logs prob looks like torch.Size([3295])\n", + "torch.from_numpy(rewards) looks like torch.Size([3295])\n", + "rewards looks like (975,)\n", + "log_probs looks like (975,)\n", + "logs prob looks like torch.Size([975])\n", + "torch.from_numpy(rewards) looks like torch.Size([975])\n", + "rewards looks like (1991,)\n", + "log_probs looks like (1991,)\n", + "logs prob looks like torch.Size([1991])\n", + "torch.from_numpy(rewards) looks like torch.Size([1991])\n", + "rewards looks like (2409,)\n", + "log_probs looks like (2409,)\n", + "logs prob looks like torch.Size([2409])\n", + "torch.from_numpy(rewards) looks like torch.Size([2409])\n", + "rewards looks like (1587,)\n", + "log_probs looks like (1587,)\n", + "logs prob looks like torch.Size([1587])\n", + "torch.from_numpy(rewards) looks like torch.Size([1587])\n", + "rewards looks like (1334,)\n", + "log_probs looks like (1334,)\n", + "logs prob looks like torch.Size([1334])\n", + "torch.from_numpy(rewards) looks like torch.Size([1334])\n", + "rewards looks like (1070,)\n", + "log_probs looks like (1070,)\n", + "logs prob looks like torch.Size([1070])\n", + "torch.from_numpy(rewards) looks like torch.Size([1070])\n", + "rewards looks like (1082,)\n", + "log_probs looks like (1082,)\n", + "logs prob looks like torch.Size([1082])\n", + "torch.from_numpy(rewards) looks like torch.Size([1082])\n", + "rewards looks like (1084,)\n", + "log_probs looks like (1084,)\n", + "logs prob looks like torch.Size([1084])\n", + "torch.from_numpy(rewards) looks like torch.Size([1084])\n", + "rewards looks like (1192,)\n", + "log_probs looks like (1192,)\n", + "logs prob looks like torch.Size([1192])\n", + "torch.from_numpy(rewards) looks like torch.Size([1192])\n", + "rewards looks like (1287,)\n", + "log_probs looks like (1287,)\n", + "logs prob looks like torch.Size([1287])\n", + "torch.from_numpy(rewards) looks like torch.Size([1287])\n", + "rewards looks like (1718,)\n", + "log_probs looks like (1718,)\n", + "logs prob looks like torch.Size([1718])\n", + "torch.from_numpy(rewards) looks like torch.Size([1718])\n", + "rewards looks like (1859,)\n", + "log_probs looks like (1859,)\n", + "logs prob looks like torch.Size([1859])\n", + "torch.from_numpy(rewards) looks like torch.Size([1859])\n", + "rewards looks like (1215,)\n", + "log_probs looks like (1215,)\n", + "logs prob looks like torch.Size([1215])\n", + "torch.from_numpy(rewards) looks like torch.Size([1215])\n", + "rewards looks like (1181,)\n", + "log_probs looks like (1181,)\n", + "logs prob looks like torch.Size([1181])\n", + "torch.from_numpy(rewards) looks like torch.Size([1181])\n", + "rewards looks like (1378,)\n", + "log_probs looks like (1378,)\n", + "logs prob looks like torch.Size([1378])\n", + "torch.from_numpy(rewards) looks like torch.Size([1378])\n", + "rewards looks like (1851,)\n", + "log_probs looks like (1851,)\n", + "logs prob looks like torch.Size([1851])\n", + "torch.from_numpy(rewards) looks like torch.Size([1851])\n", + "rewards looks like (2218,)\n", + "log_probs looks like (2218,)\n", + "logs prob looks like torch.Size([2218])\n", + "torch.from_numpy(rewards) looks like torch.Size([2218])\n", + "rewards looks like (2502,)\n", + "log_probs looks like (2502,)\n", + "logs prob looks like torch.Size([2502])\n", + "torch.from_numpy(rewards) looks like torch.Size([2502])\n", + "rewards looks like (1642,)\n", + "log_probs looks like (1642,)\n", + "logs prob looks like torch.Size([1642])\n", + "torch.from_numpy(rewards) looks like torch.Size([1642])\n", + "rewards looks like (1892,)\n", + "log_probs looks like (1892,)\n", + "logs prob looks like torch.Size([1892])\n", + "torch.from_numpy(rewards) looks like torch.Size([1892])\n", + "rewards looks like (2003,)\n", + "log_probs looks like (2003,)\n", + "logs prob looks like torch.Size([2003])\n", + "torch.from_numpy(rewards) looks like torch.Size([2003])\n", + "rewards looks like (3407,)\n", + "log_probs looks like (3407,)\n", + "logs prob looks like torch.Size([3407])\n", + "torch.from_numpy(rewards) looks like torch.Size([3407])\n", + "rewards looks like (3425,)\n", + "log_probs looks like (3425,)\n", + "logs prob looks like torch.Size([3425])\n", + "torch.from_numpy(rewards) looks like torch.Size([3425])\n", + "rewards looks like (1840,)\n", + "log_probs looks like (1840,)\n", + "logs prob looks like torch.Size([1840])\n", + "torch.from_numpy(rewards) looks like torch.Size([1840])\n", + "rewards looks like (1529,)\n", + "log_probs looks like (1529,)\n", + "logs prob looks like torch.Size([1529])\n", + "torch.from_numpy(rewards) looks like torch.Size([1529])\n", + "rewards looks like (1407,)\n", + "log_probs looks like (1407,)\n", + "logs prob looks like torch.Size([1407])\n", + "torch.from_numpy(rewards) looks like torch.Size([1407])\n", + "rewards looks like (2541,)\n", + "log_probs looks like (2541,)\n", + "logs prob looks like torch.Size([2541])\n", + "torch.from_numpy(rewards) looks like torch.Size([2541])\n", + "rewards looks like (1194,)\n", + "log_probs looks like (1194,)\n", + "logs prob looks like torch.Size([1194])\n", + "torch.from_numpy(rewards) looks like torch.Size([1194])\n", + "rewards looks like (1431,)\n", + "log_probs looks like (1431,)\n", + "logs prob looks like torch.Size([1431])\n", + "torch.from_numpy(rewards) looks like torch.Size([1431])\n", + "rewards looks like (3340,)\n", + "log_probs looks like (3340,)\n", + "logs prob looks like torch.Size([3340])\n", + "torch.from_numpy(rewards) looks like torch.Size([3340])\n", + "rewards looks like (897,)\n", + "log_probs looks like (897,)\n", + "logs prob looks like torch.Size([897])\n", + "torch.from_numpy(rewards) looks like torch.Size([897])\n", + "rewards looks like (1821,)\n", + "log_probs looks like (1821,)\n", + "logs prob looks like torch.Size([1821])\n", + "torch.from_numpy(rewards) looks like torch.Size([1821])\n", + "rewards looks like (1906,)\n", + "log_probs looks like (1906,)\n", + "logs prob looks like torch.Size([1906])\n", + "torch.from_numpy(rewards) looks like torch.Size([1906])\n", + "rewards looks like (2688,)\n", + "log_probs looks like (2688,)\n", + "logs prob looks like torch.Size([2688])\n", + "torch.from_numpy(rewards) looks like torch.Size([2688])\n", + "rewards looks like (1169,)\n", + "log_probs looks like (1169,)\n", + "logs prob looks like torch.Size([1169])\n", + "torch.from_numpy(rewards) looks like torch.Size([1169])\n", + "rewards looks like (1444,)\n", + "log_probs looks like (1444,)\n", + "logs prob looks like torch.Size([1444])\n", + "torch.from_numpy(rewards) looks like torch.Size([1444])\n", + "rewards looks like (1376,)\n", + "log_probs looks like (1376,)\n", + "logs prob looks like torch.Size([1376])\n", + "torch.from_numpy(rewards) looks like torch.Size([1376])\n", + "rewards looks like (1395,)\n", + "log_probs looks like (1395,)\n", + "logs prob looks like torch.Size([1395])\n", + "torch.from_numpy(rewards) looks like torch.Size([1395])\n", + "rewards looks like (899,)\n", + "log_probs looks like (899,)\n", + "logs prob looks like torch.Size([899])\n", + "torch.from_numpy(rewards) looks like torch.Size([899])\n", + "rewards looks like (2152,)\n", + "log_probs looks like (2152,)\n", + "logs prob looks like torch.Size([2152])\n", + "torch.from_numpy(rewards) looks like torch.Size([2152])\n", + "rewards looks like (2294,)\n", + "log_probs looks like (2294,)\n", + "logs prob looks like torch.Size([2294])\n", + "torch.from_numpy(rewards) looks like torch.Size([2294])\n", + "rewards looks like (881,)\n", + "log_probs looks like (881,)\n", + "logs prob looks like torch.Size([881])\n", + "torch.from_numpy(rewards) looks like torch.Size([881])\n", + "rewards looks like (1050,)\n", + "log_probs looks like (1050,)\n", + "logs prob looks like torch.Size([1050])\n", + "torch.from_numpy(rewards) looks like torch.Size([1050])\n", + "rewards looks like (1294,)\n", + "log_probs looks like (1294,)\n", + "logs prob looks like torch.Size([1294])\n", + "torch.from_numpy(rewards) looks like torch.Size([1294])\n", + "rewards looks like (1419,)\n", + "log_probs looks like (1419,)\n", + "logs prob looks like torch.Size([1419])\n", + "torch.from_numpy(rewards) looks like torch.Size([1419])\n", + "rewards looks like (1433,)\n", + "log_probs looks like (1433,)\n", + "logs prob looks like torch.Size([1433])\n", + "torch.from_numpy(rewards) looks like torch.Size([1433])\n", + "rewards looks like (2196,)\n", + "log_probs looks like (2196,)\n", + "logs prob looks like torch.Size([2196])\n", + "torch.from_numpy(rewards) looks like torch.Size([2196])\n", + "rewards looks like (1811,)\n", + "log_probs looks like (1811,)\n", + "logs prob looks like torch.Size([1811])\n", + "torch.from_numpy(rewards) looks like torch.Size([1811])\n", + "rewards looks like (1418,)\n", + "log_probs looks like (1418,)\n", + "logs prob looks like torch.Size([1418])\n", + "torch.from_numpy(rewards) looks like torch.Size([1418])\n", + "rewards looks like (1536,)\n", + "log_probs looks like (1536,)\n", + "logs prob looks like torch.Size([1536])\n", + "torch.from_numpy(rewards) looks like torch.Size([1536])\n", + "rewards looks like (1353,)\n", + "log_probs looks like (1353,)\n", + "logs prob looks like torch.Size([1353])\n", + "torch.from_numpy(rewards) looks like torch.Size([1353])\n", + "rewards looks like (1260,)\n", + "log_probs looks like (1260,)\n", + "logs prob looks like torch.Size([1260])\n", + "torch.from_numpy(rewards) looks like torch.Size([1260])\n", + "rewards looks like (1514,)\n", + "log_probs looks like (1514,)\n", + "logs prob looks like torch.Size([1514])\n", + "torch.from_numpy(rewards) looks like torch.Size([1514])\n", + "rewards looks like (2095,)\n", + "log_probs looks like (2095,)\n", + "logs prob looks like torch.Size([2095])\n", + "torch.from_numpy(rewards) looks like torch.Size([2095])\n", + "rewards looks like (1695,)\n", + "log_probs looks like (1695,)\n", + "logs prob looks like torch.Size([1695])\n", + "torch.from_numpy(rewards) looks like torch.Size([1695])\n", + "rewards looks like (2109,)\n", + "log_probs looks like (2109,)\n", + "logs prob looks like torch.Size([2109])\n", + "torch.from_numpy(rewards) looks like torch.Size([2109])\n", + "rewards looks like (967,)\n", + "log_probs looks like (967,)\n", + "logs prob looks like torch.Size([967])\n", + "torch.from_numpy(rewards) looks like torch.Size([967])\n", + "rewards looks like (1231,)\n", + "log_probs looks like (1231,)\n", + "logs prob looks like torch.Size([1231])\n", + "torch.from_numpy(rewards) looks like torch.Size([1231])\n", + "rewards looks like (1355,)\n", + "log_probs looks like (1355,)\n", + "logs prob looks like torch.Size([1355])\n", + "torch.from_numpy(rewards) looks like torch.Size([1355])\n", + "rewards looks like (1351,)\n", + "log_probs looks like (1351,)\n", + "logs prob looks like torch.Size([1351])\n", + "torch.from_numpy(rewards) looks like torch.Size([1351])\n", + "rewards looks like (1674,)\n", + "log_probs looks like (1674,)\n", + "logs prob looks like torch.Size([1674])\n", + "torch.from_numpy(rewards) looks like torch.Size([1674])\n", + "rewards looks like (2394,)\n", + "log_probs looks like (2394,)\n", + "logs prob looks like torch.Size([2394])\n", + "torch.from_numpy(rewards) looks like torch.Size([2394])\n", + "rewards looks like (2296,)\n", + "log_probs looks like (2296,)\n", + "logs prob looks like torch.Size([2296])\n", + "torch.from_numpy(rewards) looks like torch.Size([2296])\n", + "rewards looks like (897,)\n", + "log_probs looks like (897,)\n", + "logs prob looks like torch.Size([897])\n", + "torch.from_numpy(rewards) looks like torch.Size([897])\n", + "rewards looks like (2389,)\n", + "log_probs looks like (2389,)\n", + "logs prob looks like torch.Size([2389])\n", + "torch.from_numpy(rewards) looks like torch.Size([2389])\n", + "rewards looks like (1798,)\n", + "log_probs looks like (1798,)\n", + "logs prob looks like torch.Size([1798])\n", + "torch.from_numpy(rewards) looks like torch.Size([1798])\n", + "rewards looks like (1232,)\n", + "log_probs looks like (1232,)\n", + "logs prob looks like torch.Size([1232])\n", + "torch.from_numpy(rewards) looks like torch.Size([1232])\n", + "rewards looks like (1173,)\n", + "log_probs looks like (1173,)\n", + "logs prob looks like torch.Size([1173])\n", + "torch.from_numpy(rewards) looks like torch.Size([1173])\n", + "rewards looks like (1165,)\n", + "log_probs looks like (1165,)\n", + "logs prob looks like torch.Size([1165])\n", + "torch.from_numpy(rewards) looks like torch.Size([1165])\n", + "rewards looks like (1602,)\n", + "log_probs looks like (1602,)\n", + "logs prob looks like torch.Size([1602])\n", + "torch.from_numpy(rewards) looks like torch.Size([1602])\n", + "rewards looks like (1164,)\n", + "log_probs looks like (1164,)\n", + "logs prob looks like torch.Size([1164])\n", + "torch.from_numpy(rewards) looks like torch.Size([1164])\n", + "rewards looks like (2235,)\n", + "log_probs looks like (2235,)\n", + "logs prob looks like torch.Size([2235])\n", + "torch.from_numpy(rewards) looks like torch.Size([2235])\n", + "rewards looks like (1038,)\n", + "log_probs looks like (1038,)\n", + "logs prob looks like torch.Size([1038])\n", + "torch.from_numpy(rewards) looks like torch.Size([1038])\n", + "rewards looks like (1698,)\n", + "log_probs looks like (1698,)\n", + "logs prob looks like torch.Size([1698])\n", + "torch.from_numpy(rewards) looks like torch.Size([1698])\n", + "rewards looks like (1436,)\n", + "log_probs looks like (1436,)\n", + "logs prob looks like torch.Size([1436])\n", + "torch.from_numpy(rewards) looks like torch.Size([1436])\n", + "rewards looks like (1223,)\n", + "log_probs looks like (1223,)\n", + "logs prob looks like torch.Size([1223])\n", + "torch.from_numpy(rewards) looks like torch.Size([1223])\n", + "rewards looks like (2006,)\n", + "log_probs looks like (2006,)\n", + "logs prob looks like torch.Size([2006])\n", + "torch.from_numpy(rewards) looks like torch.Size([2006])\n", + "rewards looks like (1162,)\n", + "log_probs looks like (1162,)\n", + "logs prob looks like torch.Size([1162])\n", + "torch.from_numpy(rewards) looks like torch.Size([1162])\n", + "rewards looks like (2239,)\n", + "log_probs looks like (2239,)\n", + "logs prob looks like torch.Size([2239])\n", + "torch.from_numpy(rewards) looks like torch.Size([2239])\n", + "rewards looks like (1104,)\n", + "log_probs looks like (1104,)\n", + "logs prob looks like torch.Size([1104])\n", + "torch.from_numpy(rewards) looks like torch.Size([1104])\n", + "rewards looks like (1389,)\n", + "log_probs looks like (1389,)\n", + "logs prob looks like torch.Size([1389])\n", + "torch.from_numpy(rewards) looks like torch.Size([1389])\n", + "rewards looks like (1419,)\n", + "log_probs looks like (1419,)\n", + "logs prob looks like torch.Size([1419])\n", + "torch.from_numpy(rewards) looks like torch.Size([1419])\n", + "rewards looks like (1526,)\n", + "log_probs looks like (1526,)\n", + "logs prob looks like torch.Size([1526])\n", + "torch.from_numpy(rewards) looks like torch.Size([1526])\n", + "rewards looks like (1618,)\n", + "log_probs looks like (1618,)\n", + "logs prob looks like torch.Size([1618])\n", + "torch.from_numpy(rewards) looks like torch.Size([1618])\n", + "rewards looks like (2276,)\n", + "log_probs looks like (2276,)\n", + "logs prob looks like torch.Size([2276])\n", + "torch.from_numpy(rewards) looks like torch.Size([2276])\n", + "rewards looks like (2973,)\n", + "log_probs looks like (2973,)\n", + "logs prob looks like torch.Size([2973])\n", + "torch.from_numpy(rewards) looks like torch.Size([2973])\n", + "rewards looks like (1418,)\n", + "log_probs looks like (1418,)\n", + "logs prob looks like torch.Size([1418])\n", + "torch.from_numpy(rewards) looks like torch.Size([1418])\n", + "rewards looks like (1273,)\n", + "log_probs looks like (1273,)\n", + "logs prob looks like torch.Size([1273])\n", + "torch.from_numpy(rewards) looks like torch.Size([1273])\n", + "rewards looks like (2355,)\n", + "log_probs looks like (2355,)\n", + "logs prob looks like torch.Size([2355])\n", + "torch.from_numpy(rewards) looks like torch.Size([2355])\n", + "rewards looks like (1308,)\n", + "log_probs looks like (1308,)\n", + "logs prob looks like torch.Size([1308])\n", + "torch.from_numpy(rewards) looks like torch.Size([1308])\n", + "rewards looks like (1403,)\n", + "log_probs looks like (1403,)\n", + "logs prob looks like torch.Size([1403])\n", + "torch.from_numpy(rewards) looks like torch.Size([1403])\n", + "rewards looks like (1794,)\n", + "log_probs looks like (1794,)\n", + "logs prob looks like torch.Size([1794])\n", + "torch.from_numpy(rewards) looks like torch.Size([1794])\n", + "rewards looks like (1101,)\n", + "log_probs looks like (1101,)\n", + "logs prob looks like torch.Size([1101])\n", + "torch.from_numpy(rewards) looks like torch.Size([1101])\n", + "rewards looks like (1165,)\n", + "log_probs looks like (1165,)\n", + "logs prob looks like torch.Size([1165])\n", + "torch.from_numpy(rewards) looks like torch.Size([1165])\n", + "rewards looks like (1162,)\n", + "log_probs looks like (1162,)\n", + "logs prob looks like torch.Size([1162])\n", + "torch.from_numpy(rewards) looks like torch.Size([1162])\n", + "rewards looks like (1317,)\n", + "log_probs looks like (1317,)\n", + "logs prob looks like torch.Size([1317])\n", + "torch.from_numpy(rewards) looks like torch.Size([1317])\n", + "rewards looks like (993,)\n", + "log_probs looks like (993,)\n", + "logs prob looks like torch.Size([993])\n", + "torch.from_numpy(rewards) looks like torch.Size([993])\n", + "rewards looks like (2078,)\n", + "log_probs looks like (2078,)\n", + "logs prob looks like torch.Size([2078])\n", + "torch.from_numpy(rewards) looks like torch.Size([2078])\n", + "rewards looks like (1419,)\n", + "log_probs looks like (1419,)\n", + "logs prob looks like torch.Size([1419])\n", + "torch.from_numpy(rewards) looks like torch.Size([1419])\n", + "rewards looks like (1354,)\n", + "log_probs looks like (1354,)\n", + "logs prob looks like torch.Size([1354])\n", + "torch.from_numpy(rewards) looks like torch.Size([1354])\n", + "rewards looks like (1216,)\n", + "log_probs looks like (1216,)\n", + "logs prob looks like torch.Size([1216])\n", + "torch.from_numpy(rewards) looks like torch.Size([1216])\n", + "rewards looks like (1661,)\n", + "log_probs looks like (1661,)\n", + "logs prob looks like torch.Size([1661])\n", + "torch.from_numpy(rewards) looks like torch.Size([1661])\n", + "rewards looks like (2095,)\n", + "log_probs looks like (2095,)\n", + "logs prob looks like torch.Size([2095])\n", + "torch.from_numpy(rewards) looks like torch.Size([2095])\n", + "rewards looks like (2455,)\n", + "log_probs looks like (2455,)\n", + "logs prob looks like torch.Size([2455])\n", + "torch.from_numpy(rewards) looks like torch.Size([2455])\n", + "rewards looks like (2383,)\n", + "log_probs looks like (2383,)\n", + "logs prob looks like torch.Size([2383])\n", + "torch.from_numpy(rewards) looks like torch.Size([2383])\n", + "rewards looks like (2222,)\n", + "log_probs looks like (2222,)\n", + "logs prob looks like torch.Size([2222])\n", + "torch.from_numpy(rewards) looks like torch.Size([2222])\n", + "rewards looks like (2269,)\n", + "log_probs looks like (2269,)\n", + "logs prob looks like torch.Size([2269])\n", + "torch.from_numpy(rewards) looks like torch.Size([2269])\n", + "rewards looks like (2995,)\n", + "log_probs looks like (2995,)\n", + "logs prob looks like torch.Size([2995])\n", + "torch.from_numpy(rewards) looks like torch.Size([2995])\n", + "rewards looks like (1474,)\n", + "log_probs looks like (1474,)\n", + "logs prob looks like torch.Size([1474])\n", + "torch.from_numpy(rewards) looks like torch.Size([1474])\n", + "rewards looks like (2666,)\n", + "log_probs looks like (2666,)\n", + "logs prob looks like torch.Size([2666])\n", + "torch.from_numpy(rewards) looks like torch.Size([2666])\n", + "rewards looks like (1386,)\n", + "log_probs looks like (1386,)\n", + "logs prob looks like torch.Size([1386])\n", + "torch.from_numpy(rewards) looks like torch.Size([1386])\n", + "rewards looks like (2039,)\n", + "log_probs looks like (2039,)\n", + "logs prob looks like torch.Size([2039])\n", + "torch.from_numpy(rewards) looks like torch.Size([2039])\n", + "rewards looks like (2172,)\n", + "log_probs looks like (2172,)\n", + "logs prob looks like torch.Size([2172])\n", + "torch.from_numpy(rewards) looks like torch.Size([2172])\n", + "rewards looks like (2070,)\n", + "log_probs looks like (2070,)\n", + "logs prob looks like torch.Size([2070])\n", + "torch.from_numpy(rewards) looks like torch.Size([2070])\n", + "rewards looks like (2534,)\n", + "log_probs looks like (2534,)\n", + "logs prob looks like torch.Size([2534])\n", + "torch.from_numpy(rewards) looks like torch.Size([2534])\n", + "rewards looks like (1660,)\n", + "log_probs looks like (1660,)\n", + "logs prob looks like torch.Size([1660])\n", + "torch.from_numpy(rewards) looks like torch.Size([1660])\n", + "rewards looks like (1406,)\n", + "log_probs looks like (1406,)\n", + "logs prob looks like torch.Size([1406])\n", + "torch.from_numpy(rewards) looks like torch.Size([1406])\n", + "rewards looks like (1472,)\n", + "log_probs looks like (1472,)\n", + "logs prob looks like torch.Size([1472])\n", + "torch.from_numpy(rewards) looks like torch.Size([1472])\n", + "rewards looks like (2711,)\n", + "log_probs looks like (2711,)\n", + "logs prob looks like torch.Size([2711])\n", + "torch.from_numpy(rewards) looks like torch.Size([2711])\n", + "rewards looks like (1529,)\n", + "log_probs looks like (1529,)\n", + "logs prob looks like torch.Size([1529])\n", + "torch.from_numpy(rewards) looks like torch.Size([1529])\n", + "rewards looks like (1867,)\n", + "log_probs looks like (1867,)\n", + "logs prob looks like torch.Size([1867])\n", + "torch.from_numpy(rewards) looks like torch.Size([1867])\n", + "rewards looks like (1218,)\n", + "log_probs looks like (1218,)\n", + "logs prob looks like torch.Size([1218])\n", + "torch.from_numpy(rewards) looks like torch.Size([1218])\n", + "rewards looks like (1345,)\n", + "log_probs looks like (1345,)\n", + "logs prob looks like torch.Size([1345])\n", + "torch.from_numpy(rewards) looks like torch.Size([1345])\n", + "rewards looks like (1188,)\n", + "log_probs looks like (1188,)\n", + "logs prob looks like torch.Size([1188])\n", + "torch.from_numpy(rewards) looks like torch.Size([1188])\n", + "rewards looks like (1945,)\n", + "log_probs looks like (1945,)\n", + "logs prob looks like torch.Size([1945])\n", + "torch.from_numpy(rewards) looks like torch.Size([1945])\n", + "rewards looks like (987,)\n", + "log_probs looks like (987,)\n", + "logs prob looks like torch.Size([987])\n", + "torch.from_numpy(rewards) looks like torch.Size([987])\n", + "rewards looks like (2017,)\n", + "log_probs looks like (2017,)\n", + "logs prob looks like torch.Size([2017])\n", + "torch.from_numpy(rewards) looks like torch.Size([2017])\n", + "rewards looks like (2001,)\n", + "log_probs looks like (2001,)\n", + "logs prob looks like torch.Size([2001])\n", + "torch.from_numpy(rewards) looks like torch.Size([2001])\n", + "rewards looks like (1335,)\n", + "log_probs looks like (1335,)\n", + "logs prob looks like torch.Size([1335])\n", + "torch.from_numpy(rewards) looks like torch.Size([1335])\n", + "rewards looks like (1343,)\n", + "log_probs looks like (1343,)\n", + "logs prob looks like torch.Size([1343])\n", + "torch.from_numpy(rewards) looks like torch.Size([1343])\n", + "rewards looks like (2834,)\n", + "log_probs looks like (2834,)\n", + "logs prob looks like torch.Size([2834])\n", + "torch.from_numpy(rewards) looks like torch.Size([2834])\n", + "rewards looks like (1391,)\n", + "log_probs looks like (1391,)\n", + "logs prob looks like torch.Size([1391])\n", + "torch.from_numpy(rewards) looks like torch.Size([1391])\n", + "rewards looks like (1852,)\n", + "log_probs looks like (1852,)\n", + "logs prob looks like torch.Size([1852])\n", + "torch.from_numpy(rewards) looks like torch.Size([1852])\n", + "rewards looks like (1256,)\n", + "log_probs looks like (1256,)\n", + "logs prob looks like torch.Size([1256])\n", + "torch.from_numpy(rewards) looks like torch.Size([1256])\n", + "rewards looks like (1184,)\n", + "log_probs looks like (1184,)\n", + "logs prob looks like torch.Size([1184])\n", + "torch.from_numpy(rewards) looks like torch.Size([1184])\n", + "rewards looks like (1939,)\n", + "log_probs looks like (1939,)\n", + "logs prob looks like torch.Size([1939])\n", + "torch.from_numpy(rewards) looks like torch.Size([1939])\n", + "rewards looks like (1274,)\n", + "log_probs looks like (1274,)\n", + "logs prob looks like torch.Size([1274])\n", + "torch.from_numpy(rewards) looks like torch.Size([1274])\n", + "rewards looks like (1367,)\n", + "log_probs looks like (1367,)\n", + "logs prob looks like torch.Size([1367])\n", + "torch.from_numpy(rewards) looks like torch.Size([1367])\n", + "rewards looks like (1284,)\n", + "log_probs looks like (1284,)\n", + "logs prob looks like torch.Size([1284])\n", + "torch.from_numpy(rewards) looks like torch.Size([1284])\n", + "rewards looks like (1127,)\n", + "log_probs looks like (1127,)\n", + "logs prob looks like torch.Size([1127])\n", + "torch.from_numpy(rewards) looks like torch.Size([1127])\n", + "rewards looks like (1298,)\n", + "log_probs looks like (1298,)\n", + "logs prob looks like torch.Size([1298])\n", + "torch.from_numpy(rewards) looks like torch.Size([1298])\n", + "rewards looks like (1638,)\n", + "log_probs looks like (1638,)\n", + "logs prob looks like torch.Size([1638])\n", + "torch.from_numpy(rewards) looks like torch.Size([1638])\n", + "rewards looks like (1144,)\n", + "log_probs looks like (1144,)\n", + "logs prob looks like torch.Size([1144])\n", + "torch.from_numpy(rewards) looks like torch.Size([1144])\n", + "rewards looks like (1370,)\n", + "log_probs looks like (1370,)\n", + "logs prob looks like torch.Size([1370])\n", + "torch.from_numpy(rewards) looks like torch.Size([1370])\n", + "rewards looks like (1835,)\n", + "log_probs looks like (1835,)\n", + "logs prob looks like torch.Size([1835])\n", + "torch.from_numpy(rewards) looks like torch.Size([1835])\n", + "rewards looks like (2149,)\n", + "log_probs looks like (2149,)\n", + "logs prob looks like torch.Size([2149])\n", + "torch.from_numpy(rewards) looks like torch.Size([2149])\n", + "rewards looks like (1033,)\n", + "log_probs looks like (1033,)\n", + "logs prob looks like torch.Size([1033])\n", + "torch.from_numpy(rewards) looks like torch.Size([1033])\n", + "rewards looks like (989,)\n", + "log_probs looks like (989,)\n", + "logs prob looks like torch.Size([989])\n", + "torch.from_numpy(rewards) looks like torch.Size([989])\n", + "rewards looks like (1900,)\n", + "log_probs looks like (1900,)\n", + "logs prob looks like torch.Size([1900])\n", + "torch.from_numpy(rewards) looks like torch.Size([1900])\n", + "rewards looks like (1706,)\n", + "log_probs looks like (1706,)\n", + "logs prob looks like torch.Size([1706])\n", + "torch.from_numpy(rewards) looks like torch.Size([1706])\n", + "rewards looks like (1235,)\n", + "log_probs looks like (1235,)\n", + "logs prob looks like torch.Size([1235])\n", + "torch.from_numpy(rewards) looks like torch.Size([1235])\n", + "rewards looks like (2693,)\n", + "log_probs looks like (2693,)\n", + "logs prob looks like torch.Size([2693])\n", + "torch.from_numpy(rewards) looks like torch.Size([2693])\n", + "rewards looks like (1021,)\n", + "log_probs looks like (1021,)\n", + "logs prob looks like torch.Size([1021])\n", + "torch.from_numpy(rewards) looks like torch.Size([1021])\n", + "rewards looks like (1126,)\n", + "log_probs looks like (1126,)\n", + "logs prob looks like torch.Size([1126])\n", + "torch.from_numpy(rewards) looks like torch.Size([1126])\n", + "rewards looks like (1334,)\n", + "log_probs looks like (1334,)\n", + "logs prob looks like torch.Size([1334])\n", + "torch.from_numpy(rewards) looks like torch.Size([1334])\n", + "rewards looks like (1337,)\n", + "log_probs looks like (1337,)\n", + "logs prob looks like torch.Size([1337])\n", + "torch.from_numpy(rewards) looks like torch.Size([1337])\n", + "rewards looks like (1502,)\n", + "log_probs looks like (1502,)\n", + "logs prob looks like torch.Size([1502])\n", + "torch.from_numpy(rewards) looks like torch.Size([1502])\n", + "rewards looks like (2059,)\n", + "log_probs looks like (2059,)\n", + "logs prob looks like torch.Size([2059])\n", + "torch.from_numpy(rewards) looks like torch.Size([2059])\n", + "rewards looks like (2057,)\n", + "log_probs looks like (2057,)\n", + "logs prob looks like torch.Size([2057])\n", + "torch.from_numpy(rewards) looks like torch.Size([2057])\n", + "rewards looks like (1300,)\n", + "log_probs looks like (1300,)\n", + "logs prob looks like torch.Size([1300])\n", + "torch.from_numpy(rewards) looks like torch.Size([1300])\n", + "rewards looks like (3078,)\n", + "log_probs looks like (3078,)\n", + "logs prob looks like torch.Size([3078])\n", + "torch.from_numpy(rewards) looks like torch.Size([3078])\n", + "rewards looks like (1724,)\n", + "log_probs looks like (1724,)\n", + "logs prob looks like torch.Size([1724])\n", + "torch.from_numpy(rewards) looks like torch.Size([1724])\n", + "rewards looks like (1468,)\n", + "log_probs looks like (1468,)\n", + "logs prob looks like torch.Size([1468])\n", + "torch.from_numpy(rewards) looks like torch.Size([1468])\n", + "rewards looks like (2674,)\n", + "log_probs looks like (2674,)\n", + "logs prob looks like torch.Size([2674])\n", + "torch.from_numpy(rewards) looks like torch.Size([2674])\n", + "rewards looks like (1376,)\n", + "log_probs looks like (1376,)\n", + "logs prob looks like torch.Size([1376])\n", + "torch.from_numpy(rewards) looks like torch.Size([1376])\n", + "rewards looks like (1564,)\n", + "log_probs looks like (1564,)\n", + "logs prob looks like torch.Size([1564])\n", + "torch.from_numpy(rewards) looks like torch.Size([1564])\n", + "rewards looks like (1452,)\n", + "log_probs looks like (1452,)\n", + "logs prob looks like torch.Size([1452])\n", + "torch.from_numpy(rewards) looks like torch.Size([1452])\n", + "rewards looks like (1205,)\n", + "log_probs looks like (1205,)\n", + "logs prob looks like torch.Size([1205])\n", + "torch.from_numpy(rewards) looks like torch.Size([1205])\n", + "rewards looks like (1520,)\n", + "log_probs looks like (1520,)\n", + "logs prob looks like torch.Size([1520])\n", + "torch.from_numpy(rewards) looks like torch.Size([1520])\n", + "rewards looks like (1099,)\n", + "log_probs looks like (1099,)\n", + "logs prob looks like torch.Size([1099])\n", + "torch.from_numpy(rewards) looks like torch.Size([1099])\n", + "rewards looks like (1506,)\n", + "log_probs looks like (1506,)\n", + "logs prob looks like torch.Size([1506])\n", + "torch.from_numpy(rewards) looks like torch.Size([1506])\n", + "rewards looks like (1175,)\n", + "log_probs looks like (1175,)\n", + "logs prob looks like torch.Size([1175])\n", + "torch.from_numpy(rewards) looks like torch.Size([1175])\n", + "rewards looks like (1251,)\n", + "log_probs looks like (1251,)\n", + "logs prob looks like torch.Size([1251])\n", + "torch.from_numpy(rewards) looks like torch.Size([1251])\n", + "rewards looks like (1318,)\n", + "log_probs looks like (1318,)\n", + "logs prob looks like torch.Size([1318])\n", + "torch.from_numpy(rewards) looks like torch.Size([1318])\n", + "rewards looks like (1446,)\n", + "log_probs looks like (1446,)\n", + "logs prob looks like torch.Size([1446])\n", + "torch.from_numpy(rewards) looks like torch.Size([1446])\n", + "rewards looks like (1220,)\n", + "log_probs looks like (1220,)\n", + "logs prob looks like torch.Size([1220])\n", + "torch.from_numpy(rewards) looks like torch.Size([1220])\n", + "rewards looks like (1343,)\n", + "log_probs looks like (1343,)\n", + "logs prob looks like torch.Size([1343])\n", + "torch.from_numpy(rewards) looks like torch.Size([1343])\n", + "rewards looks like (1186,)\n", + "log_probs looks like (1186,)\n", + "logs prob looks like torch.Size([1186])\n", + "torch.from_numpy(rewards) looks like torch.Size([1186])\n", + "rewards looks like (1443,)\n", + "log_probs looks like (1443,)\n", + "logs prob looks like torch.Size([1443])\n", + "torch.from_numpy(rewards) looks like torch.Size([1443])\n", + "rewards looks like (1212,)\n", + "log_probs looks like (1212,)\n", + "logs prob looks like torch.Size([1212])\n", + "torch.from_numpy(rewards) looks like torch.Size([1212])\n", + "rewards looks like (1346,)\n", + "log_probs looks like (1346,)\n", + "logs prob looks like torch.Size([1346])\n", + "torch.from_numpy(rewards) looks like torch.Size([1346])\n", + "rewards looks like (2124,)\n", + "log_probs looks like (2124,)\n", + "logs prob looks like torch.Size([2124])\n", + "torch.from_numpy(rewards) looks like torch.Size([2124])\n", + "rewards looks like (1461,)\n", + "log_probs looks like (1461,)\n", + "logs prob looks like torch.Size([1461])\n", + "torch.from_numpy(rewards) looks like torch.Size([1461])\n", + "rewards looks like (1425,)\n", + "log_probs looks like (1425,)\n", + "logs prob looks like torch.Size([1425])\n", + "torch.from_numpy(rewards) looks like torch.Size([1425])\n", + "rewards looks like (1457,)\n", + "log_probs looks like (1457,)\n", + "logs prob looks like torch.Size([1457])\n", + "torch.from_numpy(rewards) looks like torch.Size([1457])\n", + "rewards looks like (1223,)\n", + "log_probs looks like (1223,)\n", + "logs prob looks like torch.Size([1223])\n", + "torch.from_numpy(rewards) looks like torch.Size([1223])\n", + "rewards looks like (1310,)\n", + "log_probs looks like (1310,)\n", + "logs prob looks like torch.Size([1310])\n", + "torch.from_numpy(rewards) looks like torch.Size([1310])\n", + "rewards looks like (2446,)\n", + "log_probs looks like (2446,)\n", + "logs prob looks like torch.Size([2446])\n", + "torch.from_numpy(rewards) looks like torch.Size([2446])\n", + "\n" + ], + "name": "stdout" + } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "vNb_tuFYhKVK" + }, + "source": [ + "### 訓練結果\n", + "\n", + "訓練過程中,我們持續記下了 `avg_total_reward`,這個數值代表的是:每次更新 policy network 前,我們讓 agent 玩數個回合(episodes),而這些回合的平均 total rewards 為何。\n", + "理論上,若是 agent 一直在進步,則所得到的 `avg_total_reward` 也會持續上升,直至 250 上下。\n", + "若將其畫出來則結果如下:" + ] + }, + { + "cell_type": "code", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 281 + }, + "id": "wZYOI8H10SHN", + "outputId": "80307382-3743-4f70-e08a-66c5e92451da" + }, + "source": [ + "end = time.time()\n", + "plt.plot(avg_total_rewards)\n", + "plt.title(\"Total Rewards\")\n", + "plt.show()" + ], + "execution_count": null, + "outputs": [ + { + "output_type": "display_data", + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "tags": [], + "needs_background": "light" + } + } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "mV5jj4dThz0Y" + }, + "source": [ + "另外,`avg_final_reward` 代表的是多個回合的平均 final rewards,而 final reward 即是 agent 在單一回合中拿到的最後一個 reward。\n", + "如果同學們還記得環境給予登月小艇 reward 的方式,便會知道,不論**回合的最後**小艇是不幸墜毀、飛出畫面、或是靜止在地面上,都會受到額外地獎勵或處罰。\n", + "也因此,final reward 可被用來觀察 agent 的「著地」是否順利等資訊。" + ] + }, + { + "cell_type": "code", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 281 + }, + "id": "txDZ5vlGWz5w", + "outputId": "bc284774-255a-45ac-dabf-3dfb5e1e5565" + }, + "source": [ + "plt.plot(avg_final_rewards)\n", + "plt.title(\"Final Rewards\")\n", + "plt.show()\n" + ], + "execution_count": null, + "outputs": [ + { + "output_type": "display_data", + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "tags": [], + "needs_background": "light" + } + } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "gyT7tNwkVdS-" + }, + "source": [ + "訓練時間\n" + ] + }, + { + "cell_type": "code", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "_t-JsKxUViFy", + "outputId": "333aa287-0455-4028-b91c-f83c8d2e1b57" + }, + "source": [ + "print(f\"total time is {end-start} sec\")" + ], + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "text": [ + "total time is 674.2419369220734 sec\n" + ], + "name": "stdout" + } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "u2HaGRVEYGQS" + }, + "source": [ + "## 測試" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "5yFuUKKRYH73", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 500 + }, + "outputId": "7901d4d3-a71b-468e-a12e-6bd9edff551e" + }, + "source": [ + "fix(env, seed)\n", + "agent.network.eval() # 測試前先將 network 切換為 evaluation 模式\n", + "NUM_OF_TEST = 5 # Do not revise it !!!!!\n", + "test_total_reward = []\n", + "action_list = []\n", + "for i in range(NUM_OF_TEST):\n", + " actions = []\n", + " state = env.reset()\n", + "\n", + " img = plt.imshow(env.render(mode='rgb_array'))\n", + "\n", + " total_reward = 0\n", + "\n", + " done = False\n", + " while not done:\n", + " action, _ = agent.sample(state)\n", + " actions.append(action)\n", + " state, reward, done, _ = env.step(action)\n", + "\n", + " total_reward += reward\n", + "\n", + " #img.set_data(env.render(mode='rgb_array'))\n", + " #display.display(plt.gcf())\n", + " #display.clear_output(wait=True)\n", + " print(total_reward)\n", + " test_total_reward.append(total_reward)\n", + "\n", + " action_list.append(actions) #儲存你測試的結果\n", + " print(\"length of actions is \", len(actions))" + ], + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "text": [ + "/usr/local/lib/python3.7/dist-packages/torch/__init__.py:422: UserWarning: torch.set_deterministic is deprecated and will be removed in a future release. Please use torch.use_deterministic_algorithms instead\n", + " \"torch.set_deterministic is deprecated and will be removed in a future \"\n" + ], + "name": "stderr" + }, + { + "output_type": "stream", + "text": [ + "260.62265430635034\n", + "length of actions is 299\n", + "-212.89375915819693\n", + "length of actions is 319\n", + "11.862808485612831\n", + "length of actions is 241\n", + "8.015383611389638\n", + "length of actions is 231\n", + "-219.21903722619058\n", + "length of actions is 256\n" + ], + "name": "stdout" + }, + { + "output_type": "display_data", + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "tags": [], + "needs_background": "light" + } + } + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "Aex7mcKr0J01", + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "f1706a79-1fbd-4d61-bdcd-ab257cb152e5" + }, + "source": [ + "print(f\"Your final reward is : %.2f\"%np.mean(test_total_reward))" + ], + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "text": [ + "Your final reward is : -30.32\n" + ], + "name": "stdout" + } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "leyebGYRpqsF" + }, + "source": [ + "Action list 的長相" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "hGAH4YWDpp4u", + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "c7f5fa21-7b7a-43a8-8478-df76dce7a4ad" + }, + "source": [ + "print(\"Action list looks like \", action_list)\n", + "print(\"Action list's shape looks like \", np.shape(action_list))" + ], + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "text": [ + "Action list looks like [[1, 1, 1, 1, 2, 2, 1, 2, 1, 2, 1, 2, 2, 2, 2, 1, 2, 2, 1, 2, 1, 2, 1, 2, 2, 2, 1, 1, 2, 2, 1, 2, 2, 1, 2, 3, 2, 2, 3, 2, 2, 3, 3, 3, 2, 2, 3, 2, 3, 3, 3, 3, 3, 3, 2, 2, 3, 3, 3, 3, 3, 2, 3, 2, 2, 3, 3, 3, 2, 2, 3, 2, 2, 3, 2, 3, 3, 2, 2, 2, 3, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 2, 2, 2, 2, 2, 2, 2, 2, 3, 2, 2, 1, 1, 2, 1, 1, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 2, 1, 1, 2, 1, 1, 2, 1, 1, 1, 2, 1, 1, 1, 2, 2, 1, 1, 2, 1, 2, 1, 2, 2, 2, 2, 2, 2, 1, 2, 2, 2, 1, 2, 1, 2, 2, 2, 2, 2, 2, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 2, 3, 3, 3, 3, 2, 3, 3, 3, 3, 3, 3, 3, 3, 2, 3, 3, 3, 2, 3, 3, 3, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 2, 3, 3, 2, 3, 2, 2, 2, 2, 3, 2, 2, 3, 2, 0, 1, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [2, 2, 2, 2, 2, 2, 2, 3, 2, 2, 3, 2, 0, 2, 2, 3, 2, 0, 3, 2, 2, 2, 3, 2, 2, 3, 2, 3, 2, 3, 2, 2, 2, 3, 2, 3, 2, 2, 0, 1, 2, 2, 2, 0, 1, 2, 2, 1, 2, 1, 2, 1, 1, 2, 2, 0, 2, 1, 2, 2, 1, 1, 1, 2, 2, 2, 1, 2, 2, 1, 2, 2, 1, 1, 2, 2, 2, 2, 2, 1, 2, 2, 0, 1, 0, 2, 2, 2, 2, 3, 3, 2, 3, 2, 3, 2, 3, 3, 2, 3, 2, 3, 3, 3, 2, 3, 2, 3, 3, 3, 2, 2, 3, 3, 3, 2, 2, 3, 2, 3, 2, 3, 2, 2, 2, 2, 2, 2, 3, 2, 3, 2, 3, 3, 3, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 2, 2, 1, 1, 1, 2, 1, 1, 2, 1, 1, 1, 1, 2, 1, 1, 2, 2, 1, 1, 1, 1, 1, 1, 2, 2, 1, 1, 2, 2, 1, 1, 1, 2, 2, 1, 2, 1, 1, 2, 1, 2, 2, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 2, 2, 3, 2, 2, 0, 3, 3, 3, 3, 3, 2, 3, 3, 3, 3, 3, 2, 3, 2, 3, 3, 3, 3, 2, 3, 3, 3, 3, 2, 2, 2, 3, 3, 3, 3, 3, 2, 2, 3, 3, 2, 2, 3, 2, 2, 2, 2, 3, 2, 2, 2, 3, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 2], [2, 1, 2, 2, 1, 1, 2, 2, 1, 1, 1, 2, 2, 2, 2, 2, 1, 2, 2, 2, 0, 2, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 2, 3, 2, 3, 3, 3, 3, 2, 3, 3, 2, 2, 2, 3, 3, 3, 3, 2, 2, 3, 3, 3, 3, 2, 3, 2, 2, 3, 2, 2, 3, 2, 2, 2, 2, 3, 2, 2, 2, 3, 2, 2, 2, 2, 3, 2, 2, 2, 3, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 1, 1, 1, 1, 2, 1, 1, 2, 1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 2, 2, 1, 1, 1, 2, 2, 1, 1, 2, 1, 2, 2, 1, 2, 2, 1, 1, 2, 2, 2, 2, 1, 2, 1, 2, 1, 2, 2, 2, 2, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 2, 3, 2, 3, 3, 3, 3, 3, 3, 2, 3, 2, 3, 3, 3, 3, 3, 3, 3, 2, 2, 3, 2, 3, 3, 3, 3, 3, 2, 3, 3, 3, 2, 2, 2, 2, 3, 3, 3, 2, 2, 3, 3, 2, 3, 3, 2, 2, 2, 2, 2], [1, 1, 1, 1, 2, 1, 2, 2, 1, 1, 2, 2, 1, 2, 2, 1, 2, 1, 2, 2, 2, 1, 2, 1, 2, 2, 2, 1, 2, 2, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 2, 3, 2, 3, 0, 2, 2, 3, 3, 2, 3, 3, 3, 3, 2, 3, 2, 3, 3, 3, 2, 3, 3, 2, 3, 3, 2, 3, 2, 2, 3, 3, 3, 2, 3, 2, 3, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 2, 3, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 2, 1, 2, 1, 1, 1, 2, 1, 1, 1, 2, 1, 2, 1, 1, 1, 1, 2, 1, 1, 2, 1, 1, 2, 2, 1, 1, 1, 2, 1, 1, 2, 2, 1, 2, 1, 1, 2, 1, 2, 2, 2, 1, 1, 2, 2, 1, 2, 2, 2, 2, 1, 1, 2, 2, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 2, 3, 3, 3, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 2, 3, 3, 2, 3, 3, 3, 2, 3, 2, 2, 3, 3, 3, 2, 3, 2, 3, 2, 3, 2, 2, 2, 2, 3, 2, 2, 2, 2], [1, 1, 0, 1, 0, 1, 1, 1, 1, 1, 2, 2, 0, 1, 2, 2, 0, 2, 3, 2, 3, 2, 3, 2, 3, 2, 2, 3, 2, 3, 3, 2, 3, 3, 2, 2, 3, 3, 3, 2, 2, 3, 2, 3, 2, 3, 3, 2, 3, 2, 2, 2, 3, 2, 3, 2, 2, 2, 2, 3, 2, 3, 2, 2, 2, 2, 3, 2, 2, 1, 1, 2, 1, 1, 2, 1, 1, 2, 2, 1, 2, 2, 1, 2, 2, 1, 1, 1, 1, 1, 1, 2, 2, 1, 1, 1, 2, 2, 1, 1, 2, 1, 1, 2, 1, 2, 2, 1, 2, 2, 2, 2, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 2, 1, 2, 2, 2, 2, 2, 3, 3, 2, 3, 3, 2, 3, 3, 3, 2, 3, 2, 3, 3, 3, 3, 3, 3, 3, 3, 2, 2, 3, 3, 3, 2, 2, 3, 3, 3, 2, 3, 2, 3, 3, 2, 2, 3, 3, 2, 3, 2, 2, 2, 3, 2, 2, 3, 2, 2, 2, 3, 2, 2, 2, 2, 2, 3, 3, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 1, 1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]\n", + "Action list's shape looks like (5,)\n" + ], + "name": "stdout" + }, + { + "output_type": "stream", + "text": [ + "/usr/local/lib/python3.7/dist-packages/numpy/core/_asarray.py:83: VisibleDeprecationWarning: Creating an ndarray from ragged nested sequences (which is a list-or-tuple of lists-or-tuples-or ndarrays with different lengths or shapes) is deprecated. If you meant to do this, you must specify 'dtype=object' when creating the ndarray\n", + " return array(a, dtype, copy=False, order=order)\n" + ], + "name": "stderr" + } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "l7sokqEUtrFY" + }, + "source": [ + "Action 的分布\n" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "WHdAItjj1nxw", + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "5129773b-1f4a-4085-d2bf-3bc2abc2598c" + }, + "source": [ + "distribution = {}\n", + "for actions in action_list:\n", + " for action in actions:\n", + " if action not in distribution.keys():\n", + " distribution[action] = 1\n", + " else:\n", + " distribution[action] += 1\n", + "print(distribution)" + ], + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "text": [ + "{1: 278, 2: 698, 3: 297, 0: 73}\n" + ], + "name": "stdout" + } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "ricE0schY75M" + }, + "source": [ + "儲存 Model Testing的結果\n" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "GZsMkGmIY42b", + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "8c55c932-4654-4f8c-f6b0-fa52ac3e8b96" + }, + "source": [ + "PATH = \"Action_List_test.npy\" # 可以改成你想取的名字或路徑\n", + "np.save(PATH ,np.array(action_list)) " + ], + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "text": [ + "/usr/local/lib/python3.7/dist-packages/ipykernel_launcher.py:2: VisibleDeprecationWarning: Creating an ndarray from ragged nested sequences (which is a list-or-tuple of lists-or-tuples-or ndarrays with different lengths or shapes) is deprecated. If you meant to do this, you must specify 'dtype=object' when creating the ndarray\n", + " \n" + ], + "name": "stderr" + } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "asK7WfbkaLjt" + }, + "source": [ + "### 你要交到JudgeBoi的檔案94這個\n", + "儲存結果到本地端 (就是你的電腦裡拉 = = )\n" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "c-CqyhHzaWAL", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 17 + }, + "outputId": "adfba5e6-a107-49aa-9f98-3c0655c5d6c2" + }, + "source": [ + "from google.colab import files\n", + "files.download(PATH)" + ], + "execution_count": null, + "outputs": [ + { + "output_type": "display_data", + "data": { + "application/javascript": [ + "\n", + " async function download(id, filename, size) {\n", + " if (!google.colab.kernel.accessAllowed) {\n", + " return;\n", + " }\n", + " const div = document.createElement('div');\n", + " const label = document.createElement('label');\n", + " label.textContent = `Downloading \"${filename}\": `;\n", + " div.appendChild(label);\n", + " const progress = document.createElement('progress');\n", + " progress.max = size;\n", + " div.appendChild(progress);\n", + " document.body.appendChild(div);\n", + "\n", + " const buffers = [];\n", + " let downloaded = 0;\n", + "\n", + " const channel = await google.colab.kernel.comms.open(id);\n", + " // Send a message to notify the kernel that we're ready.\n", + " channel.send({})\n", + "\n", + " for await (const message of channel.messages) {\n", + " // Send a message to notify the kernel that we're ready.\n", + " channel.send({})\n", + " if (message.buffers) {\n", + " for (const buffer of message.buffers) {\n", + " buffers.push(buffer);\n", + " downloaded += buffer.byteLength;\n", + " progress.value = downloaded;\n", + " }\n", + " }\n", + " }\n", + " const blob = new Blob(buffers, {type: 'application/binary'});\n", + " const a = document.createElement('a');\n", + " a.href = window.URL.createObjectURL(blob);\n", + " a.download = filename;\n", + " div.appendChild(a);\n", + " a.click();\n", + " div.remove();\n", + " }\n", + " " + ], + "text/plain": [ + "" + ] + }, + "metadata": { + "tags": [] + } + }, + { + "output_type": "display_data", + "data": { + "application/javascript": [ + "download(\"download_5d13b99b-295d-4ab0-814c-b2d0fff26eff\", \"Action_List_test.npy\", 2999)" + ], + "text/plain": [ + "" + ] + }, + "metadata": { + "tags": [] + } + } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "seT4NUmWmAZ1" + }, + "source": [ + "# Server 測試\n", + "到時候下面會是我們Server上測試的環境,可以給大家看一下自己的表現如何" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "U69c-YTxaw6b", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 412 + }, + "outputId": "50015892-29ae-4665-c66f-880aecf7be8f" + }, + "source": [ + "action_list = np.load(PATH,allow_pickle=True) #到時候你上傳的檔案\n", + "seed = 543 #到時候測試的seed 請不要更改\n", + "fix(env, seed)\n", + "\n", + "agent.network.eval() # 測試前先將 network 切換為 evaluation 模式\n", + "\n", + "test_total_reward = []\n", + "for actions in action_list:\n", + " state = env.reset()\n", + " img = plt.imshow(env.render(mode='rgb_array'))\n", + "\n", + " total_reward = 0\n", + "\n", + " done = False\n", + " # while not done:\n", + " done_count = 0\n", + " for action in actions:\n", + " # action, _ = agent1.sample(state)\n", + " state, reward, done, _ = env.step(action)\n", + " done_count += 1\n", + " total_reward += reward\n", + " if done:\n", + " \n", + " break\n", + " # img.set_data(env.render(mode='rgb_array'))\n", + " # display.display(plt.gcf())\n", + " # display.clear_output(wait=True)\n", + " print(f\"Your reward is : %.2f\"%total_reward)\n", + " test_total_reward.append(total_reward)" + ], + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "text": [ + "/usr/local/lib/python3.7/dist-packages/torch/__init__.py:422: UserWarning: torch.set_deterministic is deprecated and will be removed in a future release. Please use torch.use_deterministic_algorithms instead\n", + " \"torch.set_deterministic is deprecated and will be removed in a future \"\n" + ], + "name": "stderr" + }, + { + "output_type": "stream", + "text": [ + "Your reward is : 260.62\n", + "Your reward is : -212.89\n", + "Your reward is : 11.86\n", + "Your reward is : 8.02\n", + "Your reward is : -219.22\n" + ], + "name": "stdout" + }, + { + "output_type": "display_data", + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "tags": [], + "needs_background": "light" + } + } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "TjFBWwQP1hVe" + }, + "source": [ + "# 你的成績" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "GpJpZz3Wbm0X", + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "f1b08157-bec6-4c5a-8021-482f719b4ade" + }, + "source": [ + "print(f\"Your final reward is : %.2f\"%np.mean(test_total_reward))" + ], + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "text": [ + "Your final reward is : -30.32\n" + ], + "name": "stdout" + } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "wUBtYXG2eaqf" + }, + "source": [ + "## 參考資料\n", + "\n", + "以下是一些有用的參考資料。\n", + "建議同學們實做前,可以先參考第一則連結的上課影片。\n", + "在影片的最後有提到兩個有用的 Tips,這對於本次作業的實做非常有幫助。\n", + "\n", + "- [DRL Lecture 1: Policy Gradient (Review)](https://youtu.be/z95ZYgPgXOY)\n", + "- [ML Lecture 23-3: Reinforcement Learning (including Q-learning) start at 30:00](https://youtu.be/2-JNBzCq77c?t=1800)\n", + "- [Lecture 7: Policy Gradient, David Silver](http://www0.cs.ucl.ac.uk/staff/d.silver/web/Teaching_files/pg.pdf)\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "cGqP2EU1joWM" + }, + "source": [ + "" + ] + } + ] +} \ No newline at end of file diff --git a/范例/HW12/HW12_ZH.pdf b/范例/HW12/HW12_ZH.pdf new file mode 100644 index 0000000..e2c7481 Binary files /dev/null and b/范例/HW12/HW12_ZH.pdf differ