diff --git a/6_pytorch/0_basic/1-Tensor-and-Variable.ipynb b/6_pytorch/0_basic/1-Tensor-and-Variable.ipynb deleted file mode 100644 index fe0929a..0000000 --- a/6_pytorch/0_basic/1-Tensor-and-Variable.ipynb +++ /dev/null @@ -1,976 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Tensor and Variable\n", - "\n", - "\n", - "张量(Tensor)是一种专门的数据结构,非常类似于数组和矩阵。在PyTorch中,我们使用张量来编码模型的输入和输出,以及模型的参数。\n", - "\n", - "张量类似于`numpy`的`ndarray`,不同之处在于张量可以在GPU或其他硬件加速器上运行。事实上,张量和NumPy数组通常可以共享相同的底层内存,从而消除了复制数据的需要(请参阅使用NumPy的桥接)。张量还针对自动微分进行了优化,在Autograd部分中看到更多关于这一点的内介绍。\n", - "\n", - "`variable`是一种可以不断变化的变量,符合反向传播,参数更新的属性。PyTorch的`variable`是一个存放会变化值的内存位置,里面的值会不停变化,像装糖果(糖果就是数据,即tensor)的盒子,糖果的数量不断变化。pytorch都是由tensor计算的,而tensor里面的参数是variable形式。\n" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 1. Tensor基本用法\n", - "\n", - "PyTorch基础的数据是张量,PyTorch 的很多操作好 NumPy 都是类似的,但是因为其能够在 GPU 上运行,所以有着比 NumPy 快很多倍的速度。通过本次课程,能够学会如何像使用 NumPy 一样使用 PyTorch,了解到 PyTorch 中的基本元素 Tensor 和 Variable 及其操作方式。" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### 1.1 Tensor定义与生成" - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "metadata": {}, - "outputs": [], - "source": [ - "import torch\n", - "import numpy as np" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": {}, - "outputs": [], - "source": [ - "# 创建一个 numpy ndarray\n", - "numpy_tensor = np.random.randn(10, 20)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "我们可以使用下面两种方式将numpy的ndarray转换到tensor上" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": {}, - "outputs": [], - "source": [ - "pytorch_tensor1 = torch.Tensor(numpy_tensor)\n", - "pytorch_tensor2 = torch.from_numpy(numpy_tensor)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "使用以上两种方法进行转换的时候,会直接将 NumPy ndarray 的数据类型转换为对应的 PyTorch Tensor 数据类型" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "\n" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "同时我们也可以使用下面的方法将 pytorch tensor 转换为 numpy ndarray" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": {}, - "outputs": [], - "source": [ - "# 如果 pytorch tensor 在 cpu 上\n", - "numpy_array = pytorch_tensor1.numpy()\n", - "\n", - "# 如果 pytorch tensor 在 gpu 上\n", - "numpy_array = pytorch_tensor1.cpu().numpy()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "需要注意 GPU 上的 Tensor 不能直接转换为 NumPy ndarray,需要使用`.cpu()`先将 GPU 上的 Tensor 转到 CPU 上" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "\n" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### 1.2 PyTorch Tensor 使用 GPU 加速\n", - "\n", - "我们可以使用以下两种方式将 Tensor 放到 GPU 上" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "metadata": {}, - "outputs": [], - "source": [ - "# 第一种方式是定义 cuda 数据类型\n", - "dtype = torch.cuda.FloatTensor # 定义默认 GPU 的 数据类型\n", - "gpu_tensor = torch.randn(10, 20).type(dtype)\n", - "\n", - "# 第二种方式更简单,推荐使用\n", - "gpu_tensor = torch.randn(10, 20).cuda(0) # 将 tensor 放到第一个 GPU 上\n", - "gpu_tensor = torch.randn(10, 20).cuda(0) # 将 tensor 放到第二个 GPU 上" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "使用第一种方式将 tensor 放到 GPU 上的时候会将数据类型转换成定义的类型,而是用第二种方式能够直接将 tensor 放到 GPU 上,类型跟之前保持一致\n", - "\n", - "推荐在定义 tensor 的时候就明确数据类型,然后直接使用第二种方法将 tensor 放到 GPU 上" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "而将 tensor 放回 CPU 的操作非常简单" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "metadata": {}, - "outputs": [], - "source": [ - "cpu_tensor = gpu_tensor.cpu()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "我们也能够访问到 Tensor 的一些属性" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "torch.Size([10, 20])\n", - "torch.Size([10, 20])\n" - ] - } - ], - "source": [ - "# 可以通过下面两种方式得到 tensor 的大小\n", - "print(pytorch_tensor1.shape)\n", - "print(pytorch_tensor1.size())" - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "torch.FloatTensor\n", - "torch.cuda.FloatTensor\n" - ] - } - ], - "source": [ - "# 得到 tensor 的数据类型\n", - "print(pytorch_tensor1.type())\n", - "print(gpu_tensor.type())" - ] - }, - { - "cell_type": "code", - "execution_count": 11, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "2\n" - ] - } - ], - "source": [ - "# 得到 tensor 的维度\n", - "print(pytorch_tensor1.dim())" - ] - }, - { - "cell_type": "code", - "execution_count": 12, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "200\n" - ] - } - ], - "source": [ - "# 得到 tensor 的所有元素个数\n", - "print(pytorch_tensor1.numel())" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### 1.3 小练习\n", - "\n", - "查阅以下[文档](http://pytorch.org/docs/0.3.0/tensors.html)了解 tensor 的数据类型,创建一个 float64、大小是 3 x 2、随机初始化的 tensor,将其转化为 numpy 的 ndarray,输出其数据类型\n", - "\n", - "参考输出: float64" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "float64\n" - ] - } - ], - "source": [ - "# 答案\n", - "x = torch.randn(3, 2)\n", - "x = x.type(torch.DoubleTensor)\n", - "x_array = x.numpy()\n", - "print(x_array.dtype)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "\n" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 2. Tensor的操作\n", - "Tensor 操作中的 API 和 NumPy 非常相似,如果你熟悉 NumPy 中的操作,那么 tensor 基本是一致的,下面我们来列举其中的一些操作" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### 2.1 基本操作" - ] - }, - { - "cell_type": "code", - "execution_count": 13, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "tensor([[1., 1.],\n", - " [1., 1.],\n", - " [1., 1.]])\n" - ] - } - ], - "source": [ - "x = torch.ones(3, 2)\n", - "print(x) # 这是一个float tensor" - ] - }, - { - "cell_type": "code", - "execution_count": 14, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "torch.FloatTensor\n" - ] - } - ], - "source": [ - "print(x.type())" - ] - }, - { - "cell_type": "code", - "execution_count": 15, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "tensor([[1, 1],\n", - " [1, 1],\n", - " [1, 1]])\n" - ] - } - ], - "source": [ - "# 将其转化为整形\n", - "x = x.long()\n", - "# x = x.type(torch.LongTensor)\n", - "print(x)" - ] - }, - { - "cell_type": "code", - "execution_count": 16, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "tensor([[1., 1.],\n", - " [1., 1.],\n", - " [1., 1.]])\n" - ] - } - ], - "source": [ - "# 再将其转回 float\n", - "x = x.float()\n", - "# x = x.type(torch.FloatTensor)\n", - "print(x)" - ] - }, - { - "cell_type": "code", - "execution_count": 17, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "tensor([[-1.2200, 0.9769, -2.3477],\n", - " [ 1.0125, -1.3236, -0.2626],\n", - " [-0.3501, 0.5753, 1.5657],\n", - " [ 0.4823, -0.4008, -1.3442]])\n" - ] - } - ], - "source": [ - "x = torch.randn(4, 3)\n", - "print(x)" - ] - }, - { - "cell_type": "code", - "execution_count": 18, - "metadata": {}, - "outputs": [], - "source": [ - "# 沿着行取最大值\n", - "max_value, max_idx = torch.max(x, dim=1)" - ] - }, - { - "cell_type": "code", - "execution_count": 19, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "tensor([0.9769, 1.0125, 1.5657, 0.4823])" - ] - }, - "execution_count": 19, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "# 每一行的最大值\n", - "max_value" - ] - }, - { - "cell_type": "code", - "execution_count": 20, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "tensor([1, 0, 2, 0])" - ] - }, - "execution_count": 20, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "# 每一行最大值的下标\n", - "max_idx" - ] - }, - { - "cell_type": "code", - "execution_count": 21, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "tensor([-2.5908, -0.5736, 1.7909, -1.2627])\n" - ] - } - ], - "source": [ - "# 沿着行对 x 求和\n", - "sum_x = torch.sum(x, dim=1)\n", - "print(sum_x)" - ] - }, - { - "cell_type": "code", - "execution_count": 22, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "torch.Size([4, 3])\n", - "torch.Size([1, 4, 3])\n", - "tensor([[[-1.2200, 0.9769, -2.3477],\n", - " [ 1.0125, -1.3236, -0.2626],\n", - " [-0.3501, 0.5753, 1.5657],\n", - " [ 0.4823, -0.4008, -1.3442]]])\n" - ] - } - ], - "source": [ - "# 增加维度或者减少维度\n", - "print(x.shape)\n", - "x = x.unsqueeze(0) # 在第一维增加\n", - "print(x.shape)\n", - "print(x)" - ] - }, - { - "cell_type": "code", - "execution_count": 23, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "torch.Size([1, 1, 4, 3])\n" - ] - } - ], - "source": [ - "x = x.unsqueeze(1) # 在第二维增加\n", - "print(x.shape)" - ] - }, - { - "cell_type": "code", - "execution_count": 24, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "torch.Size([1, 4, 3])\n", - "tensor([[[-1.2200, 0.9769, -2.3477],\n", - " [ 1.0125, -1.3236, -0.2626],\n", - " [-0.3501, 0.5753, 1.5657],\n", - " [ 0.4823, -0.4008, -1.3442]]])\n" - ] - } - ], - "source": [ - "x = x.squeeze(0) # 减少第一维\n", - "print(x.shape)\n", - "print(x)" - ] - }, - { - "cell_type": "code", - "execution_count": 25, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "torch.Size([4, 3])\n" - ] - } - ], - "source": [ - "x = x.squeeze() # 将 tensor 中所有的一维全部都去掉\n", - "print(x.shape)" - ] - }, - { - "cell_type": "code", - "execution_count": 26, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "torch.Size([3, 4, 5])\n", - "torch.Size([4, 3, 5])\n", - "torch.Size([5, 3, 4])\n" - ] - } - ], - "source": [ - "x = torch.randn(3, 4, 5)\n", - "print(x.shape)\n", - "\n", - "# 使用permute和transpose进行维度交换\n", - "x = x.permute(1, 0, 2) # permute 可以重新排列 tensor 的维度\n", - "print(x.shape)\n", - "\n", - "x = x.transpose(0, 2) # transpose 交换 tensor 中的两个维度\n", - "print(x.shape)" - ] - }, - { - "cell_type": "code", - "execution_count": 27, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "torch.Size([3, 4, 5])\n", - "torch.Size([12, 5])\n", - "torch.Size([3, 20])\n" - ] - } - ], - "source": [ - "# 使用 view 对 tensor 进行 reshape\n", - "x = torch.randn(3, 4, 5)\n", - "print(x.shape)\n", - "\n", - "x = x.view(-1, 5) # -1 表示任意的大小,5 表示第二维变成 5\n", - "print(x.shape)\n", - "\n", - "x = x.view(3, 20) # 重新 reshape 成 (3, 20) 的大小\n", - "print(x.shape)" - ] - }, - { - "cell_type": "code", - "execution_count": 32, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "tensor([[-3.1321, -0.9734, 0.5307, 0.4975],\n", - " [ 0.8537, 1.3424, 0.2630, -1.6658],\n", - " [-1.0088, -2.2100, -1.9233, -0.3059]])\n" - ] - } - ], - "source": [ - "x = torch.randn(3, 4)\n", - "y = torch.randn(3, 4)\n", - "\n", - "# 两个 tensor 求和\n", - "z = x + y\n", - "# z = torch.add(x, y)\n", - "print(z)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### 2.2 `inplace`操作\n", - "另外,pytorch中大多数的操作都支持 `inplace` 操作,也就是可以直接对 tensor 进行操作而不需要另外开辟内存空间,方式非常简单,一般都是在操作的符号后面加`_`,比如" - ] - }, - { - "cell_type": "code", - "execution_count": 33, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "torch.Size([3, 3])\n", - "torch.Size([1, 3, 3])\n", - "torch.Size([3, 1, 3])\n" - ] - } - ], - "source": [ - "x = torch.ones(3, 3)\n", - "print(x.shape)\n", - "\n", - "# unsqueeze 进行 inplace\n", - "x.unsqueeze_(0)\n", - "print(x.shape)\n", - "\n", - "# transpose 进行 inplace\n", - "x.transpose_(1, 0)\n", - "print(x.shape)" - ] - }, - { - "cell_type": "code", - "execution_count": 34, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "tensor([[1., 1., 1.],\n", - " [1., 1., 1.],\n", - " [1., 1., 1.]])\n", - "tensor([[2., 2., 2.],\n", - " [2., 2., 2.],\n", - " [2., 2., 2.]])\n" - ] - } - ], - "source": [ - "x = torch.ones(3, 3)\n", - "y = torch.ones(3, 3)\n", - "print(x)\n", - "\n", - "# add 进行 inplace\n", - "x.add_(y)\n", - "print(x)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### 2.3 **小练习**\n", - "\n", - "访问[文档](http://pytorch.org/docs/tensors.html)了解 tensor 更多的 api,实现下面的要求\n", - "\n", - "创建一个 float32、4 x 4 的全为1的矩阵,将矩阵正中间 2 x 2 的矩阵,全部修改成2\n", - "\n", - "参考输出\n", - "$$\n", - "\\left[\n", - "\\begin{matrix}\n", - "1 & 1 & 1 & 1 \\\\\n", - "1 & 2 & 2 & 1 \\\\\n", - "1 & 2 & 2 & 1 \\\\\n", - "1 & 1 & 1 & 1\n", - "\\end{matrix}\n", - "\\right] \\\\\n", - "[torch.FloatTensor\\ of\\ size\\ 4x4]\n", - "$$" - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\n", - " 1 1 1 1\n", - " 1 2 2 1\n", - " 1 2 2 1\n", - " 1 1 1 1\n", - "[torch.FloatTensor of size 4x4]\n", - "\n" - ] - } - ], - "source": [ - "# 答案\n", - "x = torch.ones(4, 4).float()\n", - "x[1:3, 1:3] = 2\n", - "print(x)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 3. Variable\n", - "tensor 是 PyTorch 中的基础数据类型,但是构建神经网络还远远不够,需要能够构建计算图的 tensor,这就是 Variable。Variable 是对 tensor 的封装,操作和 tensor 是一样的,但是每个 Variabel都有三个属性:\n", - "* Variable 中的 tensor本身`.data`,\n", - "* 对应 tensor 的梯度`.grad`\n", - "* Variable 是通过什么方式得到的`.grad_fn`" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### 3.1 Variable的基本操作" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": {}, - "outputs": [], - "source": [ - "import torch\n", - "from torch.autograd import Variable" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "metadata": {}, - "outputs": [], - "source": [ - "x_tensor = torch.randn(3, 4)\n", - "y_tensor = torch.randn(3, 4)\n", - "\n", - "# 将 tensor 变成 Variable\n", - "x = Variable(x_tensor, requires_grad=True) # 默认 Variable 是不需要求梯度的,所以我们用这个方式申明需要对其进行求梯度\n", - "y = Variable(y_tensor, requires_grad=True)" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "metadata": {}, - "outputs": [], - "source": [ - "z = torch.sum(x + y)" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "tensor(-7.7018)\n", - "\n" - ] - } - ], - "source": [ - "print(z.data)\n", - "print(z.grad_fn)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "上面我们打出了 z 中的 tensor 数值,同时通过`grad_fn`知道了其是通过 Sum 这种方式得到的" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "tensor([[1., 1., 1., 1.],\n", - " [1., 1., 1., 1.],\n", - " [1., 1., 1., 1.]])\n", - "tensor([[1., 1., 1., 1.],\n", - " [1., 1., 1., 1.],\n", - " [1., 1., 1., 1.]])\n" - ] - } - ], - "source": [ - "# 求 x 和 y 的梯度\n", - "z.backward()\n", - "\n", - "print(x.grad)\n", - "print(y.grad)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "通过`.grad`我们得到了 x 和 y 的梯度,这里我们使用了 PyTorch 提供的自动求导机制,非常方便,下一小节会具体讲自动求导。" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### 3.2 **小练习**\n", - "\n", - "尝试构建一个函数 $y = x^2 $,然后求 x=2 的导数。\n", - "\n", - "参考输出:4" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "提示:\n", - "\n", - "$y = x^2$的图像如下" - ] - }, - { - "cell_type": "code", - "execution_count": 46, - "metadata": {}, - "outputs": [ - { - "data": { - "image/png": "iVBORw0KGgoAAAANSUhEUgAAAWoAAAD4CAYAAADFAawfAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAAAnB0lEQVR4nO3dd3hUVf7H8feZyaRDAkkIhCSEEFpAekcUBbvoYqPYcFXsZd2iq/tTd11dy9pdC9ZVKVbEiqKgINICRFoChIQ0IIUQkpBCMnN+fyS6igSGkMm5M/N9PU8eyWSY+VwDn1zOPfccpbVGCCGEddlMBxBCCHFkUtRCCGFxUtRCCGFxUtRCCGFxUtRCCGFxAZ540ejoaJ2UlOSJlxZCCJ+0du3aUq11zOG+5pGiTkpKIi0tzRMvLYQQPkkpldvc12ToQwghLE6KWgghLE6KWgghLE6KWgghLE6KWgghLE6KWgghLE6KWgghLM4yRV1b7+Tlpdn8sKPUdBQhhDhmSzKLeX15DgcbXK3+2pYp6gCb4uVl2by6LMd0FCGEOGYvfLeD//6wE4ddtfprW6eo7TYuGhrPkq3F7NlfazqOEEK4LbukitU5ZVwyPAGlfLioAS4ZloBLw/tr801HEUIIt72Tlo/dprhoSLxHXt9SRZ0UHcbo5CjeScvH5ZItwoQQ1lfvdPHB2gJO7dOJTu2DPfIelipqgKkjEsgvq2FF9l7TUYQQ4qi+ySimtOogU4cneOw9LFfUZ/TrTESIg7mr80xHEUKIo5q3Jo/O7YM5uddhVyhtFZYr6mCHncmDu/LV5iLKDhw0HUcIIZq1q7yG77aVcPGweALsnqtTyxU1wJThCRx0upi/vtB0FCGEaNZ7aQVo3TgRwpMsWdR9u7RnYEIk76zJQ2u5qCiEsB6nS/NuWj4npkST0DHUo+9lyaIGmDo8gW1FVazPLzcdRQghfmN5VimF5TVM8eBFxJ9YtqgnDYwjNNDOPLmoKISwoHlr8ugQ6uD0frEefy/LFnV4UACTBsTxyY+7qaytNx1HCCF+VlpVx6ItRVwwJJ6gALvH38+yRQ0wZUQCNfVOPvlxt+koQgjxs/nrCql36jYZ9gCLF/XghEh6x7Zj3hoZ/hBCWIPWmrmr8xiSGEmv2HZt8p6WLmqlFNNHJrKhYD8bCspNxxFCCFZk7yW79ACXjuzWZu9p6aIGmDykKyEOO3NWyVm1EMK82avyiAhxcM6ALm32npYv6vbBDs4bGMeC9F1UyEVFIYRBJZV1fLlpDxcNjSfY4fmLiD+xfFEDXDoqkZp6JwvkTkUhhEHvrc2nwaWZPjKxTd/XK4p6QHwk/bu2Z/YquVNRCGGGy6WZsyqPUckd6RET3qbv7VZRK6X+oJTarJTapJSaq5TyzKKrR3DpyG5k7qlkXd6+tn5rIYRg6fYSCvbVtOlFxJ8ctaiVUl2BW4FhWuv+gB2Y6ulghzpvYBzhQQHMXikXFYUQbW/2qjyiwgI5o1/nNn9vd4c+AoAQpVQAEArs8lykwwsLCmDy4K58unE35dWy/KkQou3s3l/D4sxiLhmeQGBA248YH/UdtdaFwL+BPGA3sF9r/dWhz1NKzVRKpSml0kpKSlo/KTB9ZCIHG1y8v7bAI68vhBCH886afJwuzbThbXsR8SfuDH10AM4HugNxQJhS6rJDn6e1nqW1Hqa1HhYT45mdDvp2ac+QxEjmyEVFIUQbaXC6mLc6n5N6xZAY5dnlTJvjzjn8RCBHa12ita4HPgTGeDZW8y4d2Y3s0gOyp6IQok0szixmT0Utl7bxlLxfcqeo84BRSqlQpZQCJgAZno3VvHMGdCEixCEXFYUQbWL2qjxi2wcxoU8nYxncGaNeBbwPrAM2Nv2eWR7O1axgh52Lh8bz5eY9FFXUmoohhPADO0sP8N22EqaNSPTonohH49Y7a63v01r30Vr311pfrrWu83SwI7lsVDecWsv6H0IIj3prZS4BNsX0EeaGPcBL7kw8VFJ0GCf3imHO6jwONrhMxxFC+KDqgw28l5bPmf0706l9m9/j9yteWdQAV4zu1rhAyuY9pqMIIXxQ40JwDVwxOsl0FO8t6pN7dSKxYyhvrcg1HUUI4WO01ry5Ipc+ndsxPKmD6TjeW9R2m+KyUYms3llGxu4K03GEED5kbe4+MnZXcMXoJBonu5nltUUNcMmwBIICbLwpZ9VCiFb03xW5tAsO4HeD40xHAby8qCNDAzl/UBwfrS9kf41sKiCEOH7FFbV8sXE3Fw9NIDQwwHQcwMuLGuCK0UnU1Dtl/Q8hRKuYu7pxc4DLR7f9cqbN8fqi7t81giGJkby9MheXS9b/EEK0XL3TxZzVuZzUK4bu0WGm4/zM64sa4MoxSeSUHmBZVqnpKEIIL/bV5iKKKuq40kJn0+AjRX1m/85Ehwfy5g87TUcRQnixN1fsJL5DCON7m1vX43B8oqiDAuxMH5HI4q3F7Cw9YDqOEMILbd61n1U5ZVwxuht2m/kpeb/kE0UNjet/BNgUb8hZtRCiBV5fvpPQQDtThpld1+NwfKaoO7UP5twBcbyXlk9FrUzVE0K4r6Syjo/Td3HhkHgiQh2m4/yGzxQ1wO/HdufAQSfvpclUPSGE++asyuOg08WMsUmmoxyWTxX1CfERDE/qwBs/5OCUqXpCCDfUNTh5a2Uup/SOoUdMuOk4h+VTRQ1w1dju5JfV8HVGkekoQggv8OmPuymtquOqsd1NR2mWzxX16amxdI0M4fXlOaajCCEsTmvNa8tzSOkUzrie0abjNMvnijrAbuOK0d1YmV3G5l37TccRQljYmp372LyrgqvGWmOVvOb4XFEDTB2eSIjDzhvLd5qOIoSwsNeX5xAR4uCCwfGmoxyRTxZ1RKiDi4bGsyB9F6VVRrd3FEJYVH5ZNV9u3sP0kYmEBNpNxzkinyxqgBljkzjodDF7pWyAK4T4rTdX7EQpxeWjrLWux+H4bFH3iAlnfO8Y3lqZS22903QcIYSFVNU1MG9N48a1cZEhpuMclc8WNcC145IprapjQXqh6ShCCAuZtzqPytoGZo5LNh3FLT5d1GN6RJHapT0vL8uRtaqFEAA0OF28vnwnI7p3ZGBCpOk4bvHpolZKMfOkZLKKq/huW4npOEIIC/h80x4Ky2u85mwafLyoAc4Z0IUuEcHMWpptOooQwjCtNbOW7iA5JoxT+1hrzekj8fmidtht/H5sd1Zk72VjgdwAI4Q/W5ldxqbCCq4dl4zNYmtOH4nPFzXA1BEJtAsK4OVlclYthD97eVk20eGBTB7c1XSUY+IXRd0u2MG0kYl8tnE3BfuqTccRQhiwvaiSxZnFXDE6iWCHtW9wOZRfFDXAjDFJKBp3cRBC+J9XluUQ7LBxmRfc4HIovynquMgQJg2MY97qPPbXyA4wQviT4spa5q8v5OKhCXQMCzQd55j5TVEDXDOucQeYuavltnIh/MmbP+RS73Jx9YnWXXP6SPyqqPvFRTA2JYrXl+dQ1yC3lQvhDw7UNfD2qlxOT40lKTrMdJwW8auiBrj+5B4UVdTx0Xq5rVwIfzB3dR7l1fVcd3IP01FazO+K+sSUaPp3bc+L32XLvopC+Li6BievLMthVHJHhiR2MB2nxdwqaqVUpFLqfaVUplIqQyk12tPBPEUpxY3jU8gpPcDCTXtMxxFCeNBH6wvZU1HLjeNTTEc5Lu6eUT8NLNRa9wEGAhmei+R5Z/TrTHJ0GC98l4XWclYthC9yujQvfZdN/67tLb0fojuOWtRKqQjgJOBVAK31Qa11uYdzeZTdprju5GQ2FVawbHup6ThCCA/4cvMesksPcMPJKZbeD9Ed7pxRdwdKgNeVUuuVUq8opX5z6VQpNVMplaaUSispsf5KdZMHx9O5fTDPf5tlOooQopVprXn+2yySo8M4s39n03GOmztFHQAMAV7QWg8GDgB3HfokrfUsrfUwrfWwmJiYVo7Z+gIDbFwzrjsrs8tYl7fPdBwhRCv6PquUTYUVXHdyMnYvWnypOe4UdQFQoLVe1fT5+zQWt9ebNiKRyFAHL3y7w3QUIUQren7JDmLbB/E7L1t8qTlHLWqt9R4gXynVu+mhCcAWj6ZqI2FBAVw5OolFW4rYVlRpOo4QohWsz9vHiuy9XDsumaAA71p8qTnuzvq4BZitlNoADAIe8liiNjZjTBIhDjsvylm1ED7h+W93EBHiYNqIRNNRWo1bRa21Tm8afx6gtf6d1tpnBnU7hAUybUQiC37cRd5eWQJVCG+2dU8li7YUceWYJMKCAkzHaTV+d2fi4fx0weGF72QGiBDe7NnF2wkLtPP7sUmmo7QqKWogtn0wU4cn8P7aAtlYQAgvlVVcyWcbd3PlmCQiQ71vKdMjkaJucn3Tgi0vfidj1UJ4o+cWZxHisHONF+0u7i4p6iZxkSFcNDSBd9cUsGd/rek4QohjkFN6gI9/3MVlo7p55cYARyNF/Qs3ju+BS2s5qxbCy/xnSRYOu41rffBsGqSofyWhYygXDOnK3NV5FFfIWbUQ3iBvbzXz1xdy6chuxLQLMh3HI6SoD3HTKSk0uDSzlmabjiKEcMPz32b9vNCar5KiPkS3qDDOHxTH26tyKa2qMx1HCHEEBfuqeX9tAdOGJxDbPth0HI+Roj6Mm05Joa7BxcvL5KxaCCt74dsdKIVXb7PlDinqw+gRE86kAXG8tSKXvXJWLYQl7Sqv4b20Ai4elkBcZIjpOB4lRd2MWyekUFvv5CUZqxbCkp5dnIVGc+N43z6bBinqZqV0asfvBnXlzRU7Ka6UGSBCWEne3mreS8tn2ohE4juEmo7jcVLUR3DrhJ7UOzXPL5F51UJYyTOLt2O3KW46xbs3rXWXFPURJEWHcdGQeOasymNXeY3pOEIIYEdJFR+uK+CyUd18eqbHL0lRH8UtE1LQaJ5bIivrCWEFT3+9naAAOzf4wdj0T6SojyK+QyhThyfy7pp8Wa9aCMO27qnkkw27mDE2iehw37wL8XCkqN1w0ykp2GyKZxZvNx1FCL/25KJthAUGMNNH1/RojhS1GzpHBHP5qG58uK6A7JIq03GE8EubCvezcPMerj6xOx18cIW8I5GidtMN43sQFGDnqa/lrFoIE55YtI2IEAdXj+tuOkqbk6J2U3R4EDPGJvHJhl1k7K4wHUcIv7I2dx+LM4uZeVIy7YMdpuO0OSnqY3DdScm0CwrgsS+3mo4ihN/QWvPIF5lEhwdxlY/theguKepjEBkayA3jU1icWcyq7L2m4wjhF5ZsLWb1zjJum9iT0EDf2Vn8WEhRH6MZY5KIbR/Ewwsz0VqbjiOET3O6NI98sZWkqFCmDk8wHccYKepjFBJo5w8Te7E+r5yvthSZjiOET/tofSFbiyr50xm9cdj9t67898iPw0VD4+kRE8ajCzNpcLpMxxHCJ9XWO3li0TZO6BrB2f27mI5jlBR1CwTYbfz5jD7sKDnAB+sKTMcRwie9vTKXwvIa7jqrDzabMh3HKCnqFjqjXyyDEyN5ctF2auudpuMI4VMqaut5bkkW43pGMzYl2nQc46SoW0gpxZ1n9mFPRS1v/LDTdBwhfMqs77Ipr67nzjP7mI5iCVLUx2FUchSn9I7h+SVZlFcfNB1HCJ9QXFHLq9/nMGlgHP27RpiOYwlS1MfpzrP6UFXXwDPfyDKoQrSGf3+1lQaXiz+d3st0FMuQoj5OfTq355JhCby5Yqcs2CTEcdq8az/vrS3gytFJdIsKMx3HMqSoW8Edp/ciKMDGw19kmo4ihNfSWvPgZxlEhji4ZUJP03EsRYq6FXRqF8yNp6Tw1ZYiVuyQW8uFaIlvMor5Ycdebp/Yi4gQ/1t46UikqFvJ1Sd2Jy4imH9+tgWXS24tF+JY1DtdPPR5BskxYUwfmWg6juVIUbeSYIedO8/qw+ZdFXy4vtB0HCG8yuyVuWSXHuCes/v69a3izXH7/4hSyq6UWq+U+tSTgbzZeQPjGJQQyWNfZlJ9sMF0HCGsbfZsSEpC22ycfs4o/lyaxql9OplOZUnH8qPrNiDDU0F8gVKK/zu3L0UVdcxamm06jhDWNXs2zJwJubkorYnbX8wNcx5BzZljOpkluVXUSql44BzgFc/G8X5Du3XknAFdeOm7bHbvrzEdRwhruuceqK7+1UO2mprGx8VvuHtG/RTwF6DZpeKUUjOVUmlKqbSSkpLWyOa17jqzDy6t+dfnMl1PiMPKyzu2x/3cUYtaKXUuUKy1Xnuk52mtZ2mth2mth8XExLRaQG+U0DGU607uwcc/7mKl7AQjxG8lNjOzo7nH/Zw7Z9RjgfOUUjuBecCpSqm3PZrKB9xwcg+6RoZw/8ebZc1qIQ5x8B8PUOsI+vWDoaHw4INmAlncUYtaa/1XrXW81joJmAos1lpf5vFkXi4k0M7/nZtK5p5K3l6ZazqOEJYyK34UfznjZmrj4kEp6NYNZs2CSy81Hc2SZMKiB53RL5ZxPaN5fNE2SqvqTMcRwhIKy2t4bkkW9VOmEVyYDy4X7NwpJX0Ex1TUWutvtdbneiqMr1FKcd+kftQcdPLYwq2m4whhCQ991jjL955z+hpO4j3kjNrDUjqFc/WJ3XknLZ/0/HLTcYQwanlWKZ9t3M1N41OI7xBqOo7XkKJuA7dM6EmndkHct2CTrAMi/Fa908V9H28msWMo156UbDqOV5GibgPhQQHcfXZffizYz7w1+abjCGHE68tzyCqu4t5zUwl22E3H8SpS1G3k/EFxjEruyMNfZFBSKRcWhX8p2FfNk4u2M7FvJyb0lfU8jpUUdRtRSvHg5BOorXfxz8+2mI4jRJvRWnPvgs0oBX8/vz9KKdORvI4UdRvqERPODeN7sCB9F0u3+fdt9sJ/LNy0h8WZxdxxWi+6RoaYjuOVpKjb2A3je5AcHcbfPtpEbb3TdBwhPKqytp77P9lMapf2zBiTZDqO15KibmPBDjv/nNyfvLJqnlssO5cL3/b4V9sorqzjXxecQIBsCNBi8n/OgDE9orlgSFdeWrqDbUWVpuMI4RE/5pfz3xU7uWJUNwYmRJqO49WkqA255+y+hAUFcM/8jTK3WvicBqeLv364kU7tgvjjGb1Nx/F6UtSGRIUHcffZfVmzcx9z18gavMK3vLY8hy27K7h/Uj/aB8uO4sdLitqgi4fGM6ZHFP/6PJPCctkNRviG7JIqHv9qGxP7xnJm/86m4/gEKWqDlFI8fMEAnC7NXz/ciNYyBCK8m9Ol+cv7GwgKsPHQZJkz3VqkqA1LjArlzjN7s3RbCe+vLTAdR4jj8uaKnaTl7uPeSf3o1D7YdByfIUVtAVeMTmJEUkce+HQLRRW1puMI0SK5ew/w6MKtjO8dw4VDupqO41OkqC3AZlM8ctEA6hpc3DNfhkCE93G5NHd+sIEAm+JfF5wgQx6tTIraIrpHh/HnM3rzdUYxC9J3mY4jxDGZvTqPldll3HNOX7pEyG3irU2K2kKuGtudIYmR3P/JZoorZQhEeIeCfdU8/HkG43pGM2V4guk4PkmK2kLsNsWjFw2k+qCTu2UWiPACLpfmz+9tAJAhDw+SoraYlE7h/KVpCEQ2GRBW9+r3OazI3su9k1Jlay0PkqK2oN+P7c7YlCge+HQLO0sPmI4jxGFl7K7gsS+3cnpqLJcMkyEPT5KitiCbTfHviwcSYFPc/k46DU6X6UhC/EptvZM/vJNO+xCHDHm0ASlqi+oSEcKDk08gPb+c/yzZYTqOEL/y+FdbydxTyWMXDSAqPMh0HJ8nRW1hkwbG8btBcTyzeDvr8/aZjiMEAD9klfLyshwuG5XIKX1k/8O2IEVtcX8/vz+x7YK4490fqT7YYDqO8HP7q+v543s/khwdxj1np5qO4zekqC0uIsTB45cMYufeA/zjE9kUV5ijtebujzZSUlnHk1MGERJoNx3Jb0hRe4HRPaK44eQezFuTz4L0QtNxhJ+aszqPzzbs5o7Te8mOLW1MitpL3HFaL4Z168DdH24ku6TKdBzhZ7bsquDvn2zhpF4xXH9SD9Nx/I4UtZcIsNt4ZtpgHAE2bpqzXnYwF22mqq6Bm+esIzLEwROXDMRmk6l4bU2K2ovERYbwxCUDydhdwT8/k/Fq4Xlaa/42fyM79x7gmWmDiZapeEZIUXuZU/vEMvOkZN5e2TheKIQnvZdWwEfpu7h9Yi9GJUeZjuO3pKi90J/P6M2ghEju+mADuXvlFnPhGduKKrn3402M6RHFTaekmI7j16SovZDDbuO56YNRCm54ex01B2W8WrSuytp6rn97LeFBATw1dRB2GZc2SoraS8V3COWpqYPI2FPBXz/cIEuiilbjcmnuePdHcvdW89z0IXRqJ3sfmiZF7cVO7RPLHRN78VH6Ll5fvtN0HOEjnluSxaItRfztnL4yLm0RRy1qpVSCUmqJUmqLUmqzUuq2tggm3HPTKSmcnhrLg59nsGLHXtNxhJf7JqOIJ7/exgWDuzJjTJLpOKKJO2fUDcAftdapwCjgJqWU3ORvETab4vFLBpIUFcrNc9ZRWF5jOpLwUtklVdw+L53ULu15SJYutZSjFrXWerfWel3TryuBDED2greQdsEOZl0xjLoGF9e/tVZuhhHHrKqugeveWkuAXfHS5UMJdsg6HlZyTGPUSqkkYDCw6jBfm6mUSlNKpZWUlLRSPOGuHjHhPDllEBsL98t+i+KYuFyaP76bzo6SKv4zfYhsqWVBbhe1Uioc+AC4XWtdcejXtdaztNbDtNbDYmJiWjOjcNNpqbHccVovPlxfyHOLs0zHEV7ikYWZfLm5iL+dk8qYlGjTccRhBLjzJKWUg8aSnq21/tCzkcTxuOXUFHaWHuDxRdtIjArl/EEySiWaN3d1Hi8tzebyUd24amyS6TiiGe7M+lDAq0CG1voJz0cSx0Mpxb8uPIER3Tvy5/c3sDa3zHQkYVHLtpfwt482Mb53DPdNSpWLhxbmztDHWOBy4FSlVHrTx9keziWOQ1CAnZcuG0rXyBCufXOt3GYufmNbUSU3vr2Onp3CeXbaYALsckuFlbkz6+N7rbXSWg/QWg9q+vi8LcKJlusQFshrM4bj0pqr3ljD/up605GERZRU1nHV62sIDrTz6ozhtAt2mI4kjkJ+jPqw7tFhvHTZUPLLqpn5VppM2xMcqGvgmjfT2HugjlevHEbXyBDTkYQbpKh93MjkKP598UBW5ZRx69z1NDhdpiMJQ+oanFz/9lo2Fe7n2WlDGBAfaTqScJMUtR84f1BX7p+UyldbivirzLH2S06X5o53fmTZ9lIeuXAAp6XGmo4kjoFb0/OE95sxtjv7qut5+pvtRIY6uPvsvnKV309orfnbR5v4bONu/nZOXy4aGm86kjhGUtR+5PaJPSmvPsjLy3LoEBbIjeNlMXh/8NiXW5m7Oo+bTunBNeOSTccRLSBF7UeUUtw3qR/lNfU8unArkSGBTB+ZaDqW8KCXl2bz/Lc7mD4ykT+d3tt0HNFCUtR+xmZT/PvigVTWNnDPRxsJsCkuGZ5gOpbwgNe+z+HBzzM4Z0AXHji/vwx1eTG5mOiHHHYbz186hJN6xvCXDzbwzpo805FEK3v1+xz+8ekWzurfmaemyFZa3k6K2k8FO+y8dPlQxveO4c4PNjJvtZS1r3hlWTYPfLqFs0/ozDPTBuOQuw69nnwH/Viww86Llw3llN4x3PXhRuaskrL2dq8sy+afn2VwzgldeHqqlLSvkO+inwt22Hnx8sayvnv+RmavyjUdSbTQy0v/V9JPTR0kJe1D5DspCApoLOtT+3TinvmbeP7bLLkpxotorXnsy8yfLxw+LSXtc+S7KYCmsr5sKOcNjOPRhVt54NMMXC4pa6trcLq464ON/GfJDqaNSOCZqbISni+S6XniZ4EBNp6aMoio8EBeW57D3gN1PHbRQAID5C++FdXWO7ll7noWbSni1lNT+MNpvWQKno+Soha/YrMp7j03lZh2QTy6cCv7qut54dIhhAXJHxUr2V9Tz7X/TWNNbhl/P68fV45JMh1JeJCcKonfUEpx4/gUHr1wAN9vL2H6yysprqg1HUs0KdhXzZSXVrA+fx/PThssJe0HpKhFsy4ZnsCsy4exvbiK855bzoaCctOR/N6anWWc/9xyCstreH3GCM4dEGc6kmgDUtTiiCamxvLBDWOw2xQXv7iCBemFpiP5rXmr85j+8koiQhx8dNNYTuwpO4b7CylqcVR9u7Tn45vHMjA+ktvmpfPYl5kyI6QNNThd3P/xZu76cCOjkqOYf+NYesSEm44l2pAUtXBLVHgQb18zkmkjEvjPkh3MfGst+2tkH0ZP21tVx1VvrOGNH3Zy9YndeX3GcCJCZY9DfyNFLdwWGGDjockncP+kVL7dWszZTy9jXd4+07F81g87Sjnr6WWsyinj0QsH8H/npsocaT8l33VxTJRSzBjbnfeuH41ScPGLK3jh2x0yFNKKGpwunli0jUtfWUV4cADzbxwjS9H6OSlq0SKDEzvw2a3jOLNfZx5ZmMmVr6+mpLLOdCyvt3t/DdNfXsUz32znwiHxfHLzifSLizAdSxgmRS1aLCLEwXPTB/PQ5BNYnVPGWU8vY+Gm3aZjeSWtNQvSCznr6WVs2rWfJ6cM5N8XD5QbjQQgRS2Ok1KK6SMT+fjmE+nULojr317HjbPXUlwpN8i4a1d5DVf/N43b5qWTFBXGp7ecyOTBsgGt+B/liVXShg0bptPS0lr9dYW11TtdzFqazdPfbCfEYef/zk3lwiFdZf2JZrhcmjmr83j4i0ycLs2fzujNjDFJshuLn1JKrdVaDzvs16SoRWvLKq7izg82sDZ3Hyf1iuH+Sakky7zfX9m6p5J7F2xiVU4ZY1Oi+NfkASRGhZqOJQySohZtzuXSvLUyl0cXZlLX4OLy0d24bUJPIkMDTUczqqSyjie/3sa81XmEBwVwzzl9uWRYgvyrQ0hRC3NKKut4YtE23lmTR7tgB7dO6Mnlo7r53dKptfVOXluew/NLdlBb7+SyUY0/uDqE+fcPLvE/UtTCuMw9FTz4WQbLtpfSPTqMW05NYdLAOJ/fiaSuwcn8dYU8uziLwvIaJvaN5a9n95FbwMVvSFELS9Ba8+22Eh75IpPMPZXEdwjhupN7cPHQeIIddtPxWlX1wQbmrs7n5aXZ7Kmo5YSuEdx1Vh/GpshCSuLwpKiFpWitWZxZzHNLslifV050eBDXjOvO1OEJXj+Gvbeqjjmr8nhteQ77qusZldyRG8enMK5ntIxDiyOSohaWpLVmZXYZz3+bxbLtpQQG2Dirf2emDEtgVHIUNi+ZpuZ0ab7PKuWdNXks2lJEvVMzoU8nbjylB0O7dTQdT3iJIxW13PYkjFFKMbpHFKN7RLFlVwXvrMlj/vpCFqTvIrFjKJcMi2fSwDi6RYWZjnpYO0qq+Dh9F++vLaCwvIYOoQ6uGJ3E1OEJ9IxtZzqe8CFyRi0spbbeycJNe5i3Jo+V2WUA9OwUzoS+sZyW2olBCR2M3RDS4HSxNncfX2cU8XVGMTmlBwAY1zOaKcMTOC01lqAA3xprF23nuIc+lFJnAk8DduAVrfXDR3q+FLVoDfll1SzaUsTXGUWszimjwaWJCgtkZHJHBiVEMjixA/3jIggJ9Ew5HqhrYGPhftbnlZOev49VOWWUV9fjsCtGJUdxWmosE/rG0jUyxCPvL/zLcRW1UsoObANOAwqANcA0rfWW5n6PFLVobftr6vluWwmLM4pYm7eP/LIaAOw2RZ/O7egV246EjqEkdAghoWMoiR1D6RgWSFCArdmLeFpr6hpclFbVkV9WQ35ZNfn7qskvqyZzTyXbiir5afXWblGhDO3WgYl9YxnXM5p2wbJ4v2hdxztGPQLI0lpnN73YPOB8oNmiFqK1RYQ4OG9gHOcNbNzMtbSqjvS8ctLzGz9W55TxUXohh5532BSEBgYQEmgnNNCO1lB90EnNwQZq6p0cuoy2TUGXiBCSY8I4PTWWwYkdGJgQSUe5MUUY5E5RdwXyf/F5ATDy0CcppWYCMwESExNbJZwQzYkOD2JiaiwTU2N/fuxgg4td5TXk76smr6ya8up6ag46qT7opPpgA9UHnSgFoYF2QhwBjf8NtNMxLJCEDo1n4V0ig33+JhzhfVpt1ofWehYwCxqHPlrrdYVwV2CAjaToMJKirTlLRIiWcufUoRD45T5A8U2PCSGEaAPuFPUaoKdSqrtSKhCYCnzs2VhCCCF+ctShD611g1LqZuBLGqfnvaa13uzxZEIIIQA3x6i11p8Dn3s4ixBCiMOQy9tCCGFxUtRCCGFxUtRCCGFxUtRCCGFxHlk9TylVAuS28LdHA6WtGMckXzkWXzkOkGOxIl85Dji+Y+mmtY453Bc8UtTHQymV1tzCJN7GV47FV44D5FisyFeOAzx3LDL0IYQQFidFLYQQFmfFop5lOkAr8pVj8ZXjADkWK/KV4wAPHYvlxqiFEEL8mhXPqIUQQvyCFLUQQlicJYtaKfWAUmqDUipdKfWVUirOdKaWUEo9ppTKbDqW+UqpSNOZWkopdbFSarNSyqWU8rqpVEqpM5VSW5VSWUqpu0znOR5KqdeUUsVKqU2msxwPpVSCUmqJUmpL05+t20xnaimlVLBSarVS6semY/l7q76+FceolVLttdYVTb++FUjVWl9vONYxU0qdDixuWir2EQCt9Z2GY7WIUqov4AJeAv6ktfaa3YtbskGzlSmlTgKqgDe11v1N52kppVQXoIvWep1Sqh2wFvidN35fVOMOymFa6yqllAP4HrhNa72yNV7fkmfUP5V0kzDAej9N3KC1/kpr3dD06Uoad8fxSlrrDK31VtM5WujnDZq11geBnzZo9kpa66VAmekcx0trvVtrva7p15VABo17tHod3aiq6VNH00er9ZYlixpAKfWgUiofuBS413SeVvB74AvTIfzU4TZo9spC8FVKqSRgMLDKcJQWU0rZlVLpQDGwSGvdasdirKiVUl8rpTYd5uN8AK31PVrrBGA2cLOpnEdztONoes49QAONx2JZ7hyLEK1NKRUOfADcfsi/pr2K1tqptR5E47+cRyilWm1YqtV2IT9WWuuJbj51No27y9znwTgtdrTjUErNAM4FJmgrXhD4hWP4nngb2aDZoprGcz8AZmutPzSdpzVorcuVUkuAM4FWueBryaEPpVTPX3x6PpBpKsvxUEqdCfwFOE9rXW06jx+TDZotqOkC3KtAhtb6CdN5jodSKuanWV1KqRAaL1y3Wm9ZddbHB0BvGmcZ5ALXa6297gxIKZUFBAF7mx5a6Y2zVwCUUpOBZ4EYoBxI11qfYTTUMVBKnQ08xf82aH7QbKKWU0rNBcbTuKRmEXCf1vpVo6FaQCl1IrAM2Ejj33WAu5v2aPUqSqkBwH9p/PNlA97VWv+j1V7fikUthBDifyw59CGEEOJ/pKiFEMLipKiFEMLipKiFEMLipKiFEMLipKiFEMLipKiFEMLi/h95yyGcg55E7QAAAABJRU5ErkJggg==\n", - "text/plain": [ - "
" - ] - }, - "metadata": { - "needs_background": "light" - }, - "output_type": "display_data" - } - ], - "source": [ - "import numpy as np\n", - "import matplotlib.pyplot as plt\n", - "\n", - "x = np.arange(-3, 3.01, 0.1)\n", - "y = x ** 2\n", - "plt.plot(x, y)\n", - "plt.plot(2, 4, 'ro')\n", - "plt.show()" - ] - }, - { - "cell_type": "code", - "execution_count": 47, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "tensor([4.])\n" - ] - } - ], - "source": [ - "import torch\n", - "from torch.autograd import Variable\n", - "\n", - "# 答案\n", - "x = Variable(torch.FloatTensor([2]), requires_grad=True)\n", - "y = x ** 2\n", - "y.backward()\n", - "print(x.grad)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "下一次课程我们将会从导数展开,了解 PyTorch 的自动求导机制" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## References\n", - "* http://pytorch.org/tutorials/beginner/deep_learning_60min_blitz.html\n", - "* http://cs231n.github.io/python-numpy-tutorial/" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.7.9" - } - }, - "nbformat": 4, - "nbformat_minor": 2 -} diff --git a/6_pytorch/0_basic/imgs/autograd_Variable.png b/6_pytorch/0_basic/imgs/autograd_Variable.png deleted file mode 100644 index 6576cc8..0000000 Binary files a/6_pytorch/0_basic/imgs/autograd_Variable.png and /dev/null differ diff --git a/6_pytorch/0_basic/imgs/autograd_Variable.svg b/6_pytorch/0_basic/imgs/autograd_Variable.svg deleted file mode 100644 index 6164d89..0000000 --- a/6_pytorch/0_basic/imgs/autograd_Variable.svg +++ /dev/null @@ -1,2 +0,0 @@ - -
data
[Not supported by viewer]
grad
[Not supported by viewer]
grad_fn
[Not supported by viewer]
autograd.Variable
[Not supported by viewer]
\ No newline at end of file diff --git a/6_pytorch/0_basic/ref_Autograd.ipynb b/6_pytorch/0_basic/ref_Autograd.ipynb deleted file mode 100644 index 703dd93..0000000 --- a/6_pytorch/0_basic/ref_Autograd.ipynb +++ /dev/null @@ -1,1554 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 3.2 autograd\n", - "\n", - "用Tensor训练网络很方便,但从上一小节最后的线性回归例子来看,反向传播过程需要手动实现。这对于像线性回归等较为简单的模型来说,还可以应付,但实际使用中经常出现非常复杂的网络结构,此时如果手动实现反向传播,不仅费时费力,而且容易出错,难以检查。torch.autograd就是为方便用户使用,而专门开发的一套自动求导引擎,它能够根据输入和前向传播过程自动构建计算图,并执行反向传播。\n", - "\n", - "计算图(Computation Graph)是现代深度学习框架如PyTorch和TensorFlow等的核心,其为高效自动求导算法——反向传播(Back Propogation)提供了理论支持,了解计算图在实际写程序过程中会有极大的帮助。本节将涉及一些基础的计算图知识,但并不要求读者事先对此有深入的了解。关于计算图的基础知识推荐阅读Christopher Olah的文章[^1]。\n", - "\n", - "[^1]: http://colah.github.io/posts/2015-08-Backprop/\n", - "\n", - "\n", - "### 3.2.1 Variable\n", - "PyTorch在autograd模块中实现了计算图的相关功能,autograd中的核心数据结构是Variable。Variable封装了tensor,并记录对tensor的操作记录用来构建计算图。Variable的数据结构如图3-2所示,主要包含三个属性:\n", - "\n", - "- `data`:保存variable所包含的tensor\n", - "- `grad`:保存`data`对应的梯度,`grad`也是variable,而不是tensor,它与`data`形状一致。 \n", - "- `grad_fn`: 指向一个`Function`,记录tensor的操作历史,即它是什么操作的输出,用来构建计算图。如果某一个变量是由用户创建,则它为叶子节点,对应的grad_fn等于None。\n", - "\n", - "\n", - "![图3-2:Variable数据结构](imgs/autograd_Variable.png)\n", - "\n", - "Variable的构造函数需要传入tensor,同时有两个可选参数:\n", - "- `requires_grad (bool)`:是否需要对该variable进行求导\n", - "- `volatile (bool)`:意为”挥发“,设置为True,则构建在该variable之上的图都不会求导,专为推理阶段设计\n", - "\n", - "Variable提供了大部分tensor支持的函数,但其不支持部分`inplace`函数,因这些函数会修改tensor自身,而在反向传播中,variable需要缓存原来的tensor来计算反向传播梯度。如果想要计算各个Variable的梯度,只需调用根节点variable的`backward`方法,autograd会自动沿着计算图反向传播,计算每一个叶子节点的梯度。\n", - "\n", - "`variable.backward(grad_variables=None, retain_graph=None, create_graph=None)`主要有如下参数:\n", - "\n", - "- grad_variables:形状与variable一致,对于`y.backward()`,grad_variables相当于链式法则${dz \\over dx}={dz \\over dy} \\times {dy \\over dx}$中的$\\textbf {dz} \\over \\textbf {dy}$。grad_variables也可以是tensor或序列。\n", - "- retain_graph:反向传播需要缓存一些中间结果,反向传播之后,这些缓存就被清空,可通过指定这个参数不清空缓存,用来多次反向传播。\n", - "- create_graph:对反向传播过程再次构建计算图,可通过`backward of backward`实现求高阶导数。\n", - "\n", - "上述描述可能比较抽象,如果没有看懂,不用着急,会在本节后半部分详细介绍,下面先看几个例子。" - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "metadata": {}, - "outputs": [], - "source": [ - "from __future__ import print_function\n", - "import torch as t\n", - "from torch.autograd import Variable as V" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "Variable containing:\n", - " 1 1 1 1\n", - " 1 1 1 1\n", - " 1 1 1 1\n", - "[torch.FloatTensor of size 3x4]" - ] - }, - "execution_count": 2, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "# 从tensor中创建variable,指定需要求导\n", - "a = V(t.ones(3,4), requires_grad = True) \n", - "a" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "Variable containing:\n", - " 0 0 0 0\n", - " 0 0 0 0\n", - " 0 0 0 0\n", - "[torch.FloatTensor of size 3x4]" - ] - }, - "execution_count": 3, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "b = V(t.zeros(3,4))\n", - "b" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "Variable containing:\n", - " 1 1 1 1\n", - " 1 1 1 1\n", - " 1 1 1 1\n", - "[torch.FloatTensor of size 3x4]" - ] - }, - "execution_count": 4, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "# 函数的使用与tensor一致\n", - "# 也可写成c = a + b\n", - "c = a.add(b)\n", - "c" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "metadata": {}, - "outputs": [], - "source": [ - "d = c.sum()\n", - "d.backward() # 反向传播" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "metadata": { - "scrolled": true - }, - "outputs": [ - { - "data": { - "text/plain": [ - "(12.0, Variable containing:\n", - " 12\n", - " [torch.FloatTensor of size 1])" - ] - }, - "execution_count": 6, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "# 注意二者的区别\n", - "# 前者在取data后变为tensor,而后从tensor计算sum得到float\n", - "# 后者计算sum后仍然是Variable\n", - "c.data.sum(), c.sum()" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "metadata": { - "scrolled": false - }, - "outputs": [ - { - "data": { - "text/plain": [ - "Variable containing:\n", - " 1 1 1 1\n", - " 1 1 1 1\n", - " 1 1 1 1\n", - "[torch.FloatTensor of size 3x4]" - ] - }, - "execution_count": 7, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "a.grad" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "metadata": { - "scrolled": false - }, - "outputs": [ - { - "data": { - "text/plain": [ - "(True, False, True)" - ] - }, - "execution_count": 8, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "# 此处虽然没有指定c需要求导,但c依赖于a,而a需要求导,\n", - "# 因此c的requires_grad属性会自动设为True\n", - "a.requires_grad, b.requires_grad, c.requires_grad" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "(True, True, False)" - ] - }, - "execution_count": 9, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "# 由用户创建的variable属于叶子节点,对应的grad_fn是None\n", - "a.is_leaf, b.is_leaf, c.is_leaf" - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "True" - ] - }, - "execution_count": 10, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "# c.grad是None, 因c不是叶子节点,它的梯度是用来计算a的梯度\n", - "# 所以虽然c.requires_grad = True,但其梯度计算完之后即被释放\n", - "c.grad is None" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "计算下面这个函数的导函数:\n", - "$$\n", - "y = x^2\\bullet e^x\n", - "$$\n", - "它的导函数是:\n", - "$$\n", - "{dy \\over dx} = 2x\\bullet e^x + x^2 \\bullet e^x\n", - "$$\n", - "来看看autograd的计算结果与手动求导计算结果的误差。" - ] - }, - { - "cell_type": "code", - "execution_count": 11, - "metadata": {}, - "outputs": [], - "source": [ - "def f(x):\n", - " '''计算y'''\n", - " y = x**2 * t.exp(x)\n", - " return y\n", - "\n", - "def gradf(x):\n", - " '''手动求导函数'''\n", - " dx = 2*x*t.exp(x) + x**2*t.exp(x)\n", - " return dx" - ] - }, - { - "cell_type": "code", - "execution_count": 12, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "Variable containing:\n", - " 7.8454 0.4475 5.5884 0.1406\n", - " 0.4044 0.5008 0.4989 13.3268\n", - " 0.3547 0.0623 1.0497 4.2674\n", - "[torch.FloatTensor of size 3x4]" - ] - }, - "execution_count": 12, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "x = V(t.randn(3,4), requires_grad = True)\n", - "y = f(x)\n", - "y" - ] - }, - { - "cell_type": "code", - "execution_count": 13, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "Variable containing:\n", - " 19.0962 2.1796 14.4631 1.0203\n", - " -0.3276 0.1172 -0.1745 29.7573\n", - " 1.8619 -0.3699 3.9812 11.6386\n", - "[torch.FloatTensor of size 3x4]" - ] - }, - "execution_count": 13, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "y.backward(t.ones(y.size())) # grad_variables形状与y一致\n", - "x.grad" - ] - }, - { - "cell_type": "code", - "execution_count": 14, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "Variable containing:\n", - " 19.0962 2.1796 14.4631 1.0203\n", - " -0.3276 0.1172 -0.1745 29.7573\n", - " 1.8619 -0.3699 3.9812 11.6386\n", - "[torch.FloatTensor of size 3x4]" - ] - }, - "execution_count": 14, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "# autograd的计算结果与利用公式手动计算的结果一致\n", - "gradf(x) " - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### 3.2.2 计算图\n", - "\n", - "PyTorch中`autograd`的底层采用了计算图,计算图是一种特殊的有向无环图(DAG),用于记录算子与变量之间的关系。一般用矩形表示算子,椭圆形表示变量。如表达式$ \\textbf {z = wx + b}$可分解为$\\textbf{y = wx}$和$\\textbf{z = y + b}$,其计算图如图3-3所示,图中`MUL`,`ADD`都是算子,$\\textbf{w}$,$\\textbf{x}$,$\\textbf{b}$即变量。\n", - "\n", - "![图3-3:computation graph](imgs/com_graph.svg)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "如上有向无环图中,$\\textbf{X}$和$\\textbf{b}$是叶子节点(leaf node),这些节点通常由用户自己创建,不依赖于其他变量。$\\textbf{z}$称为根节点,是计算图的最终目标。利用链式法则很容易求得各个叶子节点的梯度。\n", - "$${\\partial z \\over \\partial b} = 1,\\space {\\partial z \\over \\partial y} = 1\\\\\n", - "{\\partial y \\over \\partial w }= x,{\\partial y \\over \\partial x}= w\\\\\n", - "{\\partial z \\over \\partial x}= {\\partial z \\over \\partial y} {\\partial y \\over \\partial x}=1 * w\\\\\n", - "{\\partial z \\over \\partial w}= {\\partial z \\over \\partial y} {\\partial y \\over \\partial w}=1 * x\\\\\n", - "$$\n", - "而有了计算图,上述链式求导即可利用计算图的反向传播自动完成,其过程如图3-4所示。\n", - "\n", - "![图3-4:计算图的反向传播](imgs/com_graph_backward.svg)\n", - "\n", - "\n", - "在PyTorch实现中,autograd会随着用户的操作,记录生成当前variable的所有操作,并由此建立一个有向无环图。用户每进行一个操作,相应的计算图就会发生改变。更底层的实现中,图中记录了操作`Function`,每一个变量在图中的位置可通过其`grad_fn`属性在图中的位置推测得到。在反向传播过程中,autograd沿着这个图从当前变量(根节点$\\textbf{z}$)溯源,可以利用链式求导法则计算所有叶子节点的梯度。每一个前向传播操作的函数都有与之对应的反向传播函数用来计算输入的各个variable的梯度,这些函数的函数名通常以`Backward`结尾。下面结合代码学习autograd的实现细节。" - ] - }, - { - "cell_type": "code", - "execution_count": 15, - "metadata": {}, - "outputs": [], - "source": [ - "x = V(t.ones(1))\n", - "b = V(t.rand(1), requires_grad = True)\n", - "w = V(t.rand(1), requires_grad = True)\n", - "y = w * x # 等价于y=w.mul(x)\n", - "z = y + b # 等价于z=y.add(b)" - ] - }, - { - "cell_type": "code", - "execution_count": 16, - "metadata": { - "scrolled": true - }, - "outputs": [ - { - "data": { - "text/plain": [ - "(False, True, True)" - ] - }, - "execution_count": 16, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "x.requires_grad, b.requires_grad, w.requires_grad" - ] - }, - { - "cell_type": "code", - "execution_count": 17, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "True" - ] - }, - "execution_count": 17, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "# 虽然未指定y.requires_grad为True,但由于y依赖于需要求导的w\n", - "# 故而y.requires_grad为True\n", - "y.requires_grad" - ] - }, - { - "cell_type": "code", - "execution_count": 18, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "(True, True, True)" - ] - }, - "execution_count": 18, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "x.is_leaf, w.is_leaf, b.is_leaf" - ] - }, - { - "cell_type": "code", - "execution_count": 19, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "(False, False)" - ] - }, - "execution_count": 19, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "y.is_leaf, z.is_leaf" - ] - }, - { - "cell_type": "code", - "execution_count": 20, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "" - ] - }, - "execution_count": 20, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "# grad_fn可以查看这个variable的反向传播函数,\n", - "# z是add函数的输出,所以它的反向传播函数是AddBackward\n", - "z.grad_fn " - ] - }, - { - "cell_type": "code", - "execution_count": 21, - "metadata": { - "scrolled": true - }, - "outputs": [ - { - "data": { - "text/plain": [ - "((, 0),\n", - " (, 0))" - ] - }, - "execution_count": 21, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "# next_functions保存grad_fn的输入,是一个tuple,tuple的元素也是Function\n", - "# 第一个是y,它是乘法(mul)的输出,所以对应的反向传播函数y.grad_fn是MulBackward\n", - "# 第二个是b,它是叶子节点,由用户创建,grad_fn为None,但是有\n", - "z.grad_fn.next_functions " - ] - }, - { - "cell_type": "code", - "execution_count": 22, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "True" - ] - }, - "execution_count": 22, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "# variable的grad_fn对应着和图中的function相对应\n", - "z.grad_fn.next_functions[0][0] == y.grad_fn" - ] - }, - { - "cell_type": "code", - "execution_count": 23, - "metadata": { - "scrolled": true - }, - "outputs": [ - { - "data": { - "text/plain": [ - "((, 0), (None, 0))" - ] - }, - "execution_count": 23, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "# 第一个是w,叶子节点,需要求导,梯度是累加的\n", - "# 第二个是x,叶子节点,不需要求导,所以为None\n", - "y.grad_fn.next_functions" - ] - }, - { - "cell_type": "code", - "execution_count": 24, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "(None, None)" - ] - }, - "execution_count": 24, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "# 叶子节点的grad_fn是None\n", - "w.grad_fn,x.grad_fn" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "计算w的梯度的时候,需要用到x的数值(${\\partial y\\over \\partial w} = x $),这些数值在前向过程中会保存成buffer,在计算完梯度之后会自动清空。为了能够多次反向传播需要指定`retain_graph`来保留这些buffer。" - ] - }, - { - "cell_type": "code", - "execution_count": 25, - "metadata": { - "scrolled": true - }, - "outputs": [ - { - "data": { - "text/plain": [ - "Variable containing:\n", - " 1\n", - "[torch.FloatTensor of size 1]" - ] - }, - "execution_count": 25, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "# 使用retain_graph来保存buffer\n", - "z.backward(retain_graph=True)\n", - "w.grad" - ] - }, - { - "cell_type": "code", - "execution_count": 26, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "Variable containing:\n", - " 2\n", - "[torch.FloatTensor of size 1]" - ] - }, - "execution_count": 26, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "# 多次反向传播,梯度累加,这也就是w中AccumulateGrad标识的含义\n", - "z.backward()\n", - "w.grad" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "PyTorch使用的是动态图,它的计算图在每次前向传播时都是从头开始构建,所以它能够使用Python控制语句(如for、if等)根据需求创建计算图。这点在自然语言处理领域中很有用,它意味着你不需要事先构建所有可能用到的图的路径,图在运行时才构建。" - ] - }, - { - "cell_type": "code", - "execution_count": 27, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "Variable containing:\n", - " 1\n", - "[torch.FloatTensor of size 1]" - ] - }, - "execution_count": 27, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "def abs(x):\n", - " if x.data[0]>0: return x\n", - " else: return -x\n", - "x = V(t.ones(1),requires_grad=True)\n", - "y = abs(x)\n", - "y.backward()\n", - "x.grad" - ] - }, - { - "cell_type": "code", - "execution_count": 28, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Variable containing:\n", - "-1\n", - "[torch.FloatTensor of size 1]\n", - "\n" - ] - } - ], - "source": [ - "x = V(-1*t.ones(1),requires_grad=True)\n", - "y = abs(x)\n", - "y.backward()\n", - "print(x.grad)" - ] - }, - { - "cell_type": "code", - "execution_count": 29, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "Variable containing:\n", - " 0\n", - " 0\n", - " 0\n", - " 6\n", - " 3\n", - " 2\n", - "[torch.FloatTensor of size 6]" - ] - }, - "execution_count": 29, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "def f(x):\n", - " result = 1\n", - " for ii in x:\n", - " if ii.data[0]>0: result=ii*result\n", - " return result\n", - "x = V(t.arange(-2,4),requires_grad=True)\n", - "y = f(x) # y = x[3]*x[4]*x[5]\n", - "y.backward()\n", - "x.grad" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "变量的`requires_grad`属性默认为False,如果某一个节点requires_grad被设置为True,那么所有依赖它的节点`requires_grad`都是True。这其实很好理解,对于$ \\textbf{x}\\to \\textbf{y} \\to \\textbf{z}$,x.requires_grad = True,当需要计算$\\partial z \\over \\partial x$时,根据链式法则,$\\frac{\\partial z}{\\partial x} = \\frac{\\partial z}{\\partial y} \\frac{\\partial y}{\\partial x}$,自然也需要求$ \\frac{\\partial z}{\\partial y}$,所以y.requires_grad会被自动标为True. \n", - "\n", - "`volatile=True`是另外一个很重要的标识,它能够将所有依赖于它的节点全部都设为`volatile=True`,其优先级比`requires_grad=True`高。`volatile=True`的节点不会求导,即使`requires_grad=True`,也无法进行反向传播。对于不需要反向传播的情景(如inference,即测试推理时),该参数可实现一定程度的速度提升,并节省约一半显存,因其不需要分配空间计算梯度。" - ] - }, - { - "cell_type": "code", - "execution_count": 30, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "(False, True, True)" - ] - }, - "execution_count": 30, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "x = V(t.ones(1))\n", - "w = V(t.rand(1), requires_grad=True)\n", - "y = x * w\n", - "# y依赖于w,而w.requires_grad = True\n", - "x.requires_grad, w.requires_grad, y.requires_grad" - ] - }, - { - "cell_type": "code", - "execution_count": 31, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "(False, True, False)" - ] - }, - "execution_count": 31, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "x = V(t.ones(1), volatile=True)\n", - "w = V(t.rand(1), requires_grad = True)\n", - "y = x * w\n", - "# y依赖于w和x,但x.volatile = True,w.requires_grad = True\n", - "x.requires_grad, w.requires_grad, y.requires_grad" - ] - }, - { - "cell_type": "code", - "execution_count": 32, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "(True, False, True)" - ] - }, - "execution_count": 32, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "x.volatile, w.volatile, y.volatile" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "在反向传播过程中非叶子节点的导数计算完之后即被清空。若想查看这些变量的梯度,有两种方法:\n", - "- 使用autograd.grad函数\n", - "- 使用hook\n", - "\n", - "`autograd.grad`和`hook`方法都是很强大的工具,更详细的用法参考官方api文档,这里举例说明基础的使用。推荐使用`hook`方法,但是在实际使用中应尽量避免修改grad的值。" - ] - }, - { - "cell_type": "code", - "execution_count": 33, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "(True, True, True)" - ] - }, - "execution_count": 33, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "x = V(t.ones(3), requires_grad=True)\n", - "w = V(t.rand(3), requires_grad=True)\n", - "y = x * w\n", - "# y依赖于w,而w.requires_grad = True\n", - "z = y.sum()\n", - "x.requires_grad, w.requires_grad, y.requires_grad" - ] - }, - { - "cell_type": "code", - "execution_count": 34, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "(Variable containing:\n", - " 0.3776\n", - " 0.1184\n", - " 0.8554\n", - " [torch.FloatTensor of size 3], Variable containing:\n", - " 1\n", - " 1\n", - " 1\n", - " [torch.FloatTensor of size 3], None)" - ] - }, - "execution_count": 34, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "# 非叶子节点grad计算完之后自动清空,y.grad是None\n", - "z.backward()\n", - "(x.grad, w.grad, y.grad)" - ] - }, - { - "cell_type": "code", - "execution_count": 35, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "(Variable containing:\n", - " 1\n", - " 1\n", - " 1\n", - " [torch.FloatTensor of size 3],)" - ] - }, - "execution_count": 35, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "# 第一种方法:使用grad获取中间变量的梯度\n", - "x = V(t.ones(3), requires_grad=True)\n", - "w = V(t.rand(3), requires_grad=True)\n", - "y = x * w\n", - "z = y.sum()\n", - "# z对y的梯度,隐式调用backward()\n", - "t.autograd.grad(z, y)" - ] - }, - { - "cell_type": "code", - "execution_count": 36, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "y的梯度: \n", - " Variable containing:\n", - " 1\n", - " 1\n", - " 1\n", - "[torch.FloatTensor of size 3]\n", - "\n" - ] - } - ], - "source": [ - "# 第二种方法:使用hook\n", - "# hook是一个函数,输入是梯度,不应该有返回值\n", - "def variable_hook(grad):\n", - " print('y的梯度: \\r\\n',grad)\n", - "\n", - "x = V(t.ones(3), requires_grad=True)\n", - "w = V(t.rand(3), requires_grad=True)\n", - "y = x * w\n", - "# 注册hook\n", - "hook_handle = y.register_hook(variable_hook)\n", - "z = y.sum()\n", - "z.backward()\n", - "\n", - "# 除非你每次都要用hook,否则用完之后记得移除hook\n", - "hook_handle.remove()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "最后再来看看variable中grad属性和backward函数`grad_variables`参数的含义,这里直接下结论:\n", - "\n", - "- variable $\\textbf{x}$的梯度是目标函数${f(x)} $对$\\textbf{x}$的梯度,$\\frac{df(x)}{dx} = (\\frac {df(x)}{dx_0},\\frac {df(x)}{dx_1},...,\\frac {df(x)}{dx_N})$,形状和$\\textbf{x}$一致。\n", - "- 对于y.backward(grad_variables)中的grad_variables相当于链式求导法则中的$\\frac{\\partial z}{\\partial x} = \\frac{\\partial z}{\\partial y} \\frac{\\partial y}{\\partial x}$中的$\\frac{\\partial z}{\\partial y}$。z是目标函数,一般是一个标量,故而$\\frac{\\partial z}{\\partial y}$的形状与variable $\\textbf{y}$的形状一致。`z.backward()`在一定程度上等价于y.backward(grad_y)。`z.backward()`省略了grad_variables参数,是因为$z$是一个标量,而$\\frac{\\partial z}{\\partial z} = 1$" - ] - }, - { - "cell_type": "code", - "execution_count": 37, - "metadata": { - "scrolled": true - }, - "outputs": [ - { - "data": { - "text/plain": [ - "Variable containing:\n", - " 2\n", - " 4\n", - " 6\n", - "[torch.FloatTensor of size 3]" - ] - }, - "execution_count": 37, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "x = V(t.arange(0,3), requires_grad=True)\n", - "y = x**2 + x*2\n", - "z = y.sum()\n", - "z.backward() # 从z开始反向传播\n", - "x.grad" - ] - }, - { - "cell_type": "code", - "execution_count": 38, - "metadata": { - "scrolled": true - }, - "outputs": [ - { - "data": { - "text/plain": [ - "Variable containing:\n", - " 2\n", - " 4\n", - " 6\n", - "[torch.FloatTensor of size 3]" - ] - }, - "execution_count": 38, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "x = V(t.arange(0,3), requires_grad=True)\n", - "y = x**2 + x*2\n", - "z = y.sum()\n", - "y_grad_variables = V(t.Tensor([1,1,1])) # dz/dy\n", - "y.backward(y_grad_variables) #从y开始反向传播\n", - "x.grad" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "另外值得注意的是,只有对variable的操作才能使用autograd,如果对variable的data直接进行操作,将无法使用反向传播。除了对参数初始化,一般我们不会修改variable.data的值。" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "在PyTorch中计算图的特点可总结如下:\n", - "\n", - "- autograd根据用户对variable的操作构建其计算图。对变量的操作抽象为`Function`。\n", - "- 对于那些不是任何函数(Function)的输出,由用户创建的节点称为叶子节点,叶子节点的`grad_fn`为None。叶子节点中需要求导的variable,具有`AccumulateGrad`标识,因其梯度是累加的。\n", - "- variable默认是不需要求导的,即`requires_grad`属性默认为False,如果某一个节点requires_grad被设置为True,那么所有依赖它的节点`requires_grad`都为True。\n", - "- variable的`volatile`属性默认为False,如果某一个variable的`volatile`属性被设为True,那么所有依赖它的节点`volatile`属性都为True。volatile属性为True的节点不会求导,volatile的优先级比`requires_grad`高。\n", - "- 多次反向传播时,梯度是累加的。反向传播的中间缓存会被清空,为进行多次反向传播需指定`retain_graph`=True来保存这些缓存。\n", - "- 非叶子节点的梯度计算完之后即被清空,可以使用`autograd.grad`或`hook`技术获取非叶子节点的值。\n", - "- variable的grad与data形状一致,应避免直接修改variable.data,因为对data的直接操作无法利用autograd进行反向传播\n", - "- 反向传播函数`backward`的参数`grad_variables`可以看成链式求导的中间结果,如果是标量,可以省略,默认为1\n", - "- PyTorch采用动态图设计,可以很方便地查看中间层的输出,动态的设计计算图结构。" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### 3.2.3 扩展autograd\n", - "\n", - "\n", - "目前绝大多数函数都可以使用`autograd`实现反向求导,但如果需要自己写一个复杂的函数,不支持自动反向求导怎么办? 写一个`Function`,实现它的前向传播和反向传播代码,`Function`对应于计算图中的矩形, 它接收参数,计算并返回结果。下面给出一个例子。\n", - "\n", - "```python\n", - "\n", - "class Mul(Function):\n", - " \n", - " @staticmethod\n", - " def forward(ctx, w, x, b, x_requires_grad = True):\n", - " ctx.x_requires_grad = x_requires_grad\n", - " ctx.save_for_backward(w,x)\n", - " output = w * x + b\n", - " return output\n", - " \n", - " @staticmethod\n", - " def backward(ctx, grad_output):\n", - " w,x = ctx.saved_variables\n", - " grad_w = grad_output * x\n", - " if ctx.x_requires_grad:\n", - " grad_x = grad_output * w\n", - " else:\n", - " grad_x = None\n", - " grad_b = grad_output * 1\n", - " return grad_w, grad_x, grad_b, None\n", - "```\n", - "\n", - "分析如下:\n", - "\n", - "- 自定义的Function需要继承autograd.Function,没有构造函数`__init__`,forward和backward函数都是静态方法\n", - "- forward函数的输入和输出都是Tensor,backward函数的输入和输出都是Variable\n", - "- backward函数的输出和forward函数的输入一一对应,backward函数的输入和forward函数的输出一一对应\n", - "- backward函数的grad_output参数即t.autograd.backward中的`grad_variables`\n", - "- 如果某一个输入不需要求导,直接返回None,如forward中的输入参数x_requires_grad显然无法对它求导,直接返回None即可\n", - "- 反向传播可能需要利用前向传播的某些中间结果,需要进行保存,否则前向传播结束后这些对象即被释放\n", - "\n", - "Function的使用利用Function.apply(variable)" - ] - }, - { - "cell_type": "code", - "execution_count": 39, - "metadata": {}, - "outputs": [], - "source": [ - "from torch.autograd import Function\n", - "class MultiplyAdd(Function):\n", - " \n", - " @staticmethod\n", - " def forward(ctx, w, x, b): \n", - " print('type in forward',type(x))\n", - " ctx.save_for_backward(w,x)\n", - " output = w * x + b\n", - " return output\n", - " \n", - " @staticmethod\n", - " def backward(ctx, grad_output): \n", - " w,x = ctx.saved_variables\n", - " print('type in backward',type(x))\n", - " grad_w = grad_output * x\n", - " grad_x = grad_output * w\n", - " grad_b = grad_output * 1\n", - " return grad_w, grad_x, grad_b " - ] - }, - { - "cell_type": "code", - "execution_count": 40, - "metadata": { - "scrolled": true - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "开始前向传播\n", - "type in backwardtype in forward \n", - "\n", - "开始反向传播\n" - ] - }, - { - "data": { - "text/plain": [ - "(None, Variable containing:\n", - " 1\n", - " [torch.FloatTensor of size 1], Variable containing:\n", - " 1\n", - " [torch.FloatTensor of size 1])" - ] - }, - "execution_count": 40, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "x = V(t.ones(1))\n", - "w = V(t.rand(1), requires_grad = True)\n", - "b = V(t.rand(1), requires_grad = True)\n", - "print('开始前向传播')\n", - "z=MultiplyAdd.apply(w, x, b)\n", - "print('开始反向传播')\n", - "z.backward()\n", - "\n", - "# x不需要求导,中间过程还是会计算它的导数,但随后被清空\n", - "x.grad, w.grad, b.grad" - ] - }, - { - "cell_type": "code", - "execution_count": 41, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "开始前向传播\n", - "type in forward \n", - "开始反向传播\n", - "type in backward \n" - ] - }, - { - "data": { - "text/plain": [ - "(Variable containing:\n", - " 1\n", - " [torch.FloatTensor of size 1], Variable containing:\n", - " 0.9633\n", - " [torch.FloatTensor of size 1], Variable containing:\n", - " 1\n", - " [torch.FloatTensor of size 1])" - ] - }, - "execution_count": 41, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "x = V(t.ones(1))\n", - "w = V(t.rand(1), requires_grad = True)\n", - "b = V(t.rand(1), requires_grad = True)\n", - "print('开始前向传播')\n", - "z=MultiplyAdd.apply(w,x,b)\n", - "print('开始反向传播')\n", - "\n", - "# 调用MultiplyAdd.backward\n", - "# 输出grad_w, grad_x, grad_b\n", - "z.grad_fn.apply(V(t.ones(1)))" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "之所以forward函数的输入是tensor,而backward函数的输入是variable,是为了实现高阶求导。backward函数的输入输出虽然是variable,但在实际使用时autograd.Function会将输入variable提取为tensor,并将计算结果的tensor封装成variable返回。在backward函数中,之所以也要对variable进行操作,是为了能够计算梯度的梯度(backward of backward)。下面举例说明,有关torch.autograd.grad的更详细使用请参照文档。" - ] - }, - { - "cell_type": "code", - "execution_count": 42, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "(Variable containing:\n", - " 10\n", - " [torch.FloatTensor of size 1],)" - ] - }, - "execution_count": 42, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "x = V(t.Tensor([5]), requires_grad=True)\n", - "y = x ** 2\n", - "grad_x = t.autograd.grad(y, x, create_graph=True)\n", - "grad_x # dy/dx = 2 * x" - ] - }, - { - "cell_type": "code", - "execution_count": 43, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "(Variable containing:\n", - " 2\n", - " [torch.FloatTensor of size 1],)" - ] - }, - "execution_count": 43, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "grad_grad_x = t.autograd.grad(grad_x[0],x)\n", - "grad_grad_x # 二阶导数 d(2x)/dx = 2" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "这种设计虽然能让`autograd`具有高阶求导功能,但其也限制了Tensor的使用,因autograd中反向传播的函数只能利用当前已经有的Variable操作。这个设计是在`0.2`版本新加入的,为了更好的灵活性,也为了兼容旧版本的代码,PyTorch还提供了另外一种扩展autograd的方法。PyTorch提供了一个装饰器`@once_differentiable`,能够在backward函数中自动将输入的variable提取成tensor,把计算结果的tensor自动封装成variable。有了这个特性我们就能够很方便的使用numpy/scipy中的函数,操作不再局限于variable所支持的操作。但是这种做法正如名字中所暗示的那样只能求导一次,它打断了反向传播图,不再支持高阶求导。\n", - "\n", - "\n", - "上面所描述的都是新式Function,还有个legacy Function,可以带有`__init__`方法,`forward`和`backwad`函数也不需要声明为`@staticmethod`,但随着版本更迭,此类Function将越来越少遇到,在此不做更多介绍。\n", - "\n", - "此外在实现了自己的Function之后,还可以使用`gradcheck`函数来检测实现是否正确。`gradcheck`通过数值逼近来计算梯度,可能具有一定的误差,通过控制`eps`的大小可以控制容忍的误差。\n", - "关于这部份的内容可以参考github上开发者们的讨论[^3]。\n", - "\n", - "[^3]: https://github.com/pytorch/pytorch/pull/1016" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "下面举例说明如何利用Function实现sigmoid Function。" - ] - }, - { - "cell_type": "code", - "execution_count": 44, - "metadata": {}, - "outputs": [], - "source": [ - "class Sigmoid(Function):\n", - " \n", - " @staticmethod\n", - " def forward(ctx, x): \n", - " output = 1 / (1 + t.exp(-x))\n", - " ctx.save_for_backward(output)\n", - " return output\n", - " \n", - " @staticmethod\n", - " def backward(ctx, grad_output): \n", - " output, = ctx.saved_variables\n", - " grad_x = output * (1 - output) * grad_output\n", - " return grad_x " - ] - }, - { - "cell_type": "code", - "execution_count": 45, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "True" - ] - }, - "execution_count": 45, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "# 采用数值逼近方式检验计算梯度的公式对不对\n", - "test_input = V(t.randn(3,4), requires_grad=True)\n", - "t.autograd.gradcheck(Sigmoid.apply, (test_input,), eps=1e-3)" - ] - }, - { - "cell_type": "code", - "execution_count": 46, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "232 µs ± 68.6 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n", - "191 µs ± 6.1 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n", - "215 µs ± 23.1 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n" - ] - } - ], - "source": [ - "def f_sigmoid(x):\n", - " y = Sigmoid.apply(x)\n", - " y.backward(t.ones(x.size()))\n", - " \n", - "def f_naive(x):\n", - " y = 1/(1 + t.exp(-x))\n", - " y.backward(t.ones(x.size()))\n", - " \n", - "def f_th(x):\n", - " y = t.sigmoid(x)\n", - " y.backward(t.ones(x.size()))\n", - " \n", - "x=V(t.randn(100, 100), requires_grad=True)\n", - "%timeit -n 100 f_sigmoid(x)\n", - "%timeit -n 100 f_naive(x)\n", - "%timeit -n 100 f_th(x)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "显然`f_sigmoid`要比单纯利用`autograd`加减和乘方操作实现的函数快不少,因为f_sigmoid的backward优化了反向传播的过程。另外可以看出系统实现的buildin接口(t.sigmoid)更快。" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### 3.2.4 小试牛刀: 用Variable实现线性回归\n", - "在上一节中讲解了利用tensor实现线性回归,在这一小节中,将讲解如何利用autograd/Variable实现线性回归,以此感受autograd的便捷之处。" - ] - }, - { - "cell_type": "code", - "execution_count": 47, - "metadata": {}, - "outputs": [], - "source": [ - "import torch as t\n", - "from torch.autograd import Variable as V\n", - "%matplotlib inline\n", - "from matplotlib import pyplot as plt\n", - "from IPython import display" - ] - }, - { - "cell_type": "code", - "execution_count": 48, - "metadata": {}, - "outputs": [], - "source": [ - "# 设置随机数种子,为了在不同人电脑上运行时下面的输出一致\n", - "t.manual_seed(1000) \n", - "\n", - "def get_fake_data(batch_size=8):\n", - " ''' 产生随机数据:y = x*2 + 3,加上了一些噪声'''\n", - " x = t.rand(batch_size,1) * 20\n", - " y = x * 2 + (1 + t.randn(batch_size, 1))*3\n", - " return x, y" - ] - }, - { - "cell_type": "code", - "execution_count": 49, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "" - ] - }, - "execution_count": 49, - "metadata": {}, - "output_type": "execute_result" - }, - { - "data": { - "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXQAAAD8CAYAAABn919SAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBo\ndHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAAD11JREFUeJzt3V+MXGd9xvHvU8eU5U+1gWxQvEAN\nKHKpSLHpKkobKaJA64AQMVFRSVtktbShEqhQkEVML4CLKkHmj6peRAokTS5oVArGQS3FWCFtWqmk\n3eAQO3XdFMqfrN14KSzQsqKO+fVix2Bv1t6Z9c7OzLvfj7SamXfP6DxaK0/mvOedc1JVSJJG308N\nOoAkaXVY6JLUCAtdkhphoUtSIyx0SWqEhS5JjbDQJakRFrokNcJCl6RGXLSWO7vkkktq8+bNa7lL\nSRp5Dz744LeqamK57da00Ddv3sz09PRa7lKSRl6Sr3eznVMuktQIC12SGmGhS1Ijli30JE9N8s9J\nvpzkkSTv74y/IMkDSR5N8pdJntL/uJKkc+nmE/oPgVdU1UuBrcC1Sa4CPgB8pKouB74DvLl/MSVJ\ny1l2lUst3AHjfzovN3Z+CngF8Jud8buA9wG3rn5ESRpN+w7OsGf/UY7NzbNpfIxd27ewY9tk3/bX\n1Rx6kg1JHgJOAAeArwBzVfVEZ5PHgP6llKQRs+/gDLv3HmJmbp4CZubm2b33EPsOzvRtn10VelWd\nqqqtwHOBK4EXL7XZUu9NcmOS6STTs7OzK08qSSNkz/6jzJ88ddbY/MlT7Nl/tG/77GmVS1XNAX8H\nXAWMJzk9ZfNc4Ng53nNbVU1V1dTExLJfdJKkJhybm+9pfDV0s8plIsl45/kY8CrgCHAf8OudzXYC\n9/QrpCSNmk3jYz2Nr4ZuPqFfBtyX5GHgX4ADVfXXwLuBdyb5D+DZwO19SylJI2bX9i2Mbdxw1tjY\nxg3s2r6lb/vsZpXLw8C2Jca/ysJ8uiRpkdOrWdZylcuaXpxLktaTHdsm+1rgi/nVf0lqhIUuSY2w\n0CWpERa6JDXCQpekRljoktQIC12SGmGhS1IjLHRJaoSFLkmNsNAlqREWuiQ1wkKXpEZY6JLUCAtd\nkhphoUtSIyx0SWqEhS5JjbDQJakRFrokNcJCl6RGWOiS1AgLXZIaYaFLUiMsdElqhIUuSY2w0CWp\nERa6JDXCQpekRljoktQIC12SGmGhS1IjLHRJaoSFLkmNsNAlqRHLFnqS5yW5L8mRJI8keXtn/H1J\nZpI81Pl5Tf/jSpLO5aIutnkCeFdVfSnJM4EHkxzo/O4jVfXB/sWTJHVr2UKvquPA8c7z7yc5Akz2\nO5gkqTc9zaEn2QxsAx7oDL0tycNJ7khy8SpnkyT1oOtCT/IM4FPAO6rqe8CtwIuArSx8gv/QOd53\nY5LpJNOzs7OrEFmStJSuCj3JRhbK/ONVtRegqh6vqlNV9SPgo8CVS723qm6rqqmqmpqYmFit3JKk\nRbpZ5RLgduBIVX34jPHLztjs9cDh1Y8nSepWN6tcrgbeBBxK8lBn7D3ADUm2AgV8DXhLXxJKkrrS\nzSqXfwSyxK8+u/pxJEkr5TdFJakRFrokNcJCl6RGdHNSVGrSvoMz7Nl/lGNz82waH2PX9i3s2OaX\noDW6LHStS/sOzrB77yHmT54CYGZunt17DwFY6hpZTrloXdqz/+iPy/y0+ZOn2LP/6IASSRfOQte6\ndGxuvqdxaRRY6FqXNo2P9TQujQILXevSru1bGNu44ayxsY0b2LV9y4ASSRfOk6Jal06f+HSVi1pi\noWvd2rFt0gJXU5xykaRGWOiS1AgLXZIaYaFLUiMsdElqhKtcJKlHw3phNwtdknowzBd2c8pFknow\nzBd2s9AlqQfDfGE3C12SejDMF3az0CWpB8N8YTdPikpSD4b5wm4WuiT1aFgv7OaUiyQ1wkKXpEZY\n6JLUCAtdkhphoUtSIyx0SWqEhS5JjbDQJakRFrokNcJCl6RGWOiS1AgLXZIasWyhJ3lekvuSHEny\nSJK3d8afleRAkkc7jxf3P64k6Vy6+YT+BPCuqnoxcBXw1iQ/D9wE3FtVlwP3dl5rBO07OMPVt3yB\nF9z0N1x9yxfYd3Bm0JEkrcCyhV5Vx6vqS53n3weOAJPAdcBdnc3uAnb0K6T65/QNb2fm5il+csNb\nS10aPT3NoSfZDGwDHgCeU1XHYaH0gUtXO5z6b5hveCupN10XepJnAJ8C3lFV3+vhfTcmmU4yPTs7\nu5KM6qNhvuGtpN50VehJNrJQ5h+vqr2d4ceTXNb5/WXAiaXeW1W3VdVUVU1NTEysRmatomG+4a2k\n3nSzyiXA7cCRqvrwGb/6DLCz83wncM/qx1O/DfMNbyX1ppt7il4NvAk4lOShzth7gFuATyR5M/AN\n4A39iah+GuYb3krqTapqzXY2NTVV09PTa7Y/SWpBkgeramq57fymqCQ1wkKXpEZY6JLUCAtdkhph\noUtSI7pZtqhVsu/gjMsDJfWNhb5GTl8E6/R1U05fBAuw1CWtCgt9jZzvIlgW+uB41KSWWOhrxItg\nDR+PmtQaT4quES+CNXy8dLBaY6GvES+CNXw8alJrLPQ1smPbJDdffwWT42MEmBwf4+brr/DQfoA8\nalJrnENfQzu2TVrgQ2TX9i1nzaGDR00abRa61i0vHazWWOha1zxqUkucQ5ekRljoktQIC12SGmGh\nS1IjLHRJaoSFLkmNsNAlqREWuiQ1wkKXpEZY6JLUCAtdkhphoUtSIyx0SWqEhS5JjbDQJakRFrok\nNWIkbnCx7+CMd5WRpGUMfaHvOzhz1n0fZ+bm2b33EIClLklnGPoplz37j551E1+A+ZOn2LP/6IAS\nSdJwGvpCPzY339O4JK1XQ1/om8bHehqXpPVq2UJPckeSE0kOnzH2viQzSR7q/LymXwF3bd/C2MYN\nZ42NbdzAru1b+rVLSRpJ3XxCvxO4donxj1TV1s7PZ1c31k/s2DbJzddfweT4GAEmx8e4+forPCEq\nSYssu8qlqu5Psrn/Uc5tx7ZJC1ySlnEhc+hvS/JwZ0rm4lVLJElakZUW+q3Ai4CtwHHgQ+faMMmN\nSaaTTM/Ozq5wd5Kk5ayo0Kvq8ao6VVU/Aj4KXHmebW+rqqmqmpqYmFhpTknSMlZU6EkuO+Pl64HD\n59pWkrQ2lj0pmuRu4OXAJUkeA94LvDzJVqCArwFv6WNGSVIXulnlcsMSw7f3IYsk6QIM/TdFJUnd\nsdAlqREWuiQ1wkKXpEZY6JLUCAtdkhphoUtSIyx0SWqEhS5JjbDQJakRFrokNcJCl6RGWOiS1AgL\nXZIaYaFLUiMsdElqhIUuSY2w0CWpERa6JDXCQpekRljoktQIC12SGmGhS1IjLHRJaoSFLkmNsNAl\nqREWuiQ1wkKXpEZY6JLUCAtdkhphoUtSIyx0SWqEhS5JjbDQJakRFrokNcJCl6RGLFvoSe5IciLJ\n4TPGnpXkQJJHO48X9zemJGk53XxCvxO4dtHYTcC9VXU5cG/ntSRpgJYt9Kq6H/j2ouHrgLs6z+8C\ndqxyLklSj1Y6h/6cqjoO0Hm8dPUiSZJWou8nRZPcmGQ6yfTs7Gy/dydJ69ZKC/3xJJcBdB5PnGvD\nqrqtqqaqampiYmKFu5MkLWelhf4ZYGfn+U7gntWJI0laqW6WLd4N/BOwJcljSd4M3AL8apJHgV/t\nvJYkDdBFy21QVTec41evXOUskqQL4DdFJakRFrokNcJCl6RGWOiS1AgLXZIaYaFLUiMsdElqhIUu\nSY2w0CWpERa6JDXCQpekRix7LZdRs+/gDHv2H+XY3DybxsfYtX0LO7ZNDjqWJPVdU4W+7+AMu/ce\nYv7kKQBm5ubZvfcQgKUuqXlNTbns2X/0x2V+2vzJU+zZf3RAiSRp7TRV6Mfm5nsal6SWNFXom8bH\nehqXpJY0Vei7tm9hbOOGs8bGNm5g1/YtA0okSWunqZOip098uspF0nrUVKHDQqlb4JLWo6amXCRp\nPbPQJakRFrokNcJCl6RGWOiS1IhU1drtLJkFvr7MZpcA31qDOBfCjKtnFHKacXWMQkYYzpw/W1UT\ny220poXejSTTVTU16BznY8bVMwo5zbg6RiEjjE7OpTjlIkmNsNAlqRHDWOi3DTpAF8y4ekYhpxlX\nxyhkhNHJ+SRDN4cuSVqZYfyELklagaEq9CRfS3IoyUNJpgedZylJxpN8Msm/JTmS5JcGnelMSbZ0\n/n6nf76X5B2DzrVYkj9K8kiSw0nuTvLUQWdaLMnbO/keGaa/YZI7kpxIcviMsWclOZDk0c7jxUOY\n8Q2dv+WPkgx8Fck5Mu7p/Lf9cJJPJxkfZMZeDVWhd/xKVW0d4mVDfwp8rqp+DngpcGTAec5SVUc7\nf7+twC8CPwA+PeBYZ0kyCfwhMFVVLwE2AG8cbKqzJXkJ8PvAlSz8O782yeWDTfVjdwLXLhq7Cbi3\nqi4H7u28HqQ7eXLGw8D1wP1rnmZpd/LkjAeAl1TVLwD/Duxe61AXYhgLfWgl+RngGuB2gKr6v6qa\nG2yq83ol8JWqWu7LXINwETCW5CLgacCxAedZ7MXAF6vqB1X1BPD3wOsHnAmAqrof+Pai4euAuzrP\n7wJ2rGmoRZbKWFVHqmpobvB7joyf7/x7A3wReO6aB7sAw1boBXw+yYNJbhx0mCW8EJgF/jzJwSQf\nS/L0QYc6jzcCdw86xGJVNQN8EPgGcBz4blV9frCpnuQwcE2SZyd5GvAa4HkDznQ+z6mq4wCdx0sH\nnKcFvwv87aBD9GLYCv3qqnoZ8GrgrUmuGXSgRS4CXgbcWlXbgP9l8Ie2S0ryFOB1wF8NOstinfnd\n64AXAJuApyf57cGmOltVHQE+wMIh+OeALwNPnPdNakaSP2bh3/vjg87Si6Eq9Ko61nk8wcK875WD\nTfQkjwGPVdUDndefZKHgh9GrgS9V1eODDrKEVwH/WVWzVXUS2Av88oAzPUlV3V5VL6uqa1g4NH90\n0JnO4/EklwF0Hk8MOM/ISrITeC3wWzVi67qHptCTPD3JM08/B36NhcPeoVFV/wV8M8npu06/EvjX\nAUY6nxsYwumWjm8AVyV5WpKw8HccqpPLAEku7Tw+n4WTecP69wT4DLCz83wncM8As4ysJNcC7wZe\nV1U/GHSeXg3NF4uSvJCfrMa4CPiLqvqTAUZaUpKtwMeApwBfBX6nqr4z2FRn68z5fhN4YVV9d9B5\nlpLk/cBvsHBYexD4var64WBTnS3JPwDPBk4C76yqewccCYAkdwMvZ+GqgI8D7wX2AZ8Ans/C/zDf\nUFWLT5wOOuO3gT8DJoA54KGq2j5kGXcDPw38d2ezL1bVHwwk4AoMTaFLki7M0Ey5SJIujIUuSY2w\n0CWpERa6JDXCQpekRljoktQIC12SGmGhS1Ij/h/CJYJPfXoR0gAAAABJRU5ErkJggg==\n", - "text/plain": [ - "
" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], - "source": [ - "# 来看看产生x-y分布是什么样的\n", - "x, y = get_fake_data()\n", - "plt.scatter(x.squeeze().numpy(), y.squeeze().numpy())" - ] - }, - { - "cell_type": "code", - "execution_count": 50, - "metadata": { - "scrolled": false - }, - "outputs": [ - { - "data": { - "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXwAAAD8CAYAAAB0IB+mAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBo\ndHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAAIABJREFUeJzt3Xl4VOX5xvHvkz1AICA7EsMaQJag\nEQXUWrWCW6EuqP1VsWqpbW0LCArWVtwqrVbp4qXFaqVWawBRBFFqBUWtG0gWIOyCLGEnhCVAlvf3\nRwYLIZMMyayZ+3NduZg5c+bM09PjnZP3vPMcc84hIiINX0yoCxARkeBQ4IuIRAkFvohIlFDgi4hE\nCQW+iEiUUOCLiEQJBb6ISJRQ4IuIRAkFvohIlIgL5oe1bNnSpaenB/MjRSTK5G/ZV+PrqY3iaZ+a\nTKxZkCqqvyVLluxyzrWq73aCGvjp6eksXrw4mB8pIlFm8OQFbCkqOWm5GfzlprO4sm+7EFRVP2a2\n0R/b8XlIx8xizWypmc31PO9kZp+Z2RozyzazBH8UJCJSH+OHZJAcH3vCshiDSVefGZFh70+nMob/\nS6DguOe/A55yznUD9gK3+7MwEZG6GN6/AzcN6EiMZ8SmWXI8f7i+HyMHpYe0rnDg05COmZ0OXAk8\nCow1MwMuBr7vWWUaMAl4JgA1ioj4ZP/hUh6as4IZSzbTu0NTptyQSdfWKaEuK2z4OoY/BbgHOLbn\nTgOKnHNlnuebgQ5+rk1ExGdfbNjDmOwcthaVcNe3u/KLS7qREKeJiMerNfDN7Cpgh3NuiZlddGxx\nNatW21jfzEYBowDS0tLqWKaISPWOllXw1H9W8+wH6+jYvBEz7hzI2We0CHVZYcmXM/zBwHfN7Aog\nCWhK5Rl/qpnFec7yTwe2Vvdm59xUYCpAVlaW7rYiIn6zevt+Rr+aw4rCYm48pyP3X9WLJolBnXwY\nUWr9e8c5N9E5d7pzLh24EVjgnPs/YCFwnWe1kcDsgFUpInKcigrH8x99xVV//ojtxYd57pYsJl/b\nV2Ffi/rsnXuBV83sEWAp8Lx/ShIR8a5wXwnjZuTy8drdXNKjNZOv7UurlMRQlxURTinwnXPvA+97\nHq8HBvi/JBGR6r2Zu5X7X8+nrMLx2DV9uPGcjlgEfWM21PT3j4iEvX2HSvnNm8uYnbOV/mmpPDUi\nk/SWjUNdVsRR4ItIWPt47S7Gzchlx/4jjP1Od356URfiYjXdsi4U+CISlg6XlvP4/FU8/9FXdG7V\nmFk/GUS/jqmhLiuiKfBFJOws37qPMdk5rN5+gFsGnsHEy3uSnBBb+xulRgp8EQkb5RWOqYvW8+S7\nq0htlMCLPzyHizJah7qsBkOBLyJhYdOeQ9w9PZfPN+zh8t5tefR7fWjRWE14/UmBLyIh5Zxj5pLN\nPDhnBQB/uL4f15zVQdMtA0CBLyIhs+fgUe6blc87y7cxIL0FfxjRj44tGoW6rAZLgS8iIbFw1Q7u\nmZlH0aGjTLi8Bz+6oDOxMTqrDyQFvogE1aGjZTw2byUvfbqR7m2aMO2HA+jVvmmoy4oKCnwRCZqc\nTUWMzc5h/a6D3HF+J8YNySApXtMtg0WBLyIBV1ZewdML1/GnBWtok5LIK3ecy6CuLUNdVtRR4ItI\nQH216yBjsnPI2VTE8Mz2PDisN82S40NdVlRS4ItIQDjneOXzr3lkbgHxscafb+rP1f3ah7qsqKbA\nFxG/27H/MBNey2fByh2c37Ulj1/fl3bNkkNdVtRT4IuIX81fvo2Js/I5eKSMB67uxciB6cRoumVY\nUOCLiF8cOFLGQ3OWM33xZs5s35QpN2TSrU1KqMuS4yjwRaTeFm/Yw5jpOWzZW8JPL+rC6Eu7kxCn\nnvXhptbAN7MkYBGQ6Fl/pnPuATN7EfgWsM+z6q3OuZxAFSoi4edoWQV/fG81z7y/jg7Nk8n+8UDO\nSW8R6rLEC1/O8I8AFzvnDphZPPCRmb3teW28c25m4MoTkXC1dsd+RmfnsGxLMdeffTq/uboXKUma\nbhnOag1855wDDniexnt+XCCLEpHwVVHhmPbJBia/vZLGiXH89eazGXJm21CXJT7waZDNzGLNLAfY\nAbzrnPvM89KjZpZnZk+ZWaKX944ys8Vmtnjnzp1+KltEQmHbvsOM/PvnPDhnBYO7tuSd0Rco7COI\nVZ7A+7iyWSrwOvBzYDewDUgApgLrnHMP1fT+rKwst3jx4rpXKyIhMyd3K/e/sYyjZRXcf1VPvj8g\nrd49699YuoXH569ia1EJ7VOTGT8kg+H9O/ip4obDzJY457Lqu51TmqXjnCsys/eBoc65JzyLj5jZ\n34Fx9S1GRMLPvpJSHpi9jDdytpLZMZWnbsikU8vG9d7uG0u3MHFWPiWl5QBsKSph4qx8AIV+gNQ6\npGNmrTxn9phZMnApsNLM2nmWGTAcWBbIQkUk+P67bheXT1nEnLxCxlzanZl3DvRL2AM8Pn/VN2F/\nTElpOY/PX+WX7cvJfDnDbwdMM7NYKn9BTHfOzTWzBWbWCjAgB7gzgHWKSBAdLi3nD/9exd8++or0\n0xrz2k8Gkdkx1a+fsbWo5JSWS/35MksnD+hfzfKLA1KRiIRUQWExY7JzWLltPz84L437ruhJowT/\nf0ezfWoyW6oJ9/ap6rkTKPoqnIgAUF7hmLpoHcP+8jG7Dx7l7z88h0eG9wlI2AOMH5JBcpWbnyTH\nxzJ+SEZAPk/UWkFEgM17D3H39Fw++2oPQ85sw2PX9KVF44SAfuaxC7OapRM8CnyRKOacY9aXW5j0\n5nIc8MT1/bj2rA71nm7pq+H9Oyjgg0iBLxKl9h48yq/eyGde/jbOSW/OkyMy6diiUajLkgBS4ItE\noQ9W72T8jFz2HjrKvUN7MOrCzsSqZ32Dp8AXiSIlR8uZ/HYB0z7ZSLfWTXjh1nPo3aFZqMuSIFHg\ni0SJvM1FjMnOYd3Og9w2uBP3DM0gqcosGWnYFPgiDVxZeQXPvL+OP763hpZNEnn5jnMZ3LVlqMuS\nEFDgizRgG3cfZEx2Dl9+XcR3+7Xn4WG9adZIPeujlQJfpAFyzvHqF5t4eO4K4mKMP96YybBMTX+M\ndgp8kQZm5/4jTJyVx38KdjCoy2k8cX0/tSsQQIEv0qD8Z8V27n0tj/1Hyvj1Vb344aB0YjTdUjwU\n+CINwMEjZTw8dwWvfrGJnu2a8soNmWS0TQl1WRJmFPgiEW7Jxr2Myc5h095D3PmtLoz5TjcS4zTd\nUk6mwBeJUKXlFfzpvTU8vXAt7Zolkz1qIAM6tQh1WRLGFPgiEWjtjgOMyc4hf8s+rjv7dB64uhcp\nSZpuKTVT4ItEEOcc//hkI7+dV0CjhFie/cFZDO3dLtRlSYSoNfDNLAlYBCR61p/pnHvAzDoBrwIt\ngC+Bm51zRwNZrEg02158mPEz81i0eicXZbTi99f2pXXTpFCXJRHElzP8I8DFzrkDZhYPfGRmbwNj\ngaecc6+a2bPA7cAzAaxVJGrNyy/kvtfzOVxazsPDe/ODc9OC1rNeGg5f7mnrgAOep/GeHwdcDHzf\ns3waMAkFvohfFR8uZdLs5cxauoV+pzfjyRsy6dKqSajLkgjl0xi+mcUCS4CuwNPAOqDIOVfmWWUz\noO9ti/jRp+t3c/f0XLYVH+aXl3Tjrou7Eh+r21BL3fkU+M65ciDTzFKB14Ge1a1W3XvNbBQwCiAt\nLa2OZYpEjyNl5Tz579VM/XA9Z7RoxMw7B9I/rXmoy5IG4JRm6TjniszsfeA8INXM4jxn+acDW728\nZyowFSArK6vaXwoiUmnltmJGv5rDym37+f65adx/ZU8aJWgynfhHrX8fmlkrz5k9ZpYMXAoUAAuB\n6zyrjQRmB6pIkYauosLxtw/Xc+WfPmL19v0AfLBqJ/9evj3ElUlD4supQztgmmccPwaY7pyba2Yr\ngFfN7BFgKfB8AOsUabC2FJUwbnoun6zfTYxBhfvf8omz8gEY3l+XyKT+fJmlkwf0r2b5emBAIIoS\niQbOOWbnbOXXs5dRUeFITY6nqKT0hHVKSst5fP4qBb74hS75i4RA0aGj3PWvpYzOziGjTQpv//JC\n9lUJ+2O2FpUEuTppqHQ1SCTIPlyzk3Ezctl94Cjjh2Rw57e6EBtjtE9NZks14a6bl4i/6AxfJEgO\nl5Yz6c3l3Pz856QkxfPGzwbzs293JdZzg5LxQzJIjj+xrXFyfCzjh2SEolxpgHSGLxIEy7bsY3R2\nDmt3HODWQelMuLwHSVXC/dg4/ePzV7G1qIT2qcmMH5Kh8XvxGwW+SACVVzie/WAdT727mtOaJPDS\n7QO4oFsrr+sP799BAS8Bo8AXCZCvdx9izPQclmzcy5V92/Ho8N6kNkoIdVkSxRT4In7mnGPG4s08\nOGc5MTHGlBsyGZbZXt0tJeQU+CJ+tPvAESbOyuffK7YzsPNpPDGiHx00y0bChAJfxE8WrNzOPTPz\nKC4p4/4re3Lb4E7ExOisXsKHAl+kng4eKePReQW88tnX9Gibwj/vOJcebZuGuiyRkyjwRerhy6/3\nMjY7h417DvHjCzsz9rLuJMbF1v5GkRBQ4IvUQWl5BX9esJanF66lbdMk/vWj8ziv82mhLkukRgp8\nkVO0bucBxmbnkLt5H9f078CkYWfSNCk+1GWJ1EqBL+Ij5xz//OxrHn1rBUnxsTz9/bO4sm+7UJcl\n4jMFvogPdhQf5p7X8nh/1U4u7N6Kx6/rS5umSaEuS+SUKPBFvHhj6RYen7+KLUUlxBjExhgPDTuT\nm887Q1+ikoikwBepxhtLtzDhtTwOl1UAlXehijejaVK8wl4iltoji1TjkbdWfBP2xxwpq+Dx+atC\nVJFI/flyE/OOZrbQzArMbLmZ/dKzfJKZbTGzHM/PFYEvVySwjpSVM/ntlew6cLTa13X3KYlkvgzp\nlAF3O+e+NLMUYImZvet57Snn3BOBK08keFZv388vX82hoLCYRgmxHDpaftI6uvuURDJfbmJeCBR6\nHu83swJADbulwaiocPz9vxv43TsrSUmM47lbsjh4pIyJs/IpKf1f6Ef73aeOXcTWzVki1yldtDWz\ndKA/8BkwGLjLzG4BFlP5V8Deat4zChgFkJaWVs9yRfyrcF8J42bk8vHa3VzaszWTr+1LyyaJ37yu\ngKv0xtItJ/wC3FJUwsRZ+QBRu08ikTnnfFvRrAnwAfCoc26WmbUBdgEOeBho55y7raZtZGVlucWL\nF9ezZBH/mJ2zhV+/sYzDZRU0io+lqKSUDlEe7N4Mnryg2husd0hN5uMJF4egouhiZkucc1n13Y5P\nZ/hmFg+8BrzsnJsF4JzbftzrzwFz61uMSDDsO1TKr2cv483craSf1ojCfYcpKikFdObqjbeL1bqI\nHVl8maVjwPNAgXPuyeOWH/+d8u8By/xfnoh/fbx2F0OmLGJefiHjLuvO0bIKjlSZfllSWq7pl1V4\nu1iti9iRxZd5+IOBm4GLq0zB/L2Z5ZtZHvBtYEwgCxWpj8Ol5Tw0ZwX/97fPaJQYy6yfDuKui7tR\nuO9wtevrzPVE44dkkBx/YtvnaL+IHYl8maXzEVDdVwvn+b8cEf9bvnUfo1/NYc2OA4wceAYTLu9J\nckJleLVPTa52bFpnric6Nryli9iRTa0VpMEqr3BMXbSeJ99dRfNGCUy7bQDf6t7qhHXGD8nQ9Esf\nDe/fQQEf4RT40iBt2nOIsdNz+GLDXq7o05ZHh/eheeOEk9bTmatEEwW+NCjOOWYu2cyDc1ZgwJMj\n+vG9/h1qbHimM1eJFgp8aTD2HDzKxFl5zF++nQGdWvDkiH6c3rxRqMsSCRsKfGkQFq7cwfiZeRSX\nlHLfFT24/fzOxMaojbHI8RT4EtEOHS3jt/MK+OenX5PRJoWXbh9Az3ZNQ12WSFhS4EvEytlUxNjs\nHL7afZA7zu/EuCEZJFWZKy4i/6PAl4hTVl7BXxau5c8L1tImJZGX7ziXQV1ahroskbCnwJeI8tWu\ng4zOziF3UxHDM9vz4LDeNEuOD3VZIhFBgS8RwTnHK59/zSNzC0iIi+HPN/Xn6n7tQ12WSERR4EvY\n27H/MPfOzGPhqp1c0K0lj1/Xj7bNkkJdlkjEUeBLWJu/fBsTZ+Vz8EgZk67uxS0D04mpx3RL3bVJ\nopkCX8LSgSNlPPjmcmYs2cyZ7Zsy5YZMurVJqdc2ddcmiXYKfAk7X2zYw9jpOWzZW8Jd3+7KLy7p\nRkKcL528a/b4/FUnNEmD//W+V+BLNFDgS9g4WlbBlP+s5tkP1nF680ZM//FAstJb+G37umuTRDsF\nvoSF1dv3M/rVHFYUFnNDVkd+fXUvmiTW7fD0Nk6v3vcS7RT4ElIVFY4X/7uBye+spEliHH+9+WyG\nnNm2zturaZxeve8l2tUa+GbWEfgH0BaoAKY65/5oZi2AbCAd2ACMcM7tDVyp0tAU7ith/Iw8Plq7\ni0t6tGbytX1plZJYr23WNE7/8YSLv1lHs3QkGvlyhl8G3O2c+9LMUoAlZvYucCvwnnNusplNACYA\n9wauVGlI5uRu5Vev51Na7vjt9/pw04CONfas91Vt4/TqfS/RzJd72hYChZ7H+82sAOgADAMu8qw2\nDXgfBb7UYt+hUn7z5jJm52wls2MqU27IJL1lY79tX+P0It6d0lw3M0sH+gOfAW08vwyO/VJo7e/i\npGH579pdDP3jIubmFTL2O92ZeedAv4Y9VI7TJ1fpmKlxepFKPl+0NbMmwGvAaOdcsa9/fpvZKGAU\nQFpaWl1qlDB0Kt9YPewZQ3/+o6/o3LIxs34yiH4dUwNSl+5RK+KdOedqX8ksHpgLzHfOPelZtgq4\nyDlXaGbtgPedczWeRmVlZbnFixf7oWwJpaozYaDyLPqxa/qcFKwrthYzOnspq7cf4JaBZzDx8p4k\nJ6hnvcipMLMlzrms+m6n1iEdqzyVfx4oOBb2Hm8CIz2PRwKz61uMRIaaZsIcU17hePaDdQx7+iP2\nHirlxR+ew0PDeivsRULIlyGdwcDNQL6Z5XiW3QdMBqab2e3A18D1gSlRwk1tM2E27TnE3dNz+XzD\nHoae2ZbfXtOHFo0TglmiiFTDl1k6HwHeBuwv8W85Egm8zYRp1yyJmUs2M+nN5QA8cX0/rj2rg1+m\nW4pI/dW/I5VEnepmwiTFxdAqJZFxM3Lp1a4pb//yAq47+3SFvUgYUWsFOWVVZ8K0aJxAaXkFKwqL\nmXB5D350QWdi69GzXkQCQ4EvdTK8fweGnNmW384r4KVPN9K9TROm3NCfXu2bhro0EfFCgS91krup\niDHZOazfdZA7zu/EuCEZJMVrBo5IOFPgyykpK6/g6YXr+NOCNbROSeSVO85lUNeWoS5LRHygwBef\nfbXrIGOyc8jZVMSwzPY89N3eNGsUH+qyRMRHCnyplXOOf32+iYfnriA+1vjTTf35br/2oS5LRE6R\nAl9qtHP/ESa8lsd7K3cwuOtpPHF9P9o1U+dJkUikwBev3l2xnQmv5bH/SBm/uaoXtw5KJ0bTLUUi\nlgJfTnLgSBkPz1lB9uJN9GrXlFdvzKRbm5RQlyUi9aTAr4dTaREcKZZs3MOY7Fw27T3ETy/qwuhL\nu5MQpy9kizQECvw6qulm2ZEY+kfLKvjje6t55v11tE9NZvqPB3JOeotQlyUifqTAr6OaWgRHWuCv\n3bGf0dk5LNtSzPVnn85vru5FSpKmW4o0NAr8OqqtRXAkqKhw/OOTDTz29koaJcTy7A/OYmjvdqEu\nS0QCRIFfR5F+s+ztxYcZNyOXD9fs4qKMVvz+ur60TkkKdVkiEkC6GldHkXyz7LfyCrnsqUV8sWEP\njwzvzd9vPUdhLxIFdIZfR5F4s+ziw6U8MHs5ry/dQr+OqTw1oh+dWzUJdVkiEiQK/HoY3r9DWAf8\n8T5dv5u7p+eyrfgwoy/txs++3ZX42FP7A68hTkMViSa+3MT8BTPbYWbLjls2ycy2mFmO5+eKwJYp\ndXWkrJzfzivgpuc+JSEuhpl3DmT0pd3rFPYTZ+WzpagEx/+mob6xdEtgChcRv/Plv/oXgaHVLH/K\nOZfp+Znn37LEHwoKixn2l4+Zumg93x+Qxlu/OJ/+ac3rtK2apqGKSGTw5Sbmi8wsPfCliL9UVDj+\n9tF6npi/mqbJ8bxwaxYX92hTr202hGmoItGuPmP4d5nZLcBi4G7n3N7qVjKzUcAogLS0tHp8nPhi\n895DjJuRy6fr93BZrzY8dk0fTmuSWO/tRvo0VBGp+7TMZ4AuQCZQCPzB24rOuanOuSznXFarVq3q\n+HFSG+ccry/dzOVTPiR/8z5+f11f/nrz2X4Je4jsaagiUqlOZ/jOue3HHpvZc8Bcv1Ukp6zo0FF+\n9foy3sovJOuM5jx1QyYdWzTy62dE4jRUETlRnQLfzNo55wo9T78HLKtpfQmcD9fsZNyMXPYcPMo9\nQzP48YVdiA1Qz/pImoYqIierNfDN7F/ARUBLM9sMPABcZGaZgAM2AD8OYI1SjcOl5Ux+eyUv/ncD\n3Vo34fmR59C7Q7NQlyUiYcyXWTo3VbP4+QDUIj7K37yPMdNzWLvjALcN7sQ9QzNIqjK+LiJSlb5p\nG0HKyit49oN1TPnPGlo2SeSft5/L+d1ahrosEYkQCvwIsXH3QcZOz2XJxr1c1bcdjwzvTWqjhFCX\nJSIRRIEf5pxzTF+8iYfmrCAmxvjjjZkMy9SFUxE5dQr8MLbrwBEmzsrn3RXbGdj5NP4wop++6CQi\ndabAD1P/WbGdCbPyKD5cxv1X9uS2wZ2ICdB0SxGJDgr8MHPwSBmPvLWCf32+iZ7tmvLyHZlktE0J\ndVki0gAo8MPEG0u38OhbBew8cASAi3u05pkfnEViXO3TLdWnXkR8oVschoHXlmxm3Izcb8Ie4JN1\nu3k7f1ut71WfehHxlQI/xNbtPMCEWXmUVbgTlvvaa1596kXEVxrSCRHnHC99upGH566gtNxVu44v\nvebVp15EfKXAD4EdxYcZPzOPD1bvpKaJN75MwVSfehHxlYZ0guzt/EIum7KIz77aTbPkeCqqP7n3\nude8+tSLiK8U+EFSfLiUsdNz+MnLX5LWohFzf34BxSWlXtd/7Jo+Ps20Gd6/A49d04cOqckY0CE1\n2ef3ikh00ZBOEHy2fjdjp+dSuK+EX1zclZ9f0o342BivwzEdUpNPKbDVp15EfKEz/AA6UlbOY28X\ncONznxIXa8z8ySDGXpZBfGzlbtdwjIgEk87wA2TVtv2Mzs6hoLCYmwakcf+VPWmceOLu1m0DRSSY\nFPh+VlHheOHjr/j9O6tomhzH327J4tJebbyur+EYEQkWX25x+AJwFbDDOdfbs6wFkA2kU3mLwxHO\nub2BKzMybC0q4e7puXyyfjeX9mzD5Gv70LJJYqjLEhEBfBvDfxEYWmXZBOA951w34D3P86g2O2cL\nQ6YsIndzEb+7tg/P3XK2wl5Ewoov97RdZGbpVRYPo/LG5gDTgPeBe/1YV8TYd6iU+2cvY07uVs4+\nozlPjujHGac1DnVZIiInqesYfhvnXCGAc67QzFr7saawUlMnyo/W7GLcjFx2HTjCuMu6c+e3uhAX\nq4lPIhKeAn7R1sxGAaMA0tLS/L79QLYGPtaJ8lhzsmOdKI+WVVCwrZi/f7yBLq0a89wtg+lzejO/\nfKaISKDUNfC3m1k7z9l9O2CHtxWdc1OBqQBZWVleGgnUjbdABvwS+t46Ud73ej5lFY5bB6Vz79Ae\nJCfU3rNeRCTU6jr+8CYw0vN4JDDbP+WcmkC3BvbWcbKswvGP2wYw6btnKuxFJGLUGvhm9i/gEyDD\nzDab2e3AZOA7ZrYG+I7nedAFujWwt46T7ZomcWH3Vn75DBGRYPFlls5NXl66xM+1nLJAtwYed1l3\n7nkt74R+9UlxMdx7eQ+/bF9EJJgiekpJIHvR7D5whHeWb6O03JEQV7mbOqQmM/navvpmrIhEpIhu\nrRCoXjQLV+5g/Mw8iktKue+KHtxxfmdiarpTiYhIBIjowAf/9qI5dLSMR98q4OXPvqZH2xReun0A\nPds19cu2RURCLeID31+Wfr2XsdNz2bD7IKMu7MzY73QnKV4zcESk4Yj6wC8tr+AvC9byl4Vrads0\niVfuOI+BXU4LdVkiIn4X1YG/fucBxkzPJXdTEdf078CkYWfSNCk+1GWJiARERM/SqSvnHP/8dCND\np3xI3uYiAD77ag8LCrx+YVhEJOJF3Rn+jv2HuXdmHgtX7STGwHmm2Pu7LYOISLiJqjP8d5ZtY+iU\nD/nvut00S46nokpnH3+2ZRARCTdREfj7D5cyfkYud/5zCe1Tk3jrF+dTXFJa7br+assgIhJuGvyQ\nzhcb9jAmO4etRSXc9e2u/OKSbiTExQS8LYOISLhpsGf4R8sq+N07Kxnx10+IMWPGnQMZNyTjmzYJ\ngWzLICISjhrkGf7q7fsZ/WoOKwqLufGcjtx/VS+aJJ74PzVQbRlERMJVxAV+TXe4qqhw/P2/G/jd\nOytJSYzjuVuy+E6vNl635c+2DCIi4S6iAr+mO1yd27kF42bk8vHa3VzSozWTr+1Lq5TEUJYrIhJW\nIirwvd3h6sE5yymvcJRVOB67pg83ntMRM3W3FBE5XkQFvrcpk3sPldI/LZWnRmSS3rJxkKsSEYkM\nERX43qZSpiTFMePHA4mLbbCTjkRE6q1eCWlmG8ws38xyzGyxv4ryZvyQDJLiTiw5MTaGh4f1VtiL\niNTCH2f433bO7fLDdmrVrU0TUhslsK34MADtmiVx79AemmkjIuKDiBjSKa9wTF20niffXUXzRglM\nu20A3+reKtRliYhElPoGvgP+bWYO+KtzbmrVFcxsFDAKIC0t7ZQ/YNOeQ9w9PZfPN+zhij5teXR4\nH5o3Tqhn2SIi0ae+gT/YObfVzFoD75rZSufcouNX8PwSmAqQlZXlqttIdZxzzFyymQfnrMCAJ0f0\n43v9O2i6pYhIHdUr8J1zWz3/7jCz14EBwKKa31W7PQePct+sfN5Zvo0BnVrw5Ih+nN68UX03KyIS\n1eoc+GbWGIhxzu33PL4MeKi+BS1ctYN7ZuZRdOgoEy/vwR0XdCY2Rmf1IiL1VZ8z/DbA654hljjg\nFefcO3XdWMnRcn47r4CXPt1rJDRZAAAHNElEQVRIRpsUpv1wAL3aN61HeSIicrw6B75zbj3Qzx9F\n5G4qYkx2Dl/tPsiPLujE3ZdlkFSldbGIiNRPSKdllpVX8PTCdfxpwRrapCTy8h3nMqhLy1CWJCLS\nYIUs8L/adZAx2TnkbCpieGZ7BnRqwfgZeepNLyISIEEPfOccr3z+NY/MLSAhLoY/39Sf8grnte2x\nQl9ExD+CGvhlFY47pi3mvZU7OL9rS564vh9tmyUxePKCatsePz5/lQJfRMRPghr4a7bvp2TtLh64\nuhcjB6YT45lu6a3tsbflIiJy6oIa+HGxMcz9+fl0a5NywnJvbY/bpyYHqzQRkQYvqD2Fu7ZqclLY\nQ2Xb4+Qq0zCT42MZPyQjWKWJiDR4QT3D99YG59g4vbebk4uISP2FTXvk4f07KOBFRAJIt4kSEYkS\nCnwRkSihwBcRiRIKfBGRKKHAFxGJEgp8EZEoocAXEYkSCnwRkShRr8A3s6FmtsrM1prZBH8VJSIi\n/lfnwDezWOBp4HKgF3CTmfXyV2EiIuJf9TnDHwCsdc6td84dBV4FhvmnLBER8bf6BH4HYNNxzzd7\nlomISBiqT/O06npfupNWMhsFjPI8PWJmy+rxmcHSEtgV6iJ8oDr9JxJqBNXpb5FSp196xdcn8DcD\nHY97fjqwtepKzrmpwFQAM1vsnMuqx2cGher0r0ioMxJqBNXpb5FUpz+2U58hnS+AbmbWycwSgBuB\nN/1RlIiI+F+dz/Cdc2VmdhcwH4gFXnDOLfdbZSIi4lf1ugGKc24eMO8U3jK1Pp8XRKrTvyKhzkio\nEVSnv0VVnebcSddZRUSkAVJrBRGRKBGQwK+t5YKZJZpZtuf1z8wsPRB11FJjRzNbaGYFZrbczH5Z\nzToXmdk+M8vx/Pwm2HV66thgZvmeGk66Wm+V/uTZn3lmdlaQ68s4bh/lmFmxmY2usk5I9qWZvWBm\nO46fDmxmLczsXTNb4/m3uZf3jvSss8bMRoagzsfNbKXn/9PXzSzVy3trPD6CUOckM9ty3P+3V3h5\nb9BasXipM/u4GjeYWY6X9wZlf3rLoIAen845v/5QeQF3HdAZSABygV5V1vkp8Kzn8Y1Atr/r8KHO\ndsBZnscpwOpq6rwImBvs2qqpdQPQsobXrwDepvK7EecBn4Ww1lhgG3BGOOxL4ELgLGDZcct+D0zw\nPJ4A/K6a97UA1nv+be553DzIdV4GxHke/666On05PoJQ5yRgnA/HRY25EOg6q7z+B+A3odyf3jIo\nkMdnIM7wfWm5MAyY5nk8E7jEzKr7IlfAOOcKnXNfeh7vBwqI3G8KDwP+4Sp9CqSaWbsQ1XIJsM45\ntzFEn38C59wiYE+Vxccff9OA4dW8dQjwrnNuj3NuL/AuMDSYdTrn/u2cK/M8/ZTK77qElJf96Yug\ntmKpqU5P1owA/hWoz/dFDRkUsOMzEIHvS8uFb9bxHND7gNMCUItPPENK/YHPqnl5oJnlmtnbZnZm\nUAv7Hwf828yWWOU3l6sKpzYXN+L9P6Rw2JcAbZxzhVD5Hx3Qupp1wmmfAtxG5V9x1ant+AiGuzxD\nTy94GYIIp/15AbDdObfGy+tB359VMihgx2cgAt+Xlgs+tWUIBjNrArwGjHbOFVd5+Usqhyb6AX8G\n3gh2fR6DnXNnUdmZ9GdmdmGV18Nif1rlF/C+C8yo5uVw2Ze+Cot9CmBmvwLKgJe9rFLb8RFozwBd\ngEygkMrhkqrCZn8CN1Hz2X1Q92ctGeT1bdUsq3V/BiLwfWm58M06ZhYHNKNufybWi5nFU7mjX3bO\nzar6unOu2Dl3wPN4HhBvZi2DXCbOua2ef3cAr1P55/HxfGpzEQSXA18657ZXfSFc9qXH9mNDXp5/\nd1SzTljsU8/FuKuA/3OewduqfDg+Aso5t905V+6cqwCe8/L54bI/44BrgGxv6wRzf3rJoIAdn4EI\nfF9aLrwJHLuqfB2wwNvBHCiecbzngQLn3JNe1ml77NqCmQ2gcn/tDl6VYGaNzSzl2GMqL+RVbUD3\nJnCLVToP2HfsT8Ig83rmFA778jjHH38jgdnVrDMfuMzMmnuGKC7zLAsaMxsK3At81zl3yMs6vhwf\nAVXletH3vHx+uLRiuRRY6ZzbXN2LwdyfNWRQ4I7PAF19voLKK87rgF95lj1E5YELkETln/1rgc+B\nzoG8Gu6lxvOp/BMoD8jx/FwB3Anc6VnnLmA5lTMKPgUGhaDOzp7Pz/XUcmx/Hl+nUXkzmnVAPpAV\ngjobURngzY5bFvJ9SeUvoEKglMqzotupvF70HrDG828Lz7pZwN+Oe+9tnmN0LfDDENS5lspx2mPH\n57GZbe2BeTUdH0Gu8yXPcZdHZVi1q1qn5/lJuRDMOj3LXzx2TB63bkj2Zw0ZFLDjU9+0FRGJEvqm\nrYhIlFDgi4hECQW+iEiUUOCLiEQJBb6ISJRQ4IuIRAkFvohIlFDgi4hEif8HC5wNJAVBQWYAAAAA\nSUVORK5CYII=\n", - "text/plain": [ - "
" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "2.0188677310943604 2.8898627758026123\n" - ] - } - ], - "source": [ - "# 随机初始化参数\n", - "w = V(t.rand(1,1), requires_grad=True)\n", - "b = V(t.zeros(1,1), requires_grad=True)\n", - "\n", - "lr =0.001 # 学习率\n", - "\n", - "for ii in range(8000):\n", - " x, y = get_fake_data()\n", - " x, y = V(x), V(y)\n", - " \n", - " # forward:计算loss\n", - " y_pred = x.mm(w) + b.expand_as(y)\n", - " loss = 0.5 * (y_pred - y) ** 2\n", - " loss = loss.sum()\n", - " \n", - " # backward:手动计算梯度\n", - " loss.backward()\n", - " \n", - " # 更新参数\n", - " w.data.sub_(lr * w.grad.data)\n", - " b.data.sub_(lr * b.grad.data)\n", - " \n", - " # 梯度清零\n", - " w.grad.data.zero_()\n", - " b.grad.data.zero_()\n", - " \n", - " if ii%1000 ==0:\n", - " # 画图\n", - " display.clear_output(wait=True)\n", - " x = t.arange(0, 20).view(-1, 1)\n", - " y = x.mm(w.data) + b.data.expand_as(x)\n", - " plt.plot(x.numpy(), y.numpy()) # predicted\n", - " \n", - " x2, y2 = get_fake_data(batch_size=20) \n", - " plt.scatter(x2.numpy(), y2.numpy()) # true data\n", - " \n", - " plt.xlim(0,20)\n", - " plt.ylim(0,41) \n", - " plt.show()\n", - " plt.pause(0.5)\n", - " \n", - "print(w.data.squeeze()[0], b.data.squeeze()[0])" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "用autograd实现的线性回归最大的不同点就在于autograd不需要计算反向传播,可以自动计算微分。这点不单是在深度学习,在许多机器学习的问题中都很有用。另外需要注意的是在每次反向传播之前要记得先把梯度清零。\n", - "\n", - "本章主要介绍了PyTorch中两个基础底层的数据结构:Tensor和autograd中的Variable。Tensor是一个类似Numpy数组的高效多维数值运算数据结构,有着和Numpy相类似的接口,并提供简单易用的GPU加速。Variable是autograd封装了Tensor并提供自动求导技术的,具有和Tensor几乎一样的接口。`autograd`是PyTorch的自动微分引擎,采用动态计算图技术,能够快速高效的计算导数。" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.6.8" - } - }, - "nbformat": 4, - "nbformat_minor": 2 -} diff --git a/6_pytorch/0_basic/ref_Tensor.ipynb b/6_pytorch/0_basic/ref_Tensor.ipynb deleted file mode 100644 index a593048..0000000 --- a/6_pytorch/0_basic/ref_Tensor.ipynb +++ /dev/null @@ -1,3043 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# 第三章 PyTorch基础:Tensor和Autograd\n", - "\n", - "## 3.1 Tensor\n", - "\n", - "Tensor,又名张量,读者可能对这个名词似曾相识,因它不仅在PyTorch中出现过,它也是Theano、TensorFlow、\n", - "Torch和MxNet中重要的数据结构。关于张量的本质不乏深度的剖析,但从工程角度来讲,可简单地认为它就是一个数组,且支持高效的科学计算。它可以是一个数(标量)、一维数组(向量)、二维数组(矩阵)和更高维的数组(高阶数据)。Tensor和Numpy的ndarrays类似,但PyTorch的tensor支持GPU加速。\n", - "\n", - "本节将系统讲解tensor的使用,力求面面俱到,但不会涉及每个函数。对于更多函数及其用法,读者可通过在IPython/Notebook中使用函数名加`?`查看帮助文档,或查阅PyTorch官方文档[^1]。\n", - "\n", - "[^1]: http://docs.pytorch.org" - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "'0.4.1'" - ] - }, - "execution_count": 1, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "# Let's begin\n", - "from __future__ import print_function\n", - "import torch as t\n", - "t.__version__" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### 3.1.1 基础操作\n", - "\n", - "学习过Numpy的读者会对本节内容感到非常熟悉,因tensor的接口有意设计成与Numpy类似,以方便用户使用。但不熟悉Numpy也没关系,本节内容并不要求先掌握Numpy。\n", - "\n", - "从接口的角度来讲,对tensor的操作可分为两类:\n", - "\n", - "1. `torch.function`,如`torch.save`等。\n", - "2. 另一类是`tensor.function`,如`tensor.view`等。\n", - "\n", - "为方便使用,对tensor的大部分操作同时支持这两类接口,在本书中不做具体区分,如`torch.sum (torch.sum(a, b))`与`tensor.sum (a.sum(b))`功能等价。\n", - "\n", - "而从存储的角度来讲,对tensor的操作又可分为两类:\n", - "\n", - "1. 不会修改自身的数据,如 `a.add(b)`, 加法的结果会返回一个新的tensor。\n", - "2. 会修改自身的数据,如 `a.add_(b)`, 加法的结果仍存储在a中,a被修改了。\n", - "\n", - "函数名以`_`结尾的都是inplace方式, 即会修改调用者自己的数据,在实际应用中需加以区分。\n", - "\n", - "#### 创建Tensor\n", - "\n", - "在PyTorch中新建tensor的方法有很多,具体如表3-1所示。\n", - "\n", - "表3-1: 常见新建tensor的方法\n", - "\n", - "|函数|功能|\n", - "|:---:|:---:|\n", - "|Tensor(\\*sizes)|基础构造函数|\n", - "|ones(\\*sizes)|全1Tensor|\n", - "|zeros(\\*sizes)|全0Tensor|\n", - "|eye(\\*sizes)|对角线为1,其他为0|\n", - "|arange(s,e,step|从s到e,步长为step|\n", - "|linspace(s,e,steps)|从s到e,均匀切分成steps份|\n", - "|rand/randn(\\*sizes)|均匀/标准分布|\n", - "|normal(mean,std)/uniform(from,to)|正态分布/均匀分布|\n", - "|randperm(m)|随机排列|\n", - "\n", - "其中使用`Tensor`函数新建tensor是最复杂多变的方式,它既可以接收一个list,并根据list的数据新建tensor,也能根据指定的形状新建tensor,还能传入其他的tensor,下面举几个例子。" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "tensor([[0.0000, 0.0000, 0.0000],\n", - " [0.0000, 0.0000, 0.0000]])" - ] - }, - "execution_count": 2, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "# 指定tensor的形状\n", - "a = t.Tensor(2, 3)\n", - "a # 数值取决于内存空间的状态" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "tensor([[1., 2., 3.],\n", - " [4., 5., 6.]])" - ] - }, - "execution_count": 6, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "# 用list的数据创建tensor\n", - "b = t.Tensor([[1,2,3],[4,5,6]])\n", - "b" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]" - ] - }, - "execution_count": 7, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "b.tolist() # 把tensor转为list" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "`tensor.size()`返回`torch.Size`对象,它是tuple的子类,但其使用方式与tuple略有区别" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "metadata": { - "scrolled": true - }, - "outputs": [ - { - "data": { - "text/plain": [ - "torch.Size([2, 3])" - ] - }, - "execution_count": 8, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "b_size = b.size()\n", - "b_size" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "6" - ] - }, - "execution_count": 9, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "b.numel() # b中元素总个数,2*3,等价于b.nelement()" - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "metadata": { - "scrolled": true - }, - "outputs": [ - { - "data": { - "text/plain": [ - "(tensor([[ 0.0000, 0.0000, 0.0000],\n", - " [ 0.0000, -0.0000, 0.0000]]), tensor([2., 3.]))" - ] - }, - "execution_count": 10, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "# 创建一个和b形状一样的tensor\n", - "c = t.Tensor(b_size)\n", - "# 创建一个元素为2和3的tensor\n", - "d = t.Tensor((2, 3))\n", - "c, d" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "除了`tensor.size()`,还可以利用`tensor.shape`直接查看tensor的形状,`tensor.shape`等价于`tensor.size()`" - ] - }, - { - "cell_type": "code", - "execution_count": 11, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "torch.Size([2, 3])" - ] - }, - "execution_count": 11, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "c.shape" - ] - }, - { - "cell_type": "code", - "execution_count": 12, - "metadata": {}, - "outputs": [], - "source": [ - "c.shape??" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "需要注意的是,`t.Tensor(*sizes)`创建tensor时,系统不会马上分配空间,只是会计算剩余的内存是否足够使用,使用到tensor时才会分配,而其它操作都是在创建完tensor之后马上进行空间分配。其它常用的创建tensor的方法举例如下。" - ] - }, - { - "cell_type": "code", - "execution_count": 13, - "metadata": { - "scrolled": true - }, - "outputs": [ - { - "data": { - "text/plain": [ - "tensor([[1., 1., 1.],\n", - " [1., 1., 1.]])" - ] - }, - "execution_count": 13, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "t.ones(2, 3)" - ] - }, - { - "cell_type": "code", - "execution_count": 14, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "tensor([[0., 0., 0.],\n", - " [0., 0., 0.]])" - ] - }, - "execution_count": 14, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "t.zeros(2, 3)" - ] - }, - { - "cell_type": "code", - "execution_count": 15, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "tensor([1, 3, 5])" - ] - }, - "execution_count": 15, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "t.arange(1, 6, 2)" - ] - }, - { - "cell_type": "code", - "execution_count": 16, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "tensor([ 1.0000, 5.5000, 10.0000])" - ] - }, - "execution_count": 16, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "t.linspace(1, 10, 3)" - ] - }, - { - "cell_type": "code", - "execution_count": 17, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "tensor([[ 0.4388, 0.9361, 0.8411],\n", - " [-1.0667, -0.5187, 0.5520]])" - ] - }, - "execution_count": 17, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "t.randn(2, 3)" - ] - }, - { - "cell_type": "code", - "execution_count": 18, - "metadata": { - "scrolled": true - }, - "outputs": [ - { - "data": { - "text/plain": [ - "tensor([2, 0, 4, 1, 3])" - ] - }, - "execution_count": 18, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "t.randperm(5) # 长度为5的随机排列" - ] - }, - { - "cell_type": "code", - "execution_count": 19, - "metadata": { - "scrolled": true - }, - "outputs": [ - { - "data": { - "text/plain": [ - "tensor([[1., 0., 0.],\n", - " [0., 1., 0.]])" - ] - }, - "execution_count": 19, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "t.eye(2, 3) # 对角线为1, 不要求行列数一致" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "#### 常用Tensor操作" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "通过`tensor.view`方法可以调整tensor的形状,但必须保证调整前后元素总数一致。`view`不会修改自身的数据,返回的新tensor与源tensor共享内存,也即更改其中的一个,另外一个也会跟着改变。在实际应用中可能经常需要添加或减少某一维度,这时候`squeeze`和`unsqueeze`两个函数就派上用场了。" - ] - }, - { - "cell_type": "code", - "execution_count": 20, - "metadata": { - "scrolled": true - }, - "outputs": [ - { - "data": { - "text/plain": [ - "tensor([[0, 1, 2],\n", - " [3, 4, 5]])" - ] - }, - "execution_count": 20, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "a = t.arange(0, 6)\n", - "a.view(2, 3)" - ] - }, - { - "cell_type": "code", - "execution_count": 21, - "metadata": { - "scrolled": true - }, - "outputs": [ - { - "data": { - "text/plain": [ - "tensor([[0, 1, 2],\n", - " [3, 4, 5]])" - ] - }, - "execution_count": 21, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "b = a.view(-1, 3) # 当某一维为-1的时候,会自动计算它的大小\n", - "b" - ] - }, - { - "cell_type": "code", - "execution_count": 22, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "tensor([[[0, 1, 2]],\n", - "\n", - " [[3, 4, 5]]])" - ] - }, - "execution_count": 22, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "b.unsqueeze(1) # 注意形状,在第1维(下标从0开始)上增加“1”" - ] - }, - { - "cell_type": "code", - "execution_count": 23, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "tensor([[[0, 1, 2]],\n", - "\n", - " [[3, 4, 5]]])" - ] - }, - "execution_count": 23, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "b.unsqueeze(-2) # -2表示倒数第二个维度" - ] - }, - { - "cell_type": "code", - "execution_count": 24, - "metadata": { - "scrolled": true - }, - "outputs": [ - { - "data": { - "text/plain": [ - "tensor([[[[0, 1, 2],\n", - " [3, 4, 5]]]])" - ] - }, - "execution_count": 24, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "c = b.view(1, 1, 1, 2, 3)\n", - "c.squeeze(0) # 压缩第0维的“1”" - ] - }, - { - "cell_type": "code", - "execution_count": 25, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "tensor([[0, 1, 2],\n", - " [3, 4, 5]])" - ] - }, - "execution_count": 25, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "c.squeeze() # 把所有维度为“1”的压缩" - ] - }, - { - "cell_type": "code", - "execution_count": 26, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "tensor([[ 0, 100, 2],\n", - " [ 3, 4, 5]])" - ] - }, - "execution_count": 26, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "a[1] = 100\n", - "b # a修改,b作为view之后的,也会跟着修改" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "`resize`是另一种可用来调整`size`的方法,但与`view`不同,它可以修改tensor的大小。如果新大小超过了原大小,会自动分配新的内存空间,而如果新大小小于原大小,则之前的数据依旧会被保存,看一个例子。" - ] - }, - { - "cell_type": "code", - "execution_count": 24, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "\n", - " 0 100 2\n", - "[torch.FloatTensor of size 1x3]" - ] - }, - "execution_count": 24, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "b.resize_(1, 3)\n", - "b" - ] - }, - { - "cell_type": "code", - "execution_count": 25, - "metadata": { - "scrolled": true - }, - "outputs": [ - { - "data": { - "text/plain": [ - "\n", - " 0.0000e+00 1.0000e+02 2.0000e+00\n", - " 3.0000e+00 4.0000e+00 5.0000e+00\n", - " 4.1417e+36 4.5731e-41 6.7262e-44\n", - "[torch.FloatTensor of size 3x3]" - ] - }, - "execution_count": 25, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "b.resize_(3, 3) # 旧的数据依旧保存着,多出的大小会分配新空间\n", - "b" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "#### 索引操作\n", - "\n", - "Tensor支持与numpy.ndarray类似的索引操作,语法上也类似,下面通过一些例子,讲解常用的索引操作。如无特殊说明,索引出来的结果与原tensor共享内存,也即修改一个,另一个会跟着修改。" - ] - }, - { - "cell_type": "code", - "execution_count": 26, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "\n", - " 0.2355 0.8276 0.6279 -2.3826\n", - " 0.3533 1.3359 0.1627 1.7314\n", - " 0.8121 0.3059 2.4352 1.4577\n", - "[torch.FloatTensor of size 3x4]" - ] - }, - "execution_count": 26, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "a = t.randn(3, 4)\n", - "a" - ] - }, - { - "cell_type": "code", - "execution_count": 27, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "\n", - " 0.2355\n", - " 0.8276\n", - " 0.6279\n", - "-2.3826\n", - "[torch.FloatTensor of size 4]" - ] - }, - "execution_count": 27, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "a[0] # 第0行(下标从0开始)" - ] - }, - { - "cell_type": "code", - "execution_count": 28, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "\n", - " 0.2355\n", - " 0.3533\n", - " 0.8121\n", - "[torch.FloatTensor of size 3]" - ] - }, - "execution_count": 28, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "a[:, 0] # 第0列" - ] - }, - { - "cell_type": "code", - "execution_count": 29, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "0.6279084086418152" - ] - }, - "execution_count": 29, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "a[0][2] # 第0行第2个元素,等价于a[0, 2]" - ] - }, - { - "cell_type": "code", - "execution_count": 30, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "-2.3825833797454834" - ] - }, - "execution_count": 30, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "a[0, -1] # 第0行最后一个元素" - ] - }, - { - "cell_type": "code", - "execution_count": 31, - "metadata": { - "scrolled": true - }, - "outputs": [ - { - "data": { - "text/plain": [ - "\n", - " 0.2355 0.8276 0.6279 -2.3826\n", - " 0.3533 1.3359 0.1627 1.7314\n", - "[torch.FloatTensor of size 2x4]" - ] - }, - "execution_count": 31, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "a[:2] # 前两行" - ] - }, - { - "cell_type": "code", - "execution_count": 32, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "\n", - " 0.2355 0.8276\n", - " 0.3533 1.3359\n", - "[torch.FloatTensor of size 2x2]" - ] - }, - "execution_count": 32, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "a[:2, 0:2] # 前两行,第0,1列" - ] - }, - { - "cell_type": "code", - "execution_count": 33, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\n", - " 0.2355 0.8276\n", - "[torch.FloatTensor of size 1x2]\n", - "\n", - "\n", - " 0.2355\n", - " 0.8276\n", - "[torch.FloatTensor of size 2]\n", - "\n" - ] - } - ], - "source": [ - "print(a[0:1, :2]) # 第0行,前两列 \n", - "print(a[0, :2]) # 注意两者的区别:形状不同" - ] - }, - { - "cell_type": "code", - "execution_count": 34, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "\n", - " 0 0 0 0\n", - " 0 1 0 1\n", - " 0 0 1 1\n", - "[torch.ByteTensor of size 3x4]" - ] - }, - "execution_count": 34, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "a > 1 # 返回一个ByteTensor" - ] - }, - { - "cell_type": "code", - "execution_count": 35, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "\n", - " 1.3359\n", - " 1.7314\n", - " 2.4352\n", - " 1.4577\n", - "[torch.FloatTensor of size 4]" - ] - }, - "execution_count": 35, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "a[a>1] # 等价于a.masked_select(a>1)\n", - "# 选择结果与原tensor不共享内存空间" - ] - }, - { - "cell_type": "code", - "execution_count": 36, - "metadata": { - "scrolled": true - }, - "outputs": [ - { - "data": { - "text/plain": [ - "\n", - " 0.2355 0.8276 0.6279 -2.3826\n", - " 0.3533 1.3359 0.1627 1.7314\n", - "[torch.FloatTensor of size 2x4]" - ] - }, - "execution_count": 36, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "a[t.LongTensor([0,1])] # 第0行和第1行" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "其它常用的选择函数如表3-2所示。\n", - "\n", - "表3-2常用的选择函数\n", - "\n", - "函数|功能|\n", - ":---:|:---:|\n", - "index_select(input, dim, index)|在指定维度dim上选取,比如选取某些行、某些列\n", - "masked_select(input, mask)|例子如上,a[a>0],使用ByteTensor进行选取\n", - "non_zero(input)|非0元素的下标\n", - "gather(input, dim, index)|根据index,在dim维度上选取数据,输出的size与index一样\n", - "\n", - "\n", - "`gather`是一个比较复杂的操作,对一个2维tensor,输出的每个元素如下:\n", - "\n", - "```python\n", - "out[i][j] = input[index[i][j]][j] # dim=0\n", - "out[i][j] = input[i][index[i][j]] # dim=1\n", - "```\n", - "三维tensor的`gather`操作同理,下面举几个例子。" - ] - }, - { - "cell_type": "code", - "execution_count": 37, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "\n", - " 0 1 2 3\n", - " 4 5 6 7\n", - " 8 9 10 11\n", - " 12 13 14 15\n", - "[torch.FloatTensor of size 4x4]" - ] - }, - "execution_count": 37, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "a = t.arange(0, 16).view(4, 4)\n", - "a" - ] - }, - { - "cell_type": "code", - "execution_count": 38, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "\n", - " 0 5 10 15\n", - "[torch.FloatTensor of size 1x4]" - ] - }, - "execution_count": 38, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "# 选取对角线的元素\n", - "index = t.LongTensor([[0,1,2,3]])\n", - "a.gather(0, index)" - ] - }, - { - "cell_type": "code", - "execution_count": 39, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "\n", - " 3\n", - " 6\n", - " 9\n", - " 12\n", - "[torch.FloatTensor of size 4x1]" - ] - }, - "execution_count": 39, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "# 选取反对角线上的元素\n", - "index = t.LongTensor([[3,2,1,0]]).t()\n", - "a.gather(1, index)" - ] - }, - { - "cell_type": "code", - "execution_count": 40, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "\n", - " 12 9 6 3\n", - "[torch.FloatTensor of size 1x4]" - ] - }, - "execution_count": 40, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "# 选取反对角线上的元素,注意与上面的不同\n", - "index = t.LongTensor([[3,2,1,0]])\n", - "a.gather(0, index)" - ] - }, - { - "cell_type": "code", - "execution_count": 41, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "\n", - " 0 3\n", - " 5 6\n", - " 10 9\n", - " 15 12\n", - "[torch.FloatTensor of size 4x2]" - ] - }, - "execution_count": 41, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "# 选取两个对角线上的元素\n", - "index = t.LongTensor([[0,1,2,3],[3,2,1,0]]).t()\n", - "b = a.gather(1, index)\n", - "b" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "与`gather`相对应的逆操作是`scatter_`,`gather`把数据从input中按index取出,而`scatter_`是把取出的数据再放回去。注意`scatter_`函数是inplace操作。\n", - "\n", - "```python\n", - "out = input.gather(dim, index)\n", - "-->近似逆操作\n", - "out = Tensor()\n", - "out.scatter_(dim, index)\n", - "```" - ] - }, - { - "cell_type": "code", - "execution_count": 42, - "metadata": { - "scrolled": true - }, - "outputs": [ - { - "data": { - "text/plain": [ - "\n", - " 0 0 0 3\n", - " 0 5 6 0\n", - " 0 9 10 0\n", - " 12 0 0 15\n", - "[torch.FloatTensor of size 4x4]" - ] - }, - "execution_count": 42, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "# 把两个对角线元素放回去到指定位置\n", - "c = t.zeros(4,4)\n", - "c.scatter_(1, index, b)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "#### 高级索引\n", - "PyTorch在0.2版本中完善了索引操作,目前已经支持绝大多数numpy的高级索引[^10]。高级索引可以看成是普通索引操作的扩展,但是高级索引操作的结果一般不和原始的Tensor贡献内出。 \n", - "[^10]: https://docs.scipy.org/doc/numpy/reference/arrays.indexing.html#advanced-indexing" - ] - }, - { - "cell_type": "code", - "execution_count": 43, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "\n", - "(0 ,.,.) = \n", - " 0 1 2\n", - " 3 4 5\n", - " 6 7 8\n", - "\n", - "(1 ,.,.) = \n", - " 9 10 11\n", - " 12 13 14\n", - " 15 16 17\n", - "\n", - "(2 ,.,.) = \n", - " 18 19 20\n", - " 21 22 23\n", - " 24 25 26\n", - "[torch.FloatTensor of size 3x3x3]" - ] - }, - "execution_count": 43, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "x = t.arange(0,27).view(3,3,3)\n", - "x" - ] - }, - { - "cell_type": "code", - "execution_count": 44, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "\n", - " 14\n", - " 24\n", - "[torch.FloatTensor of size 2]" - ] - }, - "execution_count": 44, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "x[[1, 2], [1, 2], [2, 0]] # x[1,1,2]和x[2,2,0]" - ] - }, - { - "cell_type": "code", - "execution_count": 45, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "\n", - " 19\n", - " 10\n", - " 1\n", - "[torch.FloatTensor of size 3]" - ] - }, - "execution_count": 45, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "x[[2, 1, 0], [0], [1]] # x[2,0,1],x[1,0,1],x[0,0,1]" - ] - }, - { - "cell_type": "code", - "execution_count": 46, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "\n", - "(0 ,.,.) = \n", - " 0 1 2\n", - " 3 4 5\n", - " 6 7 8\n", - "\n", - "(1 ,.,.) = \n", - " 18 19 20\n", - " 21 22 23\n", - " 24 25 26\n", - "[torch.FloatTensor of size 2x3x3]" - ] - }, - "execution_count": 46, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "x[[0, 2], ...] # x[0] 和 x[2]" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "#### Tensor类型\n", - "\n", - "Tensor有不同的数据类型,如表3-3所示,每种类型分别对应有CPU和GPU版本(HalfTensor除外)。默认的tensor是FloatTensor,可通过`t.set_default_tensor_type` 来修改默认tensor类型(如果默认类型为GPU tensor,则所有操作都将在GPU上进行)。Tensor的类型对分析内存占用很有帮助。例如对于一个size为(1000, 1000, 1000)的FloatTensor,它有`1000*1000*1000=10^9`个元素,每个元素占32bit/8 = 4Byte内存,所以共占大约4GB内存/显存。HalfTensor是专门为GPU版本设计的,同样的元素个数,显存占用只有FloatTensor的一半,所以可以极大缓解GPU显存不足的问题,但由于HalfTensor所能表示的数值大小和精度有限[^2],所以可能出现溢出等问题。\n", - "\n", - "[^2]: https://stackoverflow.com/questions/872544/what-range-of-numbers-can-be-represented-in-a-16-32-and-64-bit-ieee-754-syste\n", - "\n", - "表3-3: tensor数据类型\n", - "\n", - "数据类型|\tCPU tensor\t|GPU tensor|\n", - ":---:|:---:|:--:|\n", - "32-bit 浮点|\ttorch.FloatTensor\t|torch.cuda.FloatTensor\n", - "64-bit 浮点|\ttorch.DoubleTensor|\ttorch.cuda.DoubleTensor\n", - "16-bit 半精度浮点|\tN/A\t|torch.cuda.HalfTensor\n", - "8-bit 无符号整形(0~255)|\ttorch.ByteTensor|\ttorch.cuda.ByteTensor\n", - "8-bit 有符号整形(-128~127)|\ttorch.CharTensor\t|torch.cuda.CharTensor\n", - "16-bit 有符号整形 |\ttorch.ShortTensor|\ttorch.cuda.ShortTensor\n", - "32-bit 有符号整形 \t|torch.IntTensor\t|torch.cuda.IntTensor\n", - "64-bit 有符号整形 \t|torch.LongTensor\t|torch.cuda.LongTensor\n", - "\n", - "各数据类型之间可以互相转换,`type(new_type)`是通用的做法,同时还有`float`、`long`、`half`等快捷方法。CPU tensor与GPU tensor之间的互相转换通过`tensor.cuda`和`tensor.cpu`方法实现。Tensor还有一个`new`方法,用法与`t.Tensor`一样,会调用该tensor对应类型的构造函数,生成与当前tensor类型一致的tensor。" - ] - }, - { - "cell_type": "code", - "execution_count": 47, - "metadata": {}, - "outputs": [], - "source": [ - "# 设置默认tensor,注意参数是字符串\n", - "t.set_default_tensor_type('torch.IntTensor')" - ] - }, - { - "cell_type": "code", - "execution_count": 48, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "\n", - "-1.7683e+09 2.1918e+04 1.0000e+00\n", - " 0.0000e+00 1.0000e+00 0.0000e+00\n", - "[torch.IntTensor of size 2x3]" - ] - }, - "execution_count": 48, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "a = t.Tensor(2,3)\n", - "a # 现在a是IntTensor" - ] - }, - { - "cell_type": "code", - "execution_count": 49, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "\n", - "-1.7683e+09 2.1918e+04 1.0000e+00\n", - " 0.0000e+00 1.0000e+00 0.0000e+00\n", - "[torch.FloatTensor of size 2x3]" - ] - }, - "execution_count": 49, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "# 把a转成FloatTensor,等价于b=a.type(t.FloatTensor)\n", - "b = a.float() \n", - "b" - ] - }, - { - "cell_type": "code", - "execution_count": 50, - "metadata": { - "scrolled": true - }, - "outputs": [ - { - "data": { - "text/plain": [ - "\n", - "-1.7683e+09 2.1918e+04 1.0000e+00\n", - " 0.0000e+00 1.0000e+00 0.0000e+00\n", - "[torch.FloatTensor of size 2x3]" - ] - }, - "execution_count": 50, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "c = a.type_as(b)\n", - "c" - ] - }, - { - "cell_type": "code", - "execution_count": 51, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "\n", - "-1.7682e+09 2.1918e+04 3.0000e+00\n", - " 0.0000e+00 1.0000e+00 0.0000e+00\n", - "[torch.IntTensor of size 2x3]" - ] - }, - "execution_count": 51, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "d = a.new(2,3) # 等价于torch.IntTensor(2,3)\n", - "d" - ] - }, - { - "cell_type": "code", - "execution_count": 52, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "\u001b[0;31mSignature:\u001b[0m \u001b[0ma\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnew\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;31mSource:\u001b[0m \n", - " \u001b[0;32mdef\u001b[0m \u001b[0mnew\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\n", - "\u001b[0;34m\u001b[0m \u001b[0;34mr\"\"\"Constructs a new tensor of the same data type as :attr:`self` tensor.\u001b[0m\n", - "\u001b[0;34m\u001b[0m\n", - "\u001b[0;34m Any valid argument combination to the tensor constructor is accepted by\u001b[0m\n", - "\u001b[0;34m this method, including sizes, :class:`torch.Storage`, NumPy ndarray,\u001b[0m\n", - "\u001b[0;34m Python Sequence, etc. See :ref:`torch.Tensor ` for more\u001b[0m\n", - "\u001b[0;34m details.\u001b[0m\n", - "\u001b[0;34m\u001b[0m\n", - "\u001b[0;34m .. note:: For CUDA tensors, this method will create new tensor on the\u001b[0m\n", - "\u001b[0;34m same device as this tensor.\u001b[0m\n", - "\u001b[0;34m \"\"\"\u001b[0m\u001b[0;34m\u001b[0m\n", - "\u001b[0;34m\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__class__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;31mFile:\u001b[0m /usr/lib/python3.6/site-packages/torch/tensor.py\n", - "\u001b[0;31mType:\u001b[0m method\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], - "source": [ - "# 查看函数new的源码\n", - "a.new??" - ] - }, - { - "cell_type": "code", - "execution_count": 53, - "metadata": {}, - "outputs": [], - "source": [ - "# 恢复之前的默认设置\n", - "t.set_default_tensor_type('torch.FloatTensor')" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "#### 逐元素操作\n", - "\n", - "这部分操作会对tensor的每一个元素(point-wise,又名element-wise)进行操作,此类操作的输入与输出形状一致。常用的操作如表3-4所示。\n", - "\n", - "表3-4: 常见的逐元素操作\n", - "\n", - "|函数|功能|\n", - "|:--:|:--:|\n", - "|abs/sqrt/div/exp/fmod/log/pow..|绝对值/平方根/除法/指数/求余/求幂..|\n", - "|cos/sin/asin/atan2/cosh..|相关三角函数|\n", - "|ceil/round/floor/trunc| 上取整/四舍五入/下取整/只保留整数部分|\n", - "|clamp(input, min, max)|超过min和max部分截断|\n", - "|sigmod/tanh..|激活函数\n", - "\n", - "对于很多操作,例如div、mul、pow、fmod等,PyTorch都实现了运算符重载,所以可以直接使用运算符。如`a ** 2` 等价于`torch.pow(a,2)`, `a * 2`等价于`torch.mul(a,2)`。\n", - "\n", - "其中`clamp(x, min, max)`的输出满足以下公式:\n", - "$$\n", - "y_i =\n", - "\\begin{cases}\n", - "min, & \\text{if } x_i \\lt min \\\\\n", - "x_i, & \\text{if } min \\le x_i \\le max \\\\\n", - "max, & \\text{if } x_i \\gt max\\\\\n", - "\\end{cases}\n", - "$$\n", - "`clamp`常用在某些需要比较大小的地方,如取一个tensor的每个元素与另一个数的较大值。" - ] - }, - { - "cell_type": "code", - "execution_count": 54, - "metadata": { - "scrolled": true - }, - "outputs": [ - { - "data": { - "text/plain": [ - "\n", - " 1.0000 0.5403 -0.4161\n", - "-0.9900 -0.6536 0.2837\n", - "[torch.FloatTensor of size 2x3]" - ] - }, - "execution_count": 54, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "a = t.arange(0, 6).view(2, 3)\n", - "t.cos(a)" - ] - }, - { - "cell_type": "code", - "execution_count": 55, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "\n", - " 0 1 2\n", - " 0 1 2\n", - "[torch.FloatTensor of size 2x3]" - ] - }, - "execution_count": 55, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "a % 3 # 等价于t.fmod(a, 3)" - ] - }, - { - "cell_type": "code", - "execution_count": 56, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "\n", - " 0 1 4\n", - " 9 16 25\n", - "[torch.FloatTensor of size 2x3]" - ] - }, - "execution_count": 56, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "a ** 2 # 等价于t.pow(a, 2)" - ] - }, - { - "cell_type": "code", - "execution_count": 57, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\n", - " 0 1 2\n", - " 3 4 5\n", - "[torch.FloatTensor of size 2x3]\n", - "\n" - ] - }, - { - "data": { - "text/plain": [ - "\n", - " 3 3 3\n", - " 3 4 5\n", - "[torch.FloatTensor of size 2x3]" - ] - }, - "execution_count": 57, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "# 取a中的每一个元素与3相比较大的一个 (小于3的截断成3)\n", - "print(a)\n", - "t.clamp(a, min=3)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "#### 归并操作 \n", - "此类操作会使输出形状小于输入形状,并可以沿着某一维度进行指定操作。如加法`sum`,既可以计算整个tensor的和,也可以计算tensor中每一行或每一列的和。常用的归并操作如表3-5所示。\n", - "\n", - "表3-5: 常用归并操作\n", - "\n", - "|函数|功能|\n", - "|:---:|:---:|\n", - "|mean/sum/median/mode|均值/和/中位数/众数|\n", - "|norm/dist|范数/距离|\n", - "|std/var|标准差/方差|\n", - "|cumsum/cumprod|累加/累乘|\n", - "\n", - "以上大多数函数都有一个参数**`dim`**,用来指定这些操作是在哪个维度上执行的。关于dim(对应于Numpy中的axis)的解释众说纷纭,这里提供一个简单的记忆方式:\n", - "\n", - "假设输入的形状是(m, n, k)\n", - "\n", - "- 如果指定dim=0,输出的形状就是(1, n, k)或者(n, k)\n", - "- 如果指定dim=1,输出的形状就是(m, 1, k)或者(m, k)\n", - "- 如果指定dim=2,输出的形状就是(m, n, 1)或者(m, n)\n", - "\n", - "size中是否有\"1\",取决于参数`keepdim`,`keepdim=True`会保留维度`1`。注意,以上只是经验总结,并非所有函数都符合这种形状变化方式,如`cumsum`。" - ] - }, - { - "cell_type": "code", - "execution_count": 58, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "\n", - " 2 2 2\n", - "[torch.FloatTensor of size 1x3]" - ] - }, - "execution_count": 58, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "b = t.ones(2, 3)\n", - "b.sum(dim = 0, keepdim=True)" - ] - }, - { - "cell_type": "code", - "execution_count": 59, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "\n", - " 2\n", - " 2\n", - " 2\n", - "[torch.FloatTensor of size 3]" - ] - }, - "execution_count": 59, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "# keepdim=False,不保留维度\"1\",注意形状\n", - "b.sum(dim=0, keepdim=False)" - ] - }, - { - "cell_type": "code", - "execution_count": 60, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "\n", - " 3\n", - " 3\n", - "[torch.FloatTensor of size 2]" - ] - }, - "execution_count": 60, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "b.sum(dim=1)" - ] - }, - { - "cell_type": "code", - "execution_count": 61, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\n", - " 0 1 2\n", - " 3 4 5\n", - "[torch.FloatTensor of size 2x3]\n", - "\n" - ] - }, - { - "data": { - "text/plain": [ - "\n", - " 0 1 3\n", - " 3 7 12\n", - "[torch.FloatTensor of size 2x3]" - ] - }, - "execution_count": 61, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "a = t.arange(0, 6).view(2, 3)\n", - "print(a)\n", - "a.cumsum(dim=1) # 沿着行累加" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "#### 比较\n", - "比较函数中有一些是逐元素比较,操作类似于逐元素操作,还有一些则类似于归并操作。常用比较函数如表3-6所示。\n", - "\n", - "表3-6: 常用比较函数\n", - "\n", - "|函数|功能|\n", - "|:--:|:--:|\n", - "|gt/lt/ge/le/eq/ne|大于/小于/大于等于/小于等于/等于/不等|\n", - "|topk|最大的k个数|\n", - "|sort|排序|\n", - "|max/min|比较两个tensor最大最小值|\n", - "\n", - "表中第一行的比较操作已经实现了运算符重载,因此可以使用`a>=b`、`a>b`、`a!=b`、`a==b`,其返回结果是一个`ByteTensor`,可用来选取元素。max/min这两个操作比较特殊,以max来说,它有以下三种使用情况:\n", - "- t.max(tensor):返回tensor中最大的一个数\n", - "- t.max(tensor,dim):指定维上最大的数,返回tensor和下标\n", - "- t.max(tensor1, tensor2): 比较两个tensor相比较大的元素\n", - "\n", - "至于比较一个tensor和一个数,可以使用clamp函数。下面举例说明。" - ] - }, - { - "cell_type": "code", - "execution_count": 62, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "\n", - " 0 3 6\n", - " 9 12 15\n", - "[torch.FloatTensor of size 2x3]" - ] - }, - "execution_count": 62, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "a = t.linspace(0, 15, 6).view(2, 3)\n", - "a" - ] - }, - { - "cell_type": "code", - "execution_count": 63, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "\n", - " 15 12 9\n", - " 6 3 0\n", - "[torch.FloatTensor of size 2x3]" - ] - }, - "execution_count": 63, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "b = t.linspace(15, 0, 6).view(2, 3)\n", - "b" - ] - }, - { - "cell_type": "code", - "execution_count": 64, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "\n", - " 0 0 0\n", - " 1 1 1\n", - "[torch.ByteTensor of size 2x3]" - ] - }, - "execution_count": 64, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "a>b" - ] - }, - { - "cell_type": "code", - "execution_count": 65, - "metadata": { - "scrolled": true - }, - "outputs": [ - { - "data": { - "text/plain": [ - "\n", - " 9\n", - " 12\n", - " 15\n", - "[torch.FloatTensor of size 3]" - ] - }, - "execution_count": 65, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "a[a>b] # a中大于b的元素" - ] - }, - { - "cell_type": "code", - "execution_count": 66, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "15.0" - ] - }, - "execution_count": 66, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "t.max(a)" - ] - }, - { - "cell_type": "code", - "execution_count": 67, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "(\n", - " 15\n", - " 6\n", - " [torch.FloatTensor of size 2], \n", - " 0\n", - " 0\n", - " [torch.LongTensor of size 2])" - ] - }, - "execution_count": 67, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "t.max(b, dim=1) \n", - "# 第一个返回值的15和6分别表示第0行和第1行最大的元素\n", - "# 第二个返回值的0和0表示上述最大的数是该行第0个元素" - ] - }, - { - "cell_type": "code", - "execution_count": 68, - "metadata": { - "scrolled": true - }, - "outputs": [ - { - "data": { - "text/plain": [ - "\n", - " 15 12 9\n", - " 9 12 15\n", - "[torch.FloatTensor of size 2x3]" - ] - }, - "execution_count": 68, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "t.max(a,b)" - ] - }, - { - "cell_type": "code", - "execution_count": 69, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "\n", - " 10 10 10\n", - " 10 12 15\n", - "[torch.FloatTensor of size 2x3]" - ] - }, - "execution_count": 69, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "# 比较a和10较大的元素\n", - "t.clamp(a, min=10)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "#### 线性代数\n", - "\n", - "PyTorch的线性函数主要封装了Blas和Lapack,其用法和接口都与之类似。常用的线性代数函数如表3-7所示。\n", - "\n", - "表3-7: 常用的线性代数函数\n", - "\n", - "|函数|功能|\n", - "|:---:|:---:|\n", - "|trace|对角线元素之和(矩阵的迹)|\n", - "|diag|对角线元素|\n", - "|triu/tril|矩阵的上三角/下三角,可指定偏移量|\n", - "|mm/bmm|矩阵乘法,batch的矩阵乘法|\n", - "|addmm/addbmm/addmv/addr/badbmm..|矩阵运算\n", - "|t|转置|\n", - "|dot/cross|内积/外积\n", - "|inverse|求逆矩阵\n", - "|svd|奇异值分解\n", - "\n", - "具体使用说明请参见官方文档[^3],需要注意的是,矩阵的转置会导致存储空间不连续,需调用它的`.contiguous`方法将其转为连续。\n", - "[^3]: http://pytorch.org/docs/torch.html#blas-and-lapack-operations" - ] - }, - { - "cell_type": "code", - "execution_count": 70, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "False" - ] - }, - "execution_count": 70, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "b = a.t()\n", - "b.is_contiguous()" - ] - }, - { - "cell_type": "code", - "execution_count": 71, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "\n", - " 0 9\n", - " 3 12\n", - " 6 15\n", - "[torch.FloatTensor of size 3x2]" - ] - }, - "execution_count": 71, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "b.contiguous()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### 3.1.2 Tensor和Numpy\n", - "\n", - "Tensor和Numpy数组之间具有很高的相似性,彼此之间的互操作也非常简单高效。需要注意的是,Numpy和Tensor共享内存。由于Numpy历史悠久,支持丰富的操作,所以当遇到Tensor不支持的操作时,可先转成Numpy数组,处理后再转回tensor,其转换开销很小。" - ] - }, - { - "cell_type": "code", - "execution_count": 72, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "array([[1., 1., 1.],\n", - " [1., 1., 1.]], dtype=float32)" - ] - }, - "execution_count": 72, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "import numpy as np\n", - "a = np.ones([2, 3],dtype=np.float32)\n", - "a" - ] - }, - { - "cell_type": "code", - "execution_count": 73, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "\n", - " 1 1 1\n", - " 1 1 1\n", - "[torch.FloatTensor of size 2x3]" - ] - }, - "execution_count": 73, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "b = t.from_numpy(a)\n", - "b" - ] - }, - { - "cell_type": "code", - "execution_count": 74, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "\n", - " 1 1 1\n", - " 1 1 1\n", - "[torch.FloatTensor of size 2x3]" - ] - }, - "execution_count": 74, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "b = t.Tensor(a) # 也可以直接将numpy对象传入Tensor\n", - "b" - ] - }, - { - "cell_type": "code", - "execution_count": 75, - "metadata": { - "scrolled": true - }, - "outputs": [ - { - "data": { - "text/plain": [ - "\n", - " 1 100 1\n", - " 1 1 1\n", - "[torch.FloatTensor of size 2x3]" - ] - }, - "execution_count": 75, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "a[0, 1]=100\n", - "b" - ] - }, - { - "cell_type": "code", - "execution_count": 76, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "array([[ 1., 100., 1.],\n", - " [ 1., 1., 1.]], dtype=float32)" - ] - }, - "execution_count": 76, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "c = b.numpy() # a, b, c三个对象共享内存\n", - "c" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "**注意**: 当numpy的数据类型和Tensor的类型不一样的时候,数据会被复制,不会共享内存。" - ] - }, - { - "cell_type": "code", - "execution_count": 77, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "array([[1., 1., 1.],\n", - " [1., 1., 1.]])" - ] - }, - "execution_count": 77, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "a = np.ones([2, 3])\n", - "a # 注意和上面的a的区别(dtype不是float32)" - ] - }, - { - "cell_type": "code", - "execution_count": 78, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "\n", - " 1 1 1\n", - " 1 1 1\n", - "[torch.FloatTensor of size 2x3]" - ] - }, - "execution_count": 78, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "b = t.Tensor(a) # FloatTensor(double64或者float64)\n", - "b" - ] - }, - { - "cell_type": "code", - "execution_count": 79, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "\n", - " 1 1 1\n", - " 1 1 1\n", - "[torch.DoubleTensor of size 2x3]" - ] - }, - "execution_count": 79, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "c = t.from_numpy(a) # 注意c的类型(DoubleTensor)\n", - "c" - ] - }, - { - "cell_type": "code", - "execution_count": 80, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "\n", - " 1 1 1\n", - " 1 1 1\n", - "[torch.FloatTensor of size 2x3]" - ] - }, - "execution_count": 80, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "a[0, 1] = 100\n", - "b # b与a不通向内存,所以即使a改变了,b也不变" - ] - }, - { - "cell_type": "code", - "execution_count": 81, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "\n", - " 1 100 1\n", - " 1 1 1\n", - "[torch.DoubleTensor of size 2x3]" - ] - }, - "execution_count": 81, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "c # c与a共享内存" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "广播法则(broadcast)是科学运算中经常使用的一个技巧,它在快速执行向量化的同时不会占用额外的内存/显存。\n", - "Numpy的广播法则定义如下:\n", - "\n", - "- 让所有输入数组都向其中shape最长的数组看齐,shape中不足的部分通过在前面加1补齐\n", - "- 两个数组要么在某一个维度的长度一致,要么其中一个为1,否则不能计算 \n", - "- 当输入数组的某个维度的长度为1时,计算时沿此维度复制扩充成一样的形状\n", - "\n", - "PyTorch当前已经支持了自动广播法则,但是笔者还是建议读者通过以下两个函数的组合手动实现广播法则,这样更直观,更不易出错:\n", - "\n", - "- `unsqueeze`或者`view`:为数据某一维的形状补1,实现法则1\n", - "- `expand`或者`expand_as`,重复数组,实现法则3;该操作不会复制数组,所以不会占用额外的空间。\n", - "\n", - "注意,repeat实现与expand相类似的功能,但是repeat会把相同数据复制多份,因此会占用额外的空间。" - ] - }, - { - "cell_type": "code", - "execution_count": 82, - "metadata": { - "scrolled": true - }, - "outputs": [], - "source": [ - "a = t.ones(3, 2)\n", - "b = t.zeros(2, 3,1)" - ] - }, - { - "cell_type": "code", - "execution_count": 83, - "metadata": { - "scrolled": true - }, - "outputs": [ - { - "data": { - "text/plain": [ - "\n", - "(0 ,.,.) = \n", - " 1 1\n", - " 1 1\n", - " 1 1\n", - "\n", - "(1 ,.,.) = \n", - " 1 1\n", - " 1 1\n", - " 1 1\n", - "[torch.FloatTensor of size 2x3x2]" - ] - }, - "execution_count": 83, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "# 自动广播法则\n", - "# 第一步:a是2维,b是3维,所以先在较小的a前面补1 ,\n", - "# 即:a.unsqueeze(0),a的形状变成(1,3,2),b的形状是(2,3,1),\n", - "# 第二步: a和b在第一维和第三维形状不一样,其中一个为1 ,\n", - "# 可以利用广播法则扩展,两个形状都变成了(2,3,2)\n", - "a+b" - ] - }, - { - "cell_type": "code", - "execution_count": 84, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "\n", - "(0 ,.,.) = \n", - " 1 1\n", - " 1 1\n", - " 1 1\n", - "\n", - "(1 ,.,.) = \n", - " 1 1\n", - " 1 1\n", - " 1 1\n", - "[torch.FloatTensor of size 2x3x2]" - ] - }, - "execution_count": 84, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "# 手动广播法则\n", - "# 或者 a.view(1,3,2).expand(2,3,2)+b.expand(2,3,2)\n", - "a.unsqueeze(0).expand(2, 3, 2) + b.expand(2,3,2)" - ] - }, - { - "cell_type": "code", - "execution_count": 85, - "metadata": {}, - "outputs": [], - "source": [ - "# expand不会占用额外空间,只会在需要的时候才扩充,可极大节省内存\n", - "e = a.unsqueeze(0).expand(10000000000000, 3,2)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### 3.1.3 内部结构\n", - "\n", - "tensor的数据结构如图3-1所示。tensor分为头信息区(Tensor)和存储区(Storage),信息区主要保存着tensor的形状(size)、步长(stride)、数据类型(type)等信息,而真正的数据则保存成连续数组。由于数据动辄成千上万,因此信息区元素占用内存较少,主要内存占用则取决于tensor中元素的数目,也即存储区的大小。\n", - "\n", - "一般来说一个tensor有着与之相对应的storage, storage是在data之上封装的接口,便于使用,而不同tensor的头信息一般不同,但却可能使用相同的数据。下面看两个例子。\n", - "\n", - "![图3-1: Tensor的数据结构](imgs/tensor_data_structure.svg)" - ] - }, - { - "cell_type": "code", - "execution_count": 86, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - " 0.0\n", - " 1.0\n", - " 2.0\n", - " 3.0\n", - " 4.0\n", - " 5.0\n", - "[torch.FloatStorage of size 6]" - ] - }, - "execution_count": 86, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "a = t.arange(0, 6)\n", - "a.storage()" - ] - }, - { - "cell_type": "code", - "execution_count": 87, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - " 0.0\n", - " 1.0\n", - " 2.0\n", - " 3.0\n", - " 4.0\n", - " 5.0\n", - "[torch.FloatStorage of size 6]" - ] - }, - "execution_count": 87, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "b = a.view(2, 3)\n", - "b.storage()" - ] - }, - { - "cell_type": "code", - "execution_count": 88, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "True" - ] - }, - "execution_count": 88, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "# 一个对象的id值可以看作它在内存中的地址\n", - "# storage的内存地址一样,即是同一个storage\n", - "id(b.storage()) == id(a.storage())" - ] - }, - { - "cell_type": "code", - "execution_count": 89, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "\n", - " 0 100 2\n", - " 3 4 5\n", - "[torch.FloatTensor of size 2x3]" - ] - }, - "execution_count": 89, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "# a改变,b也随之改变,因为他们共享storage\n", - "a[1] = 100\n", - "b" - ] - }, - { - "cell_type": "code", - "execution_count": 90, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - " 0.0\n", - " 100.0\n", - " 2.0\n", - " 3.0\n", - " 4.0\n", - " 5.0\n", - "[torch.FloatStorage of size 6]" - ] - }, - "execution_count": 90, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "c = a[2:] \n", - "c.storage()" - ] - }, - { - "cell_type": "code", - "execution_count": 91, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "(94139619931688, 94139619931680)" - ] - }, - "execution_count": 91, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "c.data_ptr(), a.data_ptr() # data_ptr返回tensor首元素的内存地址\n", - "# 可以看出相差8,这是因为2*4=8--相差两个元素,每个元素占4个字节(float)" - ] - }, - { - "cell_type": "code", - "execution_count": 92, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "\n", - " 0\n", - " 100\n", - "-100\n", - " 3\n", - " 4\n", - " 5\n", - "[torch.FloatTensor of size 6]" - ] - }, - "execution_count": 92, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "c[0] = -100 # c[0]的内存地址对应a[2]的内存地址\n", - "a" - ] - }, - { - "cell_type": "code", - "execution_count": 93, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "\n", - " 6666 100 -100\n", - " 3 4 5\n", - "[torch.FloatTensor of size 2x3]" - ] - }, - "execution_count": 93, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "d = t.Tensor(c.storage())\n", - "d[0] = 6666\n", - "b" - ] - }, - { - "cell_type": "code", - "execution_count": 94, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "True" - ] - }, - "execution_count": 94, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "# 下面4个tensor共享storage\n", - "id(a.storage()) == id(b.storage()) == id(c.storage()) == id(d.storage())" - ] - }, - { - "cell_type": "code", - "execution_count": 95, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "(0, 2, 0)" - ] - }, - "execution_count": 95, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "a.storage_offset(), c.storage_offset(), d.storage_offset()" - ] - }, - { - "cell_type": "code", - "execution_count": 96, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "True" - ] - }, - "execution_count": 96, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "e = b[::2, ::2] # 隔2行/列取一个元素\n", - "id(e.storage()) == id(a.storage())" - ] - }, - { - "cell_type": "code", - "execution_count": 97, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "((3, 1), (6, 2))" - ] - }, - "execution_count": 97, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "b.stride(), e.stride()" - ] - }, - { - "cell_type": "code", - "execution_count": 98, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "False" - ] - }, - "execution_count": 98, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "e.is_contiguous()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "可见绝大多数操作并不修改tensor的数据,而只是修改了tensor的头信息。这种做法更节省内存,同时提升了处理速度。在使用中需要注意。\n", - "此外有些操作会导致tensor不连续,这时需调用`tensor.contiguous`方法将它们变成连续的数据,该方法会使数据复制一份,不再与原来的数据共享storage。\n", - "另外读者可以思考一下,之前说过的高级索引一般不共享stroage,而普通索引共享storage,这是为什么?(提示:普通索引可以通过只修改tensor的offset,stride和size,而不修改storage来实现)。" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### 3.1.4 其它有关Tensor的话题\n", - "这部分的内容不好专门划分一小节,但是笔者认为仍值得读者注意,故而将其放在这一小节。" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "#### 持久化\n", - "Tensor的保存和加载十分的简单,使用t.save和t.load即可完成相应的功能。在save/load时可指定使用的`pickle`模块,在load时还可将GPU tensor映射到CPU或其它GPU上。" - ] - }, - { - "cell_type": "code", - "execution_count": 99, - "metadata": { - "scrolled": true - }, - "outputs": [], - "source": [ - "if t.cuda.is_available():\n", - " a = a.cuda(1) # 把a转为GPU1上的tensor,\n", - " t.save(a,'a.pth')\n", - "\n", - " # 加载为b, 存储于GPU1上(因为保存时tensor就在GPU1上)\n", - " b = t.load('a.pth')\n", - " # 加载为c, 存储于CPU\n", - " c = t.load('a.pth', map_location=lambda storage, loc: storage)\n", - " # 加载为d, 存储于GPU0上\n", - " d = t.load('a.pth', map_location={'cuda:1':'cuda:0'})" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "#### 向量化" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "向量化计算是一种特殊的并行计算方式,相对于一般程序在同一时间只执行一个操作的方式,它可在同一时间执行多个操作,通常是对不同的数据执行同样的一个或一批指令,或者说把指令应用于一个数组/向量上。向量化可极大提高科学运算的效率,Python本身是一门高级语言,使用很方便,但这也意味着很多操作很低效,尤其是`for`循环。在科学计算程序中应当极力避免使用Python原生的`for循环`。" - ] - }, - { - "cell_type": "code", - "execution_count": 100, - "metadata": {}, - "outputs": [], - "source": [ - "def for_loop_add(x, y):\n", - " result = []\n", - " for i,j in zip(x, y):\n", - " result.append(i + j)\n", - " return t.Tensor(result)" - ] - }, - { - "cell_type": "code", - "execution_count": 101, - "metadata": { - "scrolled": true - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "222 µs ± 81.9 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)\n", - "The slowest run took 11.03 times longer than the fastest. This could mean that an intermediate result is being cached.\n", - "5.58 µs ± 7.27 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)\n" - ] - } - ], - "source": [ - "x = t.zeros(100)\n", - "y = t.ones(100)\n", - "%timeit -n 10 for_loop_add(x, y)\n", - "%timeit -n 10 x + y" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "可见二者有超过40倍的速度差距,因此在实际使用中应尽量调用内建函数(buildin-function),这些函数底层由C/C++实现,能通过执行底层优化实现高效计算。因此在平时写代码时,就应养成向量化的思维习惯。" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "此外还有以下几点需要注意:\n", - "- 大多数`t.function`都有一个参数`out`,这时候产生的结果将保存在out指定tensor之中。\n", - "- `t.set_num_threads`可以设置PyTorch进行CPU多线程并行计算时候所占用的线程数,这个可以用来限制PyTorch所占用的CPU数目。\n", - "- `t.set_printoptions`可以用来设置打印tensor时的数值精度和格式。\n", - "下面举例说明。" - ] - }, - { - "cell_type": "code", - "execution_count": 102, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "16777216.0 16777216.0\n" - ] - }, - { - "data": { - "text/plain": [ - "(199999, 199998)" - ] - }, - "execution_count": 102, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "a = t.arange(0, 20000000)\n", - "print(a[-1], a[-2]) # 32bit的IntTensor精度有限导致溢出\n", - "b = t.LongTensor()\n", - "t.arange(0, 200000, out=b) # 64bit的LongTensor不会溢出\n", - "b[-1],b[-2]" - ] - }, - { - "cell_type": "code", - "execution_count": 103, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "\n", - "-0.6379 0.5422 0.0413\n", - " 0.4575 0.8977 2.3465\n", - "[torch.FloatTensor of size 2x3]" - ] - }, - "execution_count": 103, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "a = t.randn(2,3)\n", - "a" - ] - }, - { - "cell_type": "code", - "execution_count": 104, - "metadata": { - "scrolled": false - }, - "outputs": [ - { - "data": { - "text/plain": [ - "\n", - "-0.6378980875 0.5421655774 0.0412697867\n", - "0.4574612975 0.8976946473 2.3464736938\n", - "[torch.FloatTensor of size 2x3]" - ] - }, - "execution_count": 104, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "t.set_printoptions(precision=10)\n", - "a" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### 3.1.5 小试牛刀:线性回归" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "线性回归是机器学习入门知识,应用十分广泛。线性回归利用数理统计中回归分析,来确定两种或两种以上变量间相互依赖的定量关系的,其表达形式为$y = wx+b+e$,$e$为误差服从均值为0的正态分布。首先让我们来确认线性回归的损失函数:\n", - "$$\n", - "loss = \\sum_i^N \\frac 1 2 ({y_i-(wx_i+b)})^2\n", - "$$\n", - "然后利用随机梯度下降法更新参数$\\textbf{w}$和$\\textbf{b}$来最小化损失函数,最终学得$\\textbf{w}$和$\\textbf{b}$的数值。" - ] - }, - { - "cell_type": "code", - "execution_count": 105, - "metadata": {}, - "outputs": [], - "source": [ - "import torch as t\n", - "%matplotlib inline\n", - "from matplotlib import pyplot as plt\n", - "from IPython import display" - ] - }, - { - "cell_type": "code", - "execution_count": 106, - "metadata": {}, - "outputs": [], - "source": [ - "# 设置随机数种子,保证在不同电脑上运行时下面的输出一致\n", - "t.manual_seed(1000) \n", - "\n", - "def get_fake_data(batch_size=8):\n", - " ''' 产生随机数据:y=x*2+3,加上了一些噪声'''\n", - " x = t.rand(batch_size, 1) * 20\n", - " y = x * 2 + (1 + t.randn(batch_size, 1))*3\n", - " return x, y" - ] - }, - { - "cell_type": "code", - "execution_count": 107, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "" - ] - }, - "execution_count": 107, - "metadata": {}, - "output_type": "execute_result" - }, - { - "data": { - "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXQAAAD8CAYAAABn919SAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBo\ndHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAAD11JREFUeJzt3V+MXGd9xvHvU8eU5U+1gWxQvEAN\nKHKpSLHpKkobKaJA64AQMVFRSVtktbShEqhQkEVML4CLKkHmj6peRAokTS5oVArGQS3FWCFtWqmk\n3eAQO3XdFMqfrN14KSzQsqKO+fVix2Bv1t6Z9c7OzLvfj7SamXfP6DxaK0/mvOedc1JVSJJG308N\nOoAkaXVY6JLUCAtdkhphoUtSIyx0SWqEhS5JjbDQJakRFrokNcJCl6RGXLSWO7vkkktq8+bNa7lL\nSRp5Dz744LeqamK57da00Ddv3sz09PRa7lKSRl6Sr3eznVMuktQIC12SGmGhS1Ijli30JE9N8s9J\nvpzkkSTv74y/IMkDSR5N8pdJntL/uJKkc+nmE/oPgVdU1UuBrcC1Sa4CPgB8pKouB74DvLl/MSVJ\ny1l2lUst3AHjfzovN3Z+CngF8Jud8buA9wG3rn5ESRpN+w7OsGf/UY7NzbNpfIxd27ewY9tk3/bX\n1Rx6kg1JHgJOAAeArwBzVfVEZ5PHgP6llKQRs+/gDLv3HmJmbp4CZubm2b33EPsOzvRtn10VelWd\nqqqtwHOBK4EXL7XZUu9NcmOS6STTs7OzK08qSSNkz/6jzJ88ddbY/MlT7Nl/tG/77GmVS1XNAX8H\nXAWMJzk9ZfNc4Ng53nNbVU1V1dTExLJfdJKkJhybm+9pfDV0s8plIsl45/kY8CrgCHAf8OudzXYC\n9/QrpCSNmk3jYz2Nr4ZuPqFfBtyX5GHgX4ADVfXXwLuBdyb5D+DZwO19SylJI2bX9i2Mbdxw1tjY\nxg3s2r6lb/vsZpXLw8C2Jca/ysJ8uiRpkdOrWdZylcuaXpxLktaTHdsm+1rgi/nVf0lqhIUuSY2w\n0CWpERa6JDXCQpekRljoktQIC12SGmGhS1IjLHRJaoSFLkmNsNAlqREWuiQ1wkKXpEZY6JLUCAtd\nkhphoUtSIyx0SWqEhS5JjbDQJakRFrokNcJCl6RGWOiS1AgLXZIaYaFLUiMsdElqhIUuSY2w0CWp\nERa6JDXCQpekRljoktQIC12SGmGhS1IjLHRJaoSFLkmNsNAlqRHLFnqS5yW5L8mRJI8keXtn/H1J\nZpI81Pl5Tf/jSpLO5aIutnkCeFdVfSnJM4EHkxzo/O4jVfXB/sWTJHVr2UKvquPA8c7z7yc5Akz2\nO5gkqTc9zaEn2QxsAx7oDL0tycNJ7khy8SpnkyT1oOtCT/IM4FPAO6rqe8CtwIuArSx8gv/QOd53\nY5LpJNOzs7OrEFmStJSuCj3JRhbK/ONVtRegqh6vqlNV9SPgo8CVS723qm6rqqmqmpqYmFit3JKk\nRbpZ5RLgduBIVX34jPHLztjs9cDh1Y8nSepWN6tcrgbeBBxK8lBn7D3ADUm2AgV8DXhLXxJKkrrS\nzSqXfwSyxK8+u/pxJEkr5TdFJakRFrokNcJCl6RGdHNSVGrSvoMz7Nl/lGNz82waH2PX9i3s2OaX\noDW6LHStS/sOzrB77yHmT54CYGZunt17DwFY6hpZTrloXdqz/+iPy/y0+ZOn2LP/6IASSRfOQte6\ndGxuvqdxaRRY6FqXNo2P9TQujQILXevSru1bGNu44ayxsY0b2LV9y4ASSRfOk6Jal06f+HSVi1pi\noWvd2rFt0gJXU5xykaRGWOiS1AgLXZIaYaFLUiMsdElqhKtcJKlHw3phNwtdknowzBd2c8pFknow\nzBd2s9AlqQfDfGE3C12SejDMF3az0CWpB8N8YTdPikpSD4b5wm4WuiT1aFgv7OaUiyQ1wkKXpEZY\n6JLUCAtdkhphoUtSIyx0SWqEhS5JjbDQJakRFrokNcJCl6RGWOiS1AgLXZIasWyhJ3lekvuSHEny\nSJK3d8afleRAkkc7jxf3P64k6Vy6+YT+BPCuqnoxcBXw1iQ/D9wE3FtVlwP3dl5rBO07OMPVt3yB\nF9z0N1x9yxfYd3Bm0JEkrcCyhV5Vx6vqS53n3weOAJPAdcBdnc3uAnb0K6T65/QNb2fm5il+csNb\nS10aPT3NoSfZDGwDHgCeU1XHYaH0gUtXO5z6b5hveCupN10XepJnAJ8C3lFV3+vhfTcmmU4yPTs7\nu5KM6qNhvuGtpN50VehJNrJQ5h+vqr2d4ceTXNb5/WXAiaXeW1W3VdVUVU1NTEysRmatomG+4a2k\n3nSzyiXA7cCRqvrwGb/6DLCz83wncM/qx1O/DfMNbyX1ppt7il4NvAk4lOShzth7gFuATyR5M/AN\n4A39iah+GuYb3krqTapqzXY2NTVV09PTa7Y/SWpBkgeramq57fymqCQ1wkKXpEZY6JLUCAtdkhph\noUtSI7pZtqhVsu/gjMsDJfWNhb5GTl8E6/R1U05fBAuw1CWtCgt9jZzvIlgW+uB41KSWWOhrxItg\nDR+PmtQaT4quES+CNXy8dLBaY6GvES+CNXw8alJrLPQ1smPbJDdffwWT42MEmBwf4+brr/DQfoA8\nalJrnENfQzu2TVrgQ2TX9i1nzaGDR00abRa61i0vHazWWOha1zxqUkucQ5ekRljoktQIC12SGmGh\nS1IjLHRJaoSFLkmNsNAlqREWuiQ1wkKXpEZY6JLUCAtdkhphoUtSIyx0SWqEhS5JjbDQJakRFrok\nNWIkbnCx7+CMd5WRpGUMfaHvOzhz1n0fZ+bm2b33EIClLklnGPoplz37j551E1+A+ZOn2LP/6IAS\nSdJwGvpCPzY339O4JK1XQ1/om8bHehqXpPVq2UJPckeSE0kOnzH2viQzSR7q/LymXwF3bd/C2MYN\nZ42NbdzAru1b+rVLSRpJ3XxCvxO4donxj1TV1s7PZ1c31k/s2DbJzddfweT4GAEmx8e4+forPCEq\nSYssu8qlqu5Psrn/Uc5tx7ZJC1ySlnEhc+hvS/JwZ0rm4lVLJElakZUW+q3Ai4CtwHHgQ+faMMmN\nSaaTTM/Ozq5wd5Kk5ayo0Kvq8ao6VVU/Aj4KXHmebW+rqqmqmpqYmFhpTknSMlZU6EkuO+Pl64HD\n59pWkrQ2lj0pmuRu4OXAJUkeA94LvDzJVqCArwFv6WNGSVIXulnlcsMSw7f3IYsk6QIM/TdFJUnd\nsdAlqREWuiQ1wkKXpEZY6JLUCAtdkhphoUtSIyx0SWqEhS5JjbDQJakRFrokNcJCl6RGWOiS1AgL\nXZIaYaFLUiMsdElqhIUuSY2w0CWpERa6JDXCQpekRljoktQIC12SGmGhS1IjLHRJaoSFLkmNsNAl\nqREWuiQ1wkKXpEZY6JLUCAtdkhphoUtSIyx0SWqEhS5JjbDQJakRFrokNcJCl6RGLFvoSe5IciLJ\n4TPGnpXkQJJHO48X9zemJGk53XxCvxO4dtHYTcC9VXU5cG/ntSRpgJYt9Kq6H/j2ouHrgLs6z+8C\ndqxyLklSj1Y6h/6cqjoO0Hm8dPUiSZJWou8nRZPcmGQ6yfTs7Gy/dydJ69ZKC/3xJJcBdB5PnGvD\nqrqtqqaqampiYmKFu5MkLWelhf4ZYGfn+U7gntWJI0laqW6WLd4N/BOwJcljSd4M3AL8apJHgV/t\nvJYkDdBFy21QVTec41evXOUskqQL4DdFJakRFrokNcJCl6RGWOiS1AgLXZIaYaFLUiMsdElqhIUu\nSY2w0CWpERa6JDXCQpekRix7LZdRs+/gDHv2H+XY3DybxsfYtX0LO7ZNDjqWJPVdU4W+7+AMu/ce\nYv7kKQBm5ubZvfcQgKUuqXlNTbns2X/0x2V+2vzJU+zZf3RAiSRp7TRV6Mfm5nsal6SWNFXom8bH\nehqXpJY0Vei7tm9hbOOGs8bGNm5g1/YtA0okSWunqZOip098uspF0nrUVKHDQqlb4JLWo6amXCRp\nPbPQJakRFrokNcJCl6RGWOiS1IhU1drtLJkFvr7MZpcA31qDOBfCjKtnFHKacXWMQkYYzpw/W1UT\ny220poXejSTTVTU16BznY8bVMwo5zbg6RiEjjE7OpTjlIkmNsNAlqRHDWOi3DTpAF8y4ekYhpxlX\nxyhkhNHJ+SRDN4cuSVqZYfyELklagaEq9CRfS3IoyUNJpgedZylJxpN8Msm/JTmS5JcGnelMSbZ0\n/n6nf76X5B2DzrVYkj9K8kiSw0nuTvLUQWdaLMnbO/keGaa/YZI7kpxIcviMsWclOZDk0c7jxUOY\n8Q2dv+WPkgx8Fck5Mu7p/Lf9cJJPJxkfZMZeDVWhd/xKVW0d4mVDfwp8rqp+DngpcGTAec5SVUc7\nf7+twC8CPwA+PeBYZ0kyCfwhMFVVLwE2AG8cbKqzJXkJ8PvAlSz8O782yeWDTfVjdwLXLhq7Cbi3\nqi4H7u28HqQ7eXLGw8D1wP1rnmZpd/LkjAeAl1TVLwD/Duxe61AXYhgLfWgl+RngGuB2gKr6v6qa\nG2yq83ol8JWqWu7LXINwETCW5CLgacCxAedZ7MXAF6vqB1X1BPD3wOsHnAmAqrof+Pai4euAuzrP\n7wJ2rGmoRZbKWFVHqmpobvB7joyf7/x7A3wReO6aB7sAw1boBXw+yYNJbhx0mCW8EJgF/jzJwSQf\nS/L0QYc6jzcCdw86xGJVNQN8EPgGcBz4blV9frCpnuQwcE2SZyd5GvAa4HkDznQ+z6mq4wCdx0sH\nnKcFvwv87aBD9GLYCv3qqnoZ8GrgrUmuGXSgRS4CXgbcWlXbgP9l8Ie2S0ryFOB1wF8NOstinfnd\n64AXAJuApyf57cGmOltVHQE+wMIh+OeALwNPnPdNakaSP2bh3/vjg87Si6Eq9Ko61nk8wcK875WD\nTfQkjwGPVdUDndefZKHgh9GrgS9V1eODDrKEVwH/WVWzVXUS2Av88oAzPUlV3V5VL6uqa1g4NH90\n0JnO4/EklwF0Hk8MOM/ISrITeC3wWzVi67qHptCTPD3JM08/B36NhcPeoVFV/wV8M8npu06/EvjX\nAUY6nxsYwumWjm8AVyV5WpKw8HccqpPLAEku7Tw+n4WTecP69wT4DLCz83wncM8As4ysJNcC7wZe\nV1U/GHSeXg3NF4uSvJCfrMa4CPiLqvqTAUZaUpKtwMeApwBfBX6nqr4z2FRn68z5fhN4YVV9d9B5\nlpLk/cBvsHBYexD4var64WBTnS3JPwDPBk4C76yqewccCYAkdwMvZ+GqgI8D7wX2AZ8Ans/C/zDf\nUFWLT5wOOuO3gT8DJoA54KGq2j5kGXcDPw38d2ezL1bVHwwk4AoMTaFLki7M0Ey5SJIujIUuSY2w\n0CWpERa6JDXCQpekRljoktQIC12SGmGhS1Ij/h/CJYJPfXoR0gAAAABJRU5ErkJggg==\n", - "text/plain": [ - "
" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], - "source": [ - "# 来看看产生的x-y分布\n", - "x, y = get_fake_data()\n", - "plt.scatter(x.squeeze().numpy(), y.squeeze().numpy())" - ] - }, - { - "cell_type": "code", - "execution_count": 108, - "metadata": { - "scrolled": false - }, - "outputs": [ - { - "data": { - "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXwAAAD8CAYAAAB0IB+mAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBo\ndHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAAIABJREFUeJzt3Xd4VGXax/HvDQQINVSBQKQaQEDA\nCCgWVBTsiNvsroX1XX13fXcXKWJZK4p1d11d7O66lpUAFhSxgV1BMKGFjhBCJ9RA2vP+MRM3hEky\nyfSZ3+e6uDJz5kzm9nhyz5nnPPM75pxDRETiX51IFyAiIuGhhi8ikiDU8EVEEoQavohIglDDFxFJ\nEGr4IiIJQg1fRCRBqOGLiCQINXwRkQRRL5wv1rp1a9e5c+dwvqSISMjkHygiN7+A0gqJBY3q1+Xo\nVo2pV8cOW3fznoMUlZSSVLcO7Zo1JKVRkl+vs2DBgu3OuTaB1hvWht+5c2fmz58fzpcUEQmZoZM/\npji/4IjlqSnJfDH+jJ/uz1iYy4TMbFoXlfy0LCmpLpNG92XUgNRqX8fM1gejXr+HdMysrpktNLN3\nvPe7mNk3ZrbSzF43s/rBKEhEJFbk+mj2AJsqLJ8yO4eCcs0eoKCohCmzc0JWmy81GcP/PbCs3P0H\ngceccz2AXcB1wSxMRCRabd17kJte+b7SxzukJB92v+IbQHXLQ8Wvhm9mHYHzgGe99w04A3jTu8pL\nwKhQFCgiEi2cc7zx3QaGPzKXOcu2cF7f9jSsd3gbTU6qy9gR6Yctq/gGUN3yUPH3CP9x4Fag1Hu/\nFZDvnCv23t8IVD8QJSISo9Zt389lz3zDrdOy6NW+Ge///hSevHwgky/pR2pKMoZn7P4BH+PyY0ek\nk5xU97Blvt4YQq3ak7Zmdj6w1Tm3wMyGlS32sarPYH0zGwOMAUhLS6tlmSIikVFUUsqzn63l8Q9X\nUL9eHR4Y3ZdfZnSijncGzqgBqdWeeC17fMrsHDblF9AhJZmxI9L9OmEbTP7M0hkKXGhm5wINgWZ4\njvhTzKye9yi/I7DJ15Odc1OBqQAZGRm62oqIxIzsjbsZNy2LpXl7OKdPO/584bG0bdbwsHVmLMz1\nq5H788YQatU2fOfcBGACgPcI/0/OucvN7D/Az4DXgKuBmSGsU0QkbA4UFvPYnBU89/laWjdpwNNX\nHM/IPu2OWK9sumXZDJzc/AImZGYDRLy5+xLIPPxxwGtmdi+wEHguOCWJiETOZyu3MXF6Nht2FnDp\noDTGn9OT5sm+vyBV1XTLmG/4zrlPgU+9t9cAg4JfkohI+O3aX8g97y4l8/tcurZuzOtjhjC4a6sq\nnxMt0y39FdZv2oqIRBvnHG/9sIm7317K7oIi/veM7tx0encaVphV40uHlGSfX74K93RLf6nhi0jC\nys0vYNL0bD7J2cZxnVJ45ZK+9GzXzO/njx2RftgYPkRmuqW/1PBFJOGUlDr++dU6HvJGG9xxfm+u\nPqkzdev4mnFeuWiZbukvNXwRSSg5m/cybloWizbkMyy9DfeO6kPHFo1q/fuiYbqlv9TwRSQk/J2f\nHi4Hi0r4+yereGruapo2TOKJX/XnwuM64EmKSQxq+CISdNE2P/3btTsZn5nFmm37GT0glUnn96Zl\n48QL+FXDF5Ggi5b56XsOFvHge8t55Zsf6dgimZevHcSpxwR8HZGYpYYvIkEXDfPTP1iymdtnLmbb\n3kNcf3IX/nD2MTSqn9gtL7H/60UkJCI5P33rnoPc9fYSZmVvpme7pky9MoPjOqWE/HVjgRq+iARd\nJOanO+d4/bsN3DdrGYeKSxk7Ip0xp3YlqW5NrvMU39TwRSTowj0/fe32/UzIzOLrNTsZ3KUlD4zu\nS9c2TULyWrFMDV9EQiIc89OLSkp55rM1PP7hShrUq8Pk0X35RbmsejmcGr6IxKSsjfmMm5bNsrw9\nnNu3HXddcGRWvRxODV9EYsqBwmIe/WAFz3+xljZNG/CPK49nxLFHZtXLkdTwRSRmzFvhyarfuKuA\nywenMe6cnjRr6DurXo6khi8iEVGT6IWd+wu5952lZC7MpWubxrzxmxMZ1KVlmCuOfWr4IhJ2/kYv\nlGXV//ntpewpKOJ3Z3Tnt35m1cuR1PBFJOz8iV7YuOsAk2Ys5tOcbfTvlMLkGmbVy5Gqbfhm1hCY\nBzTwrv+mc+5OM3sROA3Y7V31GufcolAVKiLxo6rohZJSx8tfrWOKN6v+zgt6c9WJNc+qlyP5c4R/\nCDjDObfPzJKAz83sPe9jY51zb4auPBGJR5VFL7Rp2oDRT33JD0HKqpfDVfudY+exz3s3yfvPhbQq\nEYlrY0ekk1xhHL5eHWP7vkNs2HmAJ37VnxeuOUHNPsj8Cpkws7pmtgjYCsxxzn3jfeg+M8sys8fM\nrEElzx1jZvPNbP62bduCVLaIxLJRA1J5YHRfUr1havXqGMWljlEDUvnwD6dxUf/UhLowSbj41fCd\ncyXOuf5AR2CQmfUBJgA9gROAlsC4Sp471TmX4ZzLaNMmcXOoReRwZ/Rqy2npnp7QrnlDXr52EI/+\non9CXpgkXGo0S8c5l29mnwIjnXMPexcfMrMXgD8FuzgRiU+zl2zmDm9W/Q2ndOH/zlJWfTj4M0un\nDVDkbfbJwHDgQTNr75zLM8/nrlHA4hDXKiIxbsueg9w5cwnvL9lMr/bNeOaqDPp1VFZ9uPjzltoe\neMnM6uIZAnrDOfeOmX3sfTMwYBFwYwjrFJEYVlrqeH3+Bu6ftYzC4lLGjezJ9ad0UVZ9mFXb8J1z\nWcAAH8vPCElFIhJX1mzbx4TMbL5Zu5MhXVvywOh+dGndONJlJSQNmolISBSVlDJ13hqe+GglDevV\n4cFLPFn1mn0TOWr4IhJ0P2zIZ9y0LJZv3st5fdtz54W9adtUWfWRpoYvIkFzoLCYRz5YwQtfrKVt\n04ZMvfJ4zlZWfdRQwxeRoJi7Yhu3ebPqrxiSxq0jlVUfbdTwRSQg5bPqu7VpzH9uPJETOiurPhqp\n4YtIrTjnmLloE3e/s5S9B4v43Zk9uOn0bjSoF9qs+ppcOEUOp4YvIjW2YecBbpuxmHkrtjEgLYXJ\no/uR3q5pyF/X3wuniG9q+CLit5JSx4tfruPh2TnUMfjzhcdyxZCjj8iqD9VRuD8XTpHKqeGLiF+W\n5e1h/LQsfti4m9PT23Dvxf9NuywvlEfhVV04Raqnhi8iVTpYVMLfPl7F03NX0zw5ib9cOoAL+rWv\n9AtUoTwKr+zCKR18vPHIkdTwRaRSX6/ZwcTMbNZs388lAzsy6bxetKgmvrg2R+H+DgGNHZF+2KcH\ngOSkuowdke7nf1FiU8MXkSPsLihi8nvLefXbH+nUMpl/XjeIU3r4dz2Lmh6F12QIqOy+ZunUjhq+\niBzm/cWerPrt+w4x5tSu3DK8R42y6mt6FF7TIaBRA1LV4GtJDV9EgMOz6nu3b8ZzV59A347Na/x7\nanoUrhOx4aOGL5LgSksdr323gQfe82TV3zoynRtO6RpQVn1NjsJ1IjZ81PBFEthqb1b9t2t3cmLX\nVjwwui+dw5xVrxOx4aOGL5KACotLmTpvNX/5eBUN69XhoUv68fOMjjXKqg/Wl6t0IjZ8/LmmbUNg\nHtDAu/6bzrk7zawL8BrQEvgeuNI5VxjKYkUkcIs25DO+LKu+X3vuvKDmWfXB/nKVTsSGhz+DdIeA\nM5xzxwH9gZFmNgR4EHjMOdcD2AVcF7oyRSRQ+w8Vc/fbS7n471+Qf6CIZ67K4MnLBtbqwiRVzayR\n6OXPNW0dsM97N8n7zwFnAJd5l78E3AU8FfwSRSRQn+Zs5bbpi8nNL+DKIUdz68h0mgaQVa+ZNbHJ\nrzF8M6sLLAC6A08Cq4F851yxd5WNgD6PiUSZHfsOcc87S5mxaBPd2zbhzRtPJCMIWfWaWROb/Jp3\n5Zwrcc71BzoCg4Bevlbz9VwzG2Nm881s/rZt22pfqYj4zTnH9IUbGf7oXN7NzuP3Z/bg3d+dHJRm\nD56ZNclJh+fea2ZN9KvRLB3nXL6ZfQoMAVLMrJ73KL8jsKmS50wFpgJkZGT4fFMQkeApn1U/MC2F\nyZf045ijgptVr5k1scmfWTptgCJvs08GhuM5YfsJ8DM8M3WuBmaGslARqVpJqeOFL9byyAcrqGNw\n90XHcsXgo6lTx/+pljWhmTWxx58j/PbAS95x/DrAG865d8xsKfCamd0LLASeC2GdIlKFpZv2MCHT\nk1V/Zs+23DOqj8bT5Qj+zNLJAgb4WL4Gz3i+iETIwaIS/vLRSqbOW0NKo+qz6iWx6Zu2IjHq6zU7\nmJCZzdrt+/n58R257bxepDSqOqteEpsavkgIheLarp6s+mW8+u0G0lo24l/XDebkHq2DVLHEMzV8\nkRAJxbVd31+cx+0zl7Bj3yF+c2pXbhl+DMn161b/RBHU8EVCJpjXdt28+yB3zFzMB0u3cGyHZrxw\nzQn0Sa15Vr0kNjV8kRAJRvxAaanj1e9+ZPKs5RSWlDLhnJ5cd3IX6gWQVS+JSw1fJEQCjR9YvW0f\nE6Zl8+26nZzUrRX3Xxz+rHqJLzpMEAmR2sYPFBaX8rePV3LO45+Rs2UvD/2sH69cP1jNXgKmI3yR\nEKlN/MDCH3cxflo2OVtqn1UvUhk1fJEQ8jd+YP+hYh7+IIcXv1xHu2YNefaqDIb3PioMFUoiUcMX\nibBPcrYyafpiNu0u4IrBgWfVi1RGDV8kQnxl1R9/dHDii0V8UcMXCTNPVn0u97yzlH2HirlleA/+\nZ1g3GtTTF6gktNTwRcJow84DTJyezWcrtzMwLYUHL+lHjyBn1YtURg1fJAyKS0p58ct1P2XV33PR\nsVwewqx6EV/U8EVCbOmmPYzPzCJLWfUSYWr4IiFSllX/j3lraNEoib9dNoDz+iqrXiJHDV8kBL5a\nvYOJ05VVL9FFDV8kiHYfKOKB95bx2neerPpXrh/M0O7Kqpfo4M9FzDsBLwPtgFJgqnPuCTO7C7gB\n2OZddaJzblaoChWJZs453l+8mTveWsLO/YX85rSu3HKmsuoluvhzhF8M/NE5972ZNQUWmNkc72OP\nOeceDl15ItFPWfUSK/y5iHkekOe9vdfMlgGBXaNNJA6Uljr+/e2PPPieJ6t+/Dk9uV5Z9RLFajSG\nb2adgQHAN8BQ4GYzuwqYj+dTwC4fzxkDjAFIS0sLsFyR6LBq6z4mZGbx3bpdDO3uyao/upXiiyW6\nmXPOvxXNmgBzgfucc5lmdhSwHXDAPUB759y1Vf2OjIwMN3/+/ABLFomcwuJS/jF3NX/9eBXJ9esy\n6bxe/Oz4jppqKSFlZguccxmB/h6/jvDNLAmYBrzinMsEcM5tKff4M8A7gRYjEs2+/3EXE7xZ9ef3\na8+dFxxLm6YNIl2WiN/8maVjwHPAMufco+WWt/eO7wNcDCwOTYkikbXvUDEPz87hpa+UVS+xzZ8j\n/KHAlUC2mS3yLpsIXGpm/fEM6awDfhOSCkUi6JPlW5k0w5NVf9WQo/nTCGXVS+zyZ5bO54CvAUrN\nuZe4tX3fIe5+eylv/aCseokf+qatSDnOOTK/z+Wed5eyX1n1EmfU8EW8lFUv8U4NXxKesuolUajh\nS0Jbsmk346dlk52rrHqJf2r4kpAOFpXw+IcreeYzT1b9Xy8dwPn9lFUv8U0NXxLOl6u3MzEzm3U7\nDiirXhKKGr7EpRkLc5kyO4dN+QV0SElm7Ih0Tk9vy/2zlvH6fGXVS2JSw5e4M2NhLhMysykoKgEg\nN7+AW9/MokFSHQ4UliirXhKWGr7EnSmzc35q9mUKS0pxOGbeNFRZ9ZKwFNwtcWdTfoHP5UUlTs1e\nEpoavsSdyhIsUzXdUhKchnQkbhQWl/L03NXs2F94xGPJSXUZOyI9AlWJRA81fIkLC9bvYkJmFiu2\n7OOC4zowqHMLnp675rBZOqMG6MqcktjU8CWmVcyqf+7qDM7s5cmqv/LEzhGtTSTaqOFLzPp4+RYm\nTV9M3p6DXDnkaG4d2ZMmDbRLi1RGfx0Sc8pn1fdo24Q3bzyJ449uEemyRKKeGr7EDOcc077P5V5v\nVv3/DT+GG4d1VVa9iJ/8uaZtJ+BloB1QCkx1zj1hZi2B14HOeC5x+Avn3K7QlSqJ7Mcdnqz6z1dt\n5/ijWzB5dF9l1YvUkD9H+MXAH51z35tZU2CBmc0BrgE+cs5NNrPxwHhgXOhKlURUXFLK81+s5dE5\nK6hXp46y6kUC4M81bfOAPO/tvWa2DEgFLgKGeVd7CfgUNXwJosW5uxmfmcXi3D0M7+XJqm/fXF+e\nEqmtGo3hm1lnYADwDXCU980A51yembUNenWSkA7Pqq/Pk5cN5Ny+7ZRVLxIgvxu+mTUBpgG3OOf2\n+PvHZ2ZjgDEAaWlptalREsiXq7YzYXo263cc4JcZnZh4bi+aN0qKdFkiccGvhm9mSXia/SvOuUzv\n4i1m1t57dN8e2Orruc65qcBUgIyMDBeEmiUO7T5QxH2zlvLG/I0c3aoR/75+MCcpq14kqPyZpWPA\nc8Ay59yj5R56C7gamOz9OTMkFUpcc87xbnYed721lF0HCrnxtG7cMrwHDZM01VIk2Pw5wh8KXAlk\nm9ki77KJeBr9G2Z2HfAj8PPQlCjxKm93AbfPWMyHy7bSJ7UZL/76BMUXi4SQP7N0PgcqG7A/M7jl\nSCIoLXW88s16Hnw/h+LSUiae25Nrh3ahXl2ldYuEkr5pK2G1cstexmdms2D9Lk7u3pr7L+5LWqtG\nkS5LJCGo4UtYHCou4alPV/P3T1bTqEFdHvn5cYwemKqpliJhpIYvIbdg/S7GT8ti5dZ9XHhcB+64\noDetm/i+KpWIhI4avoTMvkPFTHl/OS9/vZ72zRry/DUZnNHzqEiXJZKw1PAlJD5atoVJMxazec9B\nrj6xM38aka6sepEI01+gBNW2vYf489tLeCcrj2OOasKTl5/EwDRl1YtEAzV8CQrnHG8u2Mi97y6j\noLCEP5x1DDee1o369TTVUiRaqOFLwNbv2M/E6dl8sWoHGUe3YPIlfeneVln1ItFGDV9q7Yis+lF9\nuHxQmrLqRaKUGr7UyuFZ9Udxz6hjlVUvEuXU8KVGCgpLePyjFTz72Vpl1YvEGDV88dsXq7YzITOb\nH3cqq14kFqnhS7XyDxRy37vL+M+CjXRu1Yh/3zCYk7opq14k1qjhS6X+m1W/hF0HivifYd34/ZnK\nqheJVWr44tOm/ALumOnJqu+b2pyXrh3EsR2UVS8Sy9Tw5TClpY5/fbOeB99bTolzTDqvF9ec1FlZ\n9SJxQA1fflI+q/6UHq25b5Sy6kXiiRq+cKi4hL9/spq/f7qKxg3qKateJE75cxHz54Hzga3OuT7e\nZXcBNwDbvKtNdM7NClWREjoL1u9k3LRsVm3dx0X9O3D7+cqqF4lX/hzhvwj8DXi5wvLHnHMPB70i\nCYu9B4uYMjuHf369ng7Nk3nhmhM4vWfbSJclIiHkz0XM55lZ59CXIuHy4VJPVv2WvQe55qTO/Ons\ndBorq14k7gXyV36zmV0FzAf+6Jzb5WslMxsDjAFIS0sL4OUkUNv2HuKut5fwblYe6Uc15e9XDFRW\nvUgCqe1cu6eAbkB/IA94pLIVnXNTnXMZzrmMNm3a1PLlJBDOOd6Yv4Hhj85lzpIt/PGsY3j7f09W\nsxdJMLU6wnfObSm7bWbPAO8ErSIJqvU79jMhM5svV+9gUOeW3D+6L93bNol0WSISAbVq+GbW3jmX\n5717MbA4eCVJMBSXlPLs52t5bM4K6tetw30X9+HSE5RVL5LI/JmW+SowDGhtZhuBO4FhZtYfcMA6\n4DchrFFqaHHubsZNy2LJpj2c1fso7rmoD+2aN4x0WSISYf7M0rnUx+LnQlCLBKigsITHP1zBs5+v\npWXj+jx1+UBG9lFWvYh4aC5enCifVX/poE6MH6msehE5nBp+jMs/UMi97y7jzQUb6dK6Ma/eMIQT\nu7WKdFkiEoXU8GOUc463s/K4+21PVv1vh3Xjd8qqF5EqqOHHoE35Bdw+YzEfLd9Kv47NefnawfTu\n0CzSZYlIlFPDjyElpY5/fb2eh95fTqlDWfUiUiNq+DFixZa9jJ+Wxfc/5nNKj9bcf3FfOrVUVr2I\n+E8NP8odKi7hyU9W89Snq2jSoB6P/fI4RvVXVr2I1JwafhjMWJjLlNk5bMovoENKMmNHpDNqQGq1\nz5u/bifjMz1Z9aO8WfWtlFUvIrWkhh9iMxbmMiEzm4KiEgBy8wuYkJkN4LPpz1iYy4PvLydv90EA\nWjRK4oVfn8Dp6cqqF5HAqOGH2JTZOT81+zIFRSVMmZ1zRMOfsTCXW9/MorCk9L/rFpaw+0BRta9T\n208RIpI4NL0jxDblF/i1fOveg0zIzD6s2QMcLC5lyuycKl+j7FNEbn4Bjv9+ipixMDeg2kUkvqjh\nh1iHlOQqlzvneOO7DQx/ZO4RnwTKVPamUaaqTxEiImXU8ENs7Ih0kn18+3X/oWKembeGy575hlun\nZdGzfTPaNvV9QrayN40y/n6KEJHEpoYfYqMGpPLA6L60qBBkll9QxH2zlrFwwy7uv7gvr90whInn\n9jrizSE5qS5jR6RX+RrVfYoQEQE1/LAYNSCVRvV9nx9PSa7PZYM9FyYpe3NITUnGgNSUZB4Y3bfa\nk6++PkX480YhIolFs3TCJLeS4ZUtew4edn/UgNQaz64pW782s3Q0u0ckcajhh8FnK7dRt45RUuqO\neCxYwy61eaOo6XcERCS2VTukY2bPm9lWM1tcbllLM5tjZiu9P1uEtszYtGt/IX984weufO5bWjaq\nT/0KIWeRHnbR7B6RxOLPGP6LwMgKy8YDHznnegAfee/HlRkLcxk6+WO6jH+XoZM/rtGcduccb/2w\nieGPzmXmolxuOr0bn407nYd+1q/G4/OhpNk9IonFn2vazjOzzhUWX4TnwuYALwGfAuOCWFdEBTLU\nkZtfwKTp2XySs43jOjbnX9cPplf7Zj89N5qGSjqkJPs8t6DZPSLxqbazdI5yzuUBeH/GVdBLbYY6\nSkodL36xlrMencvXa3Zy+/m9yfzt0J+afTTS7B6RxBLyk7ZmNgYYA5CWlhbqlwuKmg515Gzey7hp\nWSzakM+px7ThvlF9YiKrPpDZPSISe2rb8LeYWXvnXJ6ZtQe2Vraic24qMBUgIyPjyGkqUcjfoY5D\nxSU8+fEqnpq7Omaz6qNtmElEQqe2QzpvAVd7b18NzAxOOdHBn6GO79bt5NwnPuMvH6/i/H4d+PAP\np3HxgI4x1exFJLFUe4RvZq/iOUHb2sw2AncCk4E3zOw64Efg56EsMtyqGurYe7CIB99fzr++/pHU\nlGRe/PUJDFNWvYjEAHMufKMsGRkZbv78+WF7vWD7YMlm7pi5hK17D3LNSV3449nH0LiBvrsmIqFl\nZguccxmB/h51Kz9s3XuQu95awqzszfRs15Snrzye/p1SIl2WiEiNqOFXwTnHG/M3cN+7yzhYXMrY\nEemMObUrSXWVOScisUcNvxJrt+9nYmY2X63ZwaAuLXlgdF+6tWkS6bJERGpNDb+CopJSnvlsDU98\nuJL69erwwOi+/DKjE3XqaPaNiMQ2NfxysjbmM25aNsvy9jDy2Hb8+aJjOapZw0iXJSISFGr4wIHC\nYh6bs4LnPl9L6yYNePqKgYzs0z7SZYmIBFXCN/zPVm5j4vRsNuws4NJBaYw/pyfNk5Oqf6KISIxJ\n2Ia/a38h97y7lMzvc+naujGvjxnC4K6tIl2WiEjIJFzDL8uqv/vtpewuKOLm07tz8xndaVghSkFE\nJN4kVMOvKqteRCTeJUTDLyl1/POrdTw0Owfn4Pbze3PNSZ2pW2GqpS7oLSLxLO4bvr9Z9bqgt4jE\nu7ht+OWz6ps2TOLxX/bnov4dKo0vruoqV2r4IhIP4rLhf7duJ+OnZbF6235GD0hl0vm9adm4fpXP\n0QW9RSTexVXD33OwiIfKZdW/dO0gTjumjV/P1QW9RSTexU3DL59Vf93JXfjDWTXLqh87Iv2wMXzQ\nBb1FJL7EfMOvmFX/jyuP57haZNXrgt4iEu9ituGHIqteF/QWkXgWUMM3s3XAXqAEKA7GJbj8UT6r\nfrA3q76rsupFRKoUjCP8051z24Pwe6qlrHoRkdqLmSGd8ln15/Rpx58vPJa2yqoXEfFboA3fAR+Y\nmQP+4ZybWnEFMxsDjAFIS0ur8QscKCzm0Q9W8PwXZVn1xzOyT7sAyxYRSTyBNvyhzrlNZtYWmGNm\ny51z88qv4H0TmAqQkZHhavLL563wZNVv3FXA5YPTGHdOT5o1VFa9iEhtBNTwnXObvD+3mtl0YBAw\nr+pnVe+wrPo2jXnjNycyqEvLQH+tiEhCq3XDN7PGQB3n3F7v7bOBuwMppmJW/f+e0Z2bTldWvYhI\nMARyhH8UMN0bRlYP+Ldz7v3a/rKNuw4wacZiPs3ZRv9OKbxySV96tlNWvYhIsNS64Tvn1gDHBVpA\nSanjpS/X8fAHOQDceUFvrjrxyKx6EREJTESnZS7fvIdx07L5YUM+w9LbcO+oPnRscWRWvYiIBC4i\nDf9gUQlPfrKKpz5dTbPkJJ74VX8uPK7yrHoREQlc2Bv+t2t3Mj4zizXb9jN6YCqTzqs+q15ERAIX\n1oafm1/AL/7xFR1bJPPytYM41c+sehERCVxYG/7O/YXccUoX/u+sY2hUP2ZSHURE4kJYu273Nk24\n7bze4XxJERHxqn14fC0k19cXqEREIiWsDV9ERCJHDV9EJEGo4YuIJIioniozY2GuLiouIhIkUdvw\nZyzMZUJmNgVFJYBnDv+EzGwANX0RkVqI2iGdKbNzfmr2ZQqKSpgyOydCFYmIxLaobfib8gtqtFxE\nRKoWtQ2/Q0pyjZaLiEjVorbhjx2RTnKFK10lJ9Vl7Ij0CFUkIhLbovakbdmJWc3SEREJjoAavpmN\nBJ4A6gLPOucmB6Uqr1EDUtXgRUSCpNZDOmZWF3gSOAfoDVxqZkpGExGJUoGM4Q8CVjnn1jjnCoHX\ngIuCU5aIiARbIA0/FdhQ7v5G7zIREYlCgYzh+7oArTtiJbMxwBjv3UNmtjiA1wyX1sD2SBfhB9UZ\nPLFQI6jOYIuVOoMyPTGQhr/0vWA7AAAFWUlEQVQR6FTufkdgU8WVnHNTgakAZjbfOZcRwGuGheoM\nrlioMxZqBNUZbLFUZzB+TyBDOt8BPcysi5nVB34FvBWMokREJPhqfYTvnCs2s5uB2XimZT7vnFsS\ntMpERCSoApqH75ybBcyqwVOmBvJ6YaQ6gysW6oyFGkF1BltC1WnOHXGeVURE4lDUZumIiEhwhaTh\nm9lIM8sxs1VmNt7H4w3M7HXv49+YWedQ1FFNjZ3M7BMzW2ZmS8zs9z7WGWZmu81skfffHeGu01vH\nOjPL9tZwxNl68/iLd3tmmdnAMNeXXm4bLTKzPWZ2S4V1IrItzex5M9tafjqwmbU0szlmttL7s0Ul\nz73au85KM7s6AnVOMbPl3v+n080spZLnVrl/hKHOu8wst9z/23MreW6VfSEMdb5ersZ1ZraokueG\nZXtW1oNCun8654L6D88J3NVAV6A+8APQu8I6vwWe9t7+FfB6sOvwo872wEDv7abACh91DgPeCXdt\nPmpdB7Su4vFzgffwfDdiCPBNBGutC2wGjo6GbQmcCgwEFpdb9hAw3nt7PPCgj+e1BNZ4f7bw3m4R\n5jrPBup5bz/oq05/9o8w1HkX8Cc/9osq+0Ko66zw+CPAHZHcnpX1oFDun6E4wvcncuEi4CXv7TeB\nM83M1xe5QsY5l+ec+957ey+wjNj9pvBFwMvO42sgxczaR6iWM4HVzrn1EXr9wzjn5gE7Kywuv/+9\nBIzy8dQRwBzn3E7n3C5gDjAynHU65z5wzhV7736N57suEVXJ9vRHWKNYqqrT22t+Abwaqtf3RxU9\nKGT7Zygavj+RCz+t492hdwOtQlCLX7xDSgOAb3w8fKKZ/WBm75nZsWEt7L8c8IGZLTDPN5criqaY\ni19R+R9SNGxLgKOcc3ng+aMD2vpYJ5q2KcC1eD7F+VLd/hEON3uHnp6vZAgimrbnKcAW59zKSh4P\n+/as0INCtn+GouH7E7ngVyxDOJhZE2AacItzbk+Fh7/HMzRxHPBXYEa46/Ma6pwbiCeZ9CYzO7XC\n41GxPc3zBbwLgf/4eDhatqW/omKbApjZbUAx8Eolq1S3f4TaU0A3oD+Qh2e4pKKo2Z7ApVR9dB/W\n7VlND6r0aT6WVbs9Q9Hw/Ylc+GkdM6sHNKd2HxMDYmZJeDb0K865zIqPO+f2OOf2eW/PApLMrHWY\ny8Q5t8n7cyswHc/H4/L8irkIg3OA751zWyo+EC3b0mtL2ZCX9+dWH+tExTb1now7H7jceQdvK/Jj\n/wgp59wW51yJc64UeKaS14+W7VkPGA28Xtk64dyelfSgkO2foWj4/kQuvAWUnVX+GfBxZTtzqHjH\n8Z4DljnnHq1knXZl5xbMbBCe7bUjfFWCmTU2s6Zlt/GcyKsYQPcWcJV5DAF2l30kDLNKj5yiYVuW\nU37/uxqY6WOd2cDZZtbCO0RxtndZ2JjnAkPjgAudcwcqWcef/SOkKpwvuriS14+WKJbhwHLn3EZf\nD4Zze1bRg0K3f4bo7PO5eM44rwZu8y67G8+OC9AQz8f+VcC3QNdQng2vpMaT8XwEygIWef+dC9wI\n3Ohd52ZgCZ4ZBV8DJ0Wgzq7e1//BW0vZ9ixfp+G5GM1qIBvIiECdjfA08ObllkV8W+J5A8oDivAc\nFV2H53zRR8BK78+W3nUz8Fy5rey513r30VXAryNQ5yo847Rl+2fZzLYOwKyq9o8w1/lP736XhadZ\nta9Yp/f+EX0hnHV6l79Ytk+WWzci27OKHhSy/VPftBURSRD6pq2ISIJQwxcRSRBq+CIiCUINX0Qk\nQajhi4gkCDV8EZEEoYYvIpIg1PBFRBLE/wOF691fEa+RdwAAAABJRU5ErkJggg==\n", - "text/plain": [ - "
" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "1.9918575286865234 2.954965829849243\n" - ] - } - ], - "source": [ - "# 随机初始化参数\n", - "w = t.rand(1, 1) \n", - "b = t.zeros(1, 1)\n", - "\n", - "lr =0.001 # 学习率\n", - "\n", - "for ii in range(20000):\n", - " x, y = get_fake_data()\n", - " \n", - " # forward:计算loss\n", - " y_pred = x.mm(w) + b.expand_as(y) # x@W等价于x.mm(w);for python3 only\n", - " loss = 0.5 * (y_pred - y) ** 2 # 均方误差\n", - " loss = loss.sum()\n", - " \n", - " # backward:手动计算梯度\n", - " dloss = 1\n", - " dy_pred = dloss * (y_pred - y)\n", - " \n", - " dw = x.t().mm(dy_pred)\n", - " db = dy_pred.sum()\n", - " \n", - " # 更新参数\n", - " w.sub_(lr * dw)\n", - " b.sub_(lr * db)\n", - " \n", - " if ii%1000 ==0:\n", - " \n", - " # 画图\n", - " display.clear_output(wait=True)\n", - " x = t.arange(0, 20).view(-1, 1)\n", - " y = x.mm(w) + b.expand_as(x)\n", - " plt.plot(x.numpy(), y.numpy()) # predicted\n", - " \n", - " x2, y2 = get_fake_data(batch_size=20) \n", - " plt.scatter(x2.numpy(), y2.numpy()) # true data\n", - " \n", - " plt.xlim(0, 20)\n", - " plt.ylim(0, 41)\n", - " plt.show()\n", - " plt.pause(0.5)\n", - " \n", - "print(w.squeeze()[0], b.squeeze()[0])" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "可见程序已经基本学出w=2、b=3,并且图中直线和数据已经实现较好的拟合。" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "虽然上面提到了许多操作,但是只要掌握了这个例子基本上就可以了,其他的知识,读者日后遇到的时候,可以再看看这部份的内容或者查找对应文档。\n" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.5.2" - } - }, - "nbformat": 4, - "nbformat_minor": 2 -} diff --git a/6_pytorch/0_basic/ref_dynamic-graph.ipynb b/6_pytorch/0_basic/ref_dynamic-graph.ipynb deleted file mode 100644 index a1c35e0..0000000 --- a/6_pytorch/0_basic/ref_dynamic-graph.ipynb +++ /dev/null @@ -1,100 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# 动态图和静态图\n", - "目前神经网络框架分为[静态图框架和动态图框架](https://blog.csdn.net/qq_36653505/article/details/87875279),PyTorch 和 TensorFlow、Caffe 等框架最大的区别就是他们拥有不同的计算图表现形式。 TensorFlow 使用静态图,这意味着我们先定义计算图,然后不断使用它,而在 PyTorch 中,每次都会重新构建一个新的计算图。通过这次课程,我们会了解静态图和动态图之间的优缺点。\n", - "\n", - "对于使用者来说,两种形式的计算图有着非常大的区别,同时静态图和动态图都有他们各自的优点,比如动态图比较方便debug,使用者能够用任何他们喜欢的方式进行debug,同时非常直观,而静态图是通过先定义后运行的方式,之后再次运行的时候就不再需要重新构建计算图,所以速度会比动态图更快。" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "![](https://ws3.sinaimg.cn/large/006tNc79ly1fmai482qumg30rs0fmq6e.gif)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## PyTorch" - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "metadata": {}, - "outputs": [], - "source": [ - "# pytorch\n", - "import torch\n", - "first_counter = torch.Tensor([0])\n", - "second_counter = torch.Tensor([10])" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": {}, - "outputs": [], - "source": [ - "while (first_counter < second_counter):\n", - " first_counter += 2\n", - " second_counter += 1" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "tensor([20.])\n", - "tensor([20.])\n" - ] - } - ], - "source": [ - "print(first_counter)\n", - "print(second_counter)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "可以看到 PyTorch 的写法跟 Python 的写法是完全一致的,没有任何额外的学习成本\n", - "\n", - "上面的例子展示如何使用静态图和动态图构建 while 循环,看起来动态图的方式更加简单且直观,你觉得呢?" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.7.9" - } - }, - "nbformat": 4, - "nbformat_minor": 2 -} diff --git a/6_pytorch/1-tensor.ipynb b/6_pytorch/1-tensor.ipynb new file mode 100644 index 0000000..f3dcb0f --- /dev/null +++ b/6_pytorch/1-tensor.ipynb @@ -0,0 +1,726 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Tensor and Variable\n", + "\n", + "\n", + "张量(Tensor)是一种专门的数据结构,非常类似于数组和矩阵。在PyTorch中,我们使用张量来编码模型的输入和输出,以及模型的参数。\n", + "\n", + "张量类似于`NumPy`的`ndarray`,不同之处在于张量可以在GPU或其他硬件加速器上运行。事实上,张量和NumPy数组通常可以共享相同的底层内存,从而消除了复制数据的需要(请参阅使用NumPy的桥接)。张量还针对自动微分进行了优化,在Autograd部分中看到更多关于这一点的内介绍。\n", + "\n", + "`variable`是一种可以不断变化的变量,符合反向传播,参数更新的属性。PyTorch的`variable`是一个存放会变化值的内存位置,里面的值会不停变化,像装糖果(糖果就是数据,即tensor)的盒子,糖果的数量不断变化。pytorch都是由tensor计算的,而tensor里面的参数是variable形式。\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 1. Tensor基本用法\n", + "\n", + "PyTorch基础的数据是张量(Tensor),PyTorch 的很多操作好 NumPy 都是类似的,但是因为其能够在 GPU 上运行,所以有着比 NumPy 快很多倍的速度。本节内容主要包括 PyTorch 中的基本元素 Tensor 和 Variable 及其操作方式。" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 1.1 Tensor定义与生成" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": { + "collapsed": true + }, + "outputs": [], + "source": [ + "import torch\n", + "import numpy as np" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": { + "collapsed": true + }, + "outputs": [], + "source": [ + "# 创建一个 numpy ndarray\n", + "numpy_tensor = np.random.randn(10, 20)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "可以使用下面两种方式将numpy的ndarray转换到tensor上" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": { + "collapsed": true + }, + "outputs": [], + "source": [ + "pytorch_tensor1 = torch.Tensor(numpy_tensor)\n", + "pytorch_tensor2 = torch.from_numpy(numpy_tensor)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "使用以上两种方法进行转换的时候,会直接将 NumPy ndarray 的数据类型转换为对应的 PyTorch Tensor 数据类型" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "同时也可以使用下面的方法将 `PyTorch Tensor` 转换为 `NumPy ndarray`" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": { + "collapsed": true + }, + "outputs": [], + "source": [ + "# 如果 pytorch tensor 在 cpu 上\n", + "numpy_array = pytorch_tensor1.numpy()\n", + "\n", + "# 如果 pytorch tensor 在 gpu 上\n", + "numpy_array = pytorch_tensor1.cpu().numpy()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "需要注意 GPU 上的 Tensor 不能直接转换为 NumPy ndarray,需要使用`.cpu()`先将 GPU 上的 Tensor 转到 CPU 上" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 1.2 PyTorch Tensor 使用 GPU 加速\n", + "\n", + "我们可以使用以下两种方式将 Tensor 放到 GPU 上" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": { + "collapsed": true + }, + "outputs": [], + "source": [ + "# 第一种方式是定义 cuda 数据类型\n", + "dtype = torch.cuda.FloatTensor # 定义默认 GPU 的 数据类型\n", + "gpu_tensor = torch.randn(10, 20).type(dtype)\n", + "\n", + "# 第二种方式更简单,推荐使用\n", + "gpu_tensor = torch.randn(10, 20).cuda(0) # 将 tensor 放到第一个 GPU 上\n", + "gpu_tensor = torch.randn(10, 20).cuda(1) # 将 tensor 放到第二个 GPU 上" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "使用第一种方式将 tensor 放到 GPU 上的时候会将数据类型转换成定义的类型,而是用第二种方式能够直接将 tensor 放到 GPU 上,类型跟之前保持一致\n", + "\n", + "推荐在定义 tensor 的时候就明确数据类型,然后直接使用第二种方法将 tensor 放到 GPU 上" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "而将 tensor 放回 CPU 的操作如下" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": { + "collapsed": true + }, + "outputs": [], + "source": [ + "cpu_tensor = gpu_tensor.cpu()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Tensor 属性的访问方式" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "torch.Size([10, 20])\n", + "torch.Size([10, 20])\n" + ] + } + ], + "source": [ + "# 可以通过下面两种方式得到 tensor 的大小\n", + "print(pytorch_tensor1.shape)\n", + "print(pytorch_tensor1.size())" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "torch.FloatTensor\n", + "torch.cuda.FloatTensor\n" + ] + } + ], + "source": [ + "# 得到 tensor 的数据类型\n", + "print(pytorch_tensor1.type())\n", + "print(gpu_tensor.type())" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2\n" + ] + } + ], + "source": [ + "# 得到 tensor 的维度\n", + "print(pytorch_tensor1.dim())" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "200\n" + ] + } + ], + "source": [ + "# 得到 tensor 的所有元素个数\n", + "print(pytorch_tensor1.numel())" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 2. Tensor的操作\n", + "Tensor 操作中的 API 和 NumPy 非常相似,如果熟悉 NumPy 中的操作,那么 tensor 基本操作是一致的,下面列举其中的一些操作" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 2.1 基本操作" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "tensor([[1., 1.],\n", + " [1., 1.],\n", + " [1., 1.]])\n" + ] + } + ], + "source": [ + "x = torch.ones(3, 2)\n", + "print(x) # 这是一个float tensor" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "torch.FloatTensor\n" + ] + } + ], + "source": [ + "print(x.type())" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "tensor([[1, 1],\n", + " [1, 1],\n", + " [1, 1]])\n" + ] + } + ], + "source": [ + "# 将其转化为整形\n", + "x = x.long()\n", + "# x = x.type(torch.LongTensor)\n", + "print(x)" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "tensor([[1., 1.],\n", + " [1., 1.],\n", + " [1., 1.]])\n" + ] + } + ], + "source": [ + "# 再将其转回 float\n", + "x = x.float()\n", + "# x = x.type(torch.FloatTensor)\n", + "print(x)" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "tensor([[-1.2200, 0.9769, -2.3477],\n", + " [ 1.0125, -1.3236, -0.2626],\n", + " [-0.3501, 0.5753, 1.5657],\n", + " [ 0.4823, -0.4008, -1.3442]])\n" + ] + } + ], + "source": [ + "x = torch.randn(4, 3)\n", + "print(x)" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": { + "collapsed": true + }, + "outputs": [], + "source": [ + "# 沿着行取最大值\n", + "max_value, max_idx = torch.max(x, dim=1)" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([0.9769, 1.0125, 1.5657, 0.4823])" + ] + }, + "execution_count": 19, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# 每一行的最大值\n", + "max_value" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([1, 0, 2, 0])" + ] + }, + "execution_count": 20, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# 每一行最大值的下标\n", + "max_idx" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "tensor([-2.5908, -0.5736, 1.7909, -1.2627])\n" + ] + } + ], + "source": [ + "# 沿着行对 x 求和\n", + "sum_x = torch.sum(x, dim=1)\n", + "print(sum_x)" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "torch.Size([4, 3])\n", + "torch.Size([1, 4, 3])\n", + "tensor([[[-1.2200, 0.9769, -2.3477],\n", + " [ 1.0125, -1.3236, -0.2626],\n", + " [-0.3501, 0.5753, 1.5657],\n", + " [ 0.4823, -0.4008, -1.3442]]])\n" + ] + } + ], + "source": [ + "# 增加维度或者减少维度\n", + "print(x.shape)\n", + "x = x.unsqueeze(0) # 在第一维增加\n", + "print(x.shape)\n", + "print(x)" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "torch.Size([1, 1, 4, 3])\n" + ] + } + ], + "source": [ + "x = x.unsqueeze(1) # 在第二维增加\n", + "print(x.shape)" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "torch.Size([1, 4, 3])\n", + "tensor([[[-1.2200, 0.9769, -2.3477],\n", + " [ 1.0125, -1.3236, -0.2626],\n", + " [-0.3501, 0.5753, 1.5657],\n", + " [ 0.4823, -0.4008, -1.3442]]])\n" + ] + } + ], + "source": [ + "x = x.squeeze(0) # 减少第一维\n", + "print(x.shape)\n", + "print(x)" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "torch.Size([4, 3])\n" + ] + } + ], + "source": [ + "x = x.squeeze() # 将 tensor 中所有的一维全部都去掉\n", + "print(x.shape)" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "torch.Size([3, 4, 5])\n", + "torch.Size([4, 3, 5])\n", + "torch.Size([5, 3, 4])\n" + ] + } + ], + "source": [ + "x = torch.randn(3, 4, 5)\n", + "print(x.shape)\n", + "\n", + "# 使用permute和transpose进行维度交换\n", + "x = x.permute(1, 0, 2) # permute 可以重新排列 tensor 的维度\n", + "print(x.shape)\n", + "\n", + "x = x.transpose(0, 2) # transpose 交换 tensor 中的两个维度\n", + "print(x.shape)" + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "torch.Size([3, 4, 5])\n", + "torch.Size([12, 5])\n", + "torch.Size([3, 20])\n" + ] + } + ], + "source": [ + "# 使用 view 对 tensor 进行 reshape\n", + "x = torch.randn(3, 4, 5)\n", + "print(x.shape)\n", + "\n", + "x = x.view(-1, 5) # -1 表示任意的大小,5 表示第二维变成 5\n", + "print(x.shape)\n", + "\n", + "x = x.view(3, 20) # 重新 reshape 成 (3, 20) 的大小\n", + "print(x.shape)" + ] + }, + { + "cell_type": "code", + "execution_count": 32, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "tensor([[-3.1321, -0.9734, 0.5307, 0.4975],\n", + " [ 0.8537, 1.3424, 0.2630, -1.6658],\n", + " [-1.0088, -2.2100, -1.9233, -0.3059]])\n" + ] + } + ], + "source": [ + "x = torch.randn(3, 4)\n", + "y = torch.randn(3, 4)\n", + "\n", + "# 两个 tensor 求和\n", + "z = x + y\n", + "# z = torch.add(x, y)\n", + "print(z)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 2.2 `inplace`操作\n", + "另外,pytorch中大多数的操作都支持 `inplace` 操作,也就是可以直接对 tensor 进行操作而不需要另外开辟内存空间,方式非常简单,一般都是在操作的符号后面加`_`,比如" + ] + }, + { + "cell_type": "code", + "execution_count": 33, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "torch.Size([3, 3])\n", + "torch.Size([1, 3, 3])\n", + "torch.Size([3, 1, 3])\n" + ] + } + ], + "source": [ + "x = torch.ones(3, 3)\n", + "print(x.shape)\n", + "\n", + "# unsqueeze 进行 inplace\n", + "x.unsqueeze_(0)\n", + "print(x.shape)\n", + "\n", + "# transpose 进行 inplace\n", + "x.transpose_(1, 0)\n", + "print(x.shape)" + ] + }, + { + "cell_type": "code", + "execution_count": 34, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "tensor([[1., 1., 1.],\n", + " [1., 1., 1.],\n", + " [1., 1., 1.]])\n", + "tensor([[2., 2., 2.],\n", + " [2., 2., 2.],\n", + " [2., 2., 2.]])\n" + ] + } + ], + "source": [ + "x = torch.ones(3, 3)\n", + "y = torch.ones(3, 3)\n", + "print(x)\n", + "\n", + "# add 进行 inplace\n", + "x.add_(y)\n", + "print(x)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 练习题\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "* 查阅[PyTorch的Tensor文档](http://pytorch.org/docs/tensors.html)了解 tensor 的数据类型,创建一个 float64、大小是 3 x 2、随机初始化的 tensor,将其转化为 numpy 的 ndarray,输出其数据类型\n", + "* 查阅[PyTorch的Tensor文档](http://pytorch.org/docs/tensors.html)了解 tensor 更多的 API,创建一个 float32、4 x 4 的全为1的矩阵,将矩阵正中间 2 x 2 的矩阵,全部修改成2" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 参考\n", + "* http://pytorch.org/tutorials/beginner/deep_learning_60min_blitz.html\n", + "* http://cs231n.github.io/python-numpy-tutorial/" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.7" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/6_pytorch/1_NN/1-linear-regression-gradient-descend.ipynb b/6_pytorch/1_NN/1-linear-regression-gradient-descend.ipynb deleted file mode 100644 index f277f81..0000000 --- a/6_pytorch/1_NN/1-linear-regression-gradient-descend.ipynb +++ /dev/null @@ -1,962 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# 线性模型和梯度下降\n", - "\n", - "本节我们简单回顾一下线性回归模型,并演示一下如何使用PyTorch来对线性回归模型进行建模和模型参数计算。" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "\n" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 1. 一元线性回归\n", - "一元线性模型非常简单,假设我们有变量 $x_i$ 和目标 $y_i$,每个 i 对应于一个数据点,希望建立一个模型\n", - "\n", - "$$\n", - "\\hat{y}_i = w x_i + b\n", - "$$\n", - "\n", - "$\\hat{y}_i$ 是我们预测的结果,希望通过 $\\hat{y}_i$ 来拟合目标 $y_i$,通俗来讲就是找到这个函数拟合 $y_i$ 使得误差最小,即最小化\n", - "\n", - "$$\n", - "\\frac{1}{n} \\sum_{i=1}^n(\\hat{y}_i - y_i)^2\n", - "$$" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "那么如何最小化这个误差呢?\n", - "\n", - "这里需要用到**梯度下降**,这是我们接触到的第一个优化算法,非常简单,但是却非常强大,在深度学习中被大量使用,所以让我们从简单的例子出发了解梯度下降法的原理" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 2. 梯度下降法\n", - "在梯度下降法中,我们首先要明确梯度的概念,随后我们再了解如何使用梯度进行下降。" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### 2.1 梯度\n", - "梯度在数学上就是导数,如果是一个多元函数,那么梯度就是偏导数。比如一个函数f(x, y),那么 f 的梯度就是 \n", - "\n", - "$$\n", - "(\\frac{\\partial f}{\\partial x},\\ \\frac{\\partial f}{\\partial y})\n", - "$$\n", - "\n", - "可以称为 grad f(x, y) 或者 $\\nabla f(x, y)$。具体某一点 $(x_0,\\ y_0)$ 的梯度就是 $\\nabla f(x_0,\\ y_0)$。\n", - "\n", - "下面这个图片是 $f(x) = x^2$ 这个函数在 x=1 处的梯度\n", - "\n", - "![](https://ws3.sinaimg.cn/large/006tNc79ly1fmarbuh2j3j30ba0b80sy.jpg)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "梯度有什么意义呢?从几何意义来讲,一个点的梯度值是这个函数变化最快的地方,具体来说,对于函数 f(x, y),在点 $(x_0, y_0)$ 处,沿着梯度 $\\nabla f(x_0,\\ y_0)$ 的方向,函数增加最快,也就是说沿着梯度的方向,我们能够更快地找到函数的极大值点,或者反过来沿着梯度的反方向,我们能够更快地找到函数的最小值点。" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### 2.2 梯度下降法\n", - "有了对梯度的理解,我们就能了解梯度下降发的原理了。上面我们需要最小化这个误差,也就是需要找到这个误差的最小值点,那么沿着梯度的反方向我们就能够找到这个最小值点。\n", - "\n", - "我们可以来看一个直观的解释。比如我们在一座大山上的某处位置,由于我们不知道怎么下山,于是决定走一步算一步,也就是在每走到一个位置的时候,求解当前位置的梯度,沿着梯度的负方向,也就是当前最陡峭的位置向下走一步,然后继续求解当前位置梯度,向这一步所在位置沿着最陡峭最易下山的位置走一步。这样一步步的走下去,一直走到觉得我们已经到了山脚。当然这样走下去,有可能我们不能走到山脚,而是到了某一个局部的山峰低处。\n", - "\n", - "类比我们的问题,就是沿着梯度的反方向,我们不断改变 w 和 b 的值,最终找到一组最好的 w 和 b 使得误差最小。\n", - "\n", - "在更新的时候,我们需要决定每次更新的幅度,比如在下山的例子中,我们需要每次往下走的那一步的长度,这个长度称为学习率,用 $\\eta$ 表示,这个学习率非常重要,不同的学习率都会导致不同的结果,学习率太小会导致下降非常缓慢,学习率太大又会导致跳动非常明显,可以看看下面的例子\n", - "\n", - "![](https://ws2.sinaimg.cn/large/006tNc79ly1fmgn23lnzjg30980gogso.gif)\n", - "\n", - "可以看到上面的学习率较为合适,而下面的学习率太大,就会导致不断跳动\n", - "\n", - "最后我们的更新公式就是\n", - "\n", - "$$\n", - "w := w - \\eta \\frac{\\partial f(w,\\ b)}{\\partial w} \\\\\n", - "b := b - \\eta \\frac{\\partial f(w,\\ b)}{\\partial b}\n", - "$$\n", - "\n", - "通过不断地迭代更新,最终我们能够找到一组最优的 w 和 b,这就是梯度下降法的原理。\n", - "\n", - "最后可以通过这张图形象地说明一下这个方法\n", - "\n", - "![](https://ws3.sinaimg.cn/large/006tNc79ly1fmarxsltfqj30gx091gn4.jpg)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "\n" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### 2.3 PyTorch实现\n", - "\n", - "上面是原理部分,下面通过一个例子来进一步学习线性模型" - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "" - ] - }, - "execution_count": 1, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "import torch\n", - "import numpy as np\n", - "from torch.autograd import Variable\n", - "\n", - "torch.manual_seed(2021)" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "[]" - ] - }, - "execution_count": 2, - "metadata": {}, - "output_type": "execute_result" - }, - { - "data": { - "image/png": "iVBORw0KGgoAAAANSUhEUgAAAWoAAAD4CAYAAADFAawfAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAAAPB0lEQVR4nO3df4xsZ13H8fd3uam6TQXSe2tM6e5CBKS5Biibpv5BlVRJbUybKGrJIoKVDWCq6F8k+4dGc/8gURNNiLrB366ILWJugjb1B9hIaHEuLbSUQNp699JS6aK0GjfQln7948z23m5m75y5O+fMc+a8X8lmZs6cO/t9ZrafPnPO8zwnMhNJUrkWZl2AJOn8DGpJKpxBLUmFM6glqXAGtSQV7kgTL3r06NFcWVlp4qUlaS6dOnXq65l5bNRzjQT1ysoKg8GgiZeWpLkUEdsHPeehD0kqnEEtSYUzqCWpcAa1JBXOoJakwhnUknRIW1uwsgILC9Xt1tZ0X7+R4XmS1BdbW7C+Dru71ePt7eoxwNradH6HPWpJOoSNjbMhvWd3t9o+LQa1JB3CmTOTbb8QBrUkHcLS0mTbL4RBLUmHcOIELC6+cNviYrV9WgxqSTqEtTXY3ITlZYiobjc3p3ciERz1IUmHtrY23WDezx61JBXOoJakwhnUklQ4g1qSCmdQS1LhDGpJKpxBLUmFM6glqXAGtSQVzqCWpMIZ1JJUOINakgpnUEtS4QxqSSqcQS1JhTOoJalwBrUkFc6glqTCGdSSVDiDWpIKZ1BLUuEMakkqnEEtSYWrFdQR8csR8UBEfCEi3tdwTZKkc4wN6og4DrwLuBp4LfDjEfF9TRcmSarU6VG/BrgnM3cz81ngX4GfaLYsSdKeOkH9APDGiLg0IhaBG4Ar9u8UEesRMYiIwc7OzrTrlKTeGhvUmflF4APAncAdwH3At0fst5mZq5m5euzYsWnXKUm9VetkYmb+UWa+ITOvBb4BfLnZsiRJe47U2SkiLsvMJyJiier49DXNliVJ2lMrqIGPRsSlwDPAL2bmk82VJEk6V62gzsw3Nl2IJGk0ZyZKUuEMakkqnEEtSYUzqCWpcAa1JBXOoJakwhnUklQ4g1qaI1tbsLICCwvV7dbWrCvSNNSdmSipcFtbsL4Ou7vV4+3t6jHA2trs6tLh2aOW5sTGxtmQ3rO7W21XtxnU0pw4c2ay7eoOg1qaE0tLk21XdxjU0pw4cQIWF1+4bXGx2q5uM6ilObG2BpubsLwMEdXt5qYnEueBoz6kObK2ZjDPI3vUklQ4g1qSCmdQS1LhDGpJKpxBLUmFM6glqXAGtSQVzqCWpMIZ1JJUOINakgpnUEtS4QxqSSqcQS1JhTOoJalwBrVUoJKuJl5SLX1lUKu3Sg2gvauJb29D5tmric+ivmnWUur73QWRmVN/0dXV1RwMBlN/XWla9gLo3Kt2Ly6WcUWUlZUqEPdbXobTp7tZS8nvdyki4lRmro58rk5QR8SvAL8AJHA/8M7M/OZB+xvUKl1JYbjfwkLVe90vAp57rpu1lPx+l+J8QT320EdEXA78ErCamceBFwE3T7dEqV1nzky2vU0lXU18WrWU/H53Qd1j1EeA74qII8Ai8NXmSpKaV1IY7lfS1cSnVUvJ73cXjA3qzHwM+C3gDPA48FRm3rl/v4hYj4hBRAx2dnamX6k0RSWF4X4lXU18WrWU/H53wdhj1BHxUuCjwM8ATwK3Abdn5l8e9G88Rq0u2NqCjY3q6/fSUhUanthqju/3+R3qZGJE/BRwfWbeMnz8duCazHzvQf/GoJakyRzqZCLVIY9rImIxIgK4DvjiNAuUJB2szjHqe4Dbgc9SDc1bADYbrkuSNFRr1Edm/lpmfn9mHs/Mn83MbzVdmKRuceZhc47MugBJ3bd/5uHeVHPwhOE0uNaHWmWvaz5tbLxwejhUjzc2ZlPPvLFHrdbY65pfzjxslj1qtabvva55/jbhzMNmGdRqTZ97XSUtXdoEZx42y6BWa/rc65r3bxMlTXufRwa1WtPnXlcfvk2srVVLlj73XHVrSE+PQa3W9LnX1edvEzo8g1qt6muvq8/fJnR4BrXUgj5/m9DhOY5aasnamsGsC2OPWpIKZ1BLOtA8T9LpEg99SBrJKf/lsEctaaR5n6TTJQa1pJH6MEmnKwxqSSM5SaccBrWkkZykUw6DWtJITtIph0EtdVBbw+b6OuW/NA7PkzrGYXP9Y49a6hiHzfWPQS11jMPm+segljrGYXP9Y1BLHeOwuf4xqKWOcdhc/zjqQ+og17buF3vUklQ4g1qSCmdQS1LhDGpJKpxBLUmFGxvUEfHqiLjvnJ//iYj3tVCbJIkaQZ2ZX8rM12Xm64A3ALvAx5ouTNLseFHbskw6jvo64OHM3G6iGEmz5+p85Zn0GPXNwIebKERSGVydrzy1gzoiLgJuBG474Pn1iBhExGBnZ2da9UlqmavzlWeSHvWPAZ/NzK+NejIzNzNzNTNXjx07Np3qJLXO1fnKM0lQv5UGD3t48kIqg6vzladWUEfExcCPAn/bRBF7Jy+2tyHz7MkLw1pqn6vzlScyc+ovurq6moPBoPb+KytVOO+3vFxdUFOS5l1EnMrM1VHPFTEz0ZMXknSwIoLakxeSdLAigtqTF/V50lXqnyKC2pMX9XjSVeqnIk4mqh5Pukrzq/iTiarHk65SPxnUHeJJV6mfDOoO8aSr1E8GdYd40lXqp0nXo9aMra0ZzFLf2KOWpMIZ1JJUOINakgpnUEtS4QxqSSqcQS1JhTOoJalwBnUPuVSq1C1OeOmZvaVSd3erx3tLpYITaaRS2aPumY2NsyG9Z3e32i6pTAZ1z0xjqVQPnUjtMqh75rBLpXqVGal9BnXPHHapVA+dSO0zqHvmsEulepUZqX2O+uihwyyVurQ0+rqNXmVGao49ak3Eq8xI7TOoO2wWoy+8yozUPg99dNQsJ654lRmpXfaoO8rRF1J/GNQd5egLqT8M6o467MQVSd1hUHeUoy+k/jCoO8rRF1J/1Br1EREvAT4EHAcS+PnM/HSDdakGR19I/VB3eN7vAndk5lsi4iJgcdw/kCRNx9igjogXA9cC7wDIzKeBp5stS5K0p84x6pcDO8CfRMS9EfGhiLh4/04RsR4Rg4gY7OzsTL1QSeqrOkF9BLgK+P3MfD3wf8D79++UmZuZuZqZq8eOHZtymZLUX3WC+lHg0cy8Z/j4dqrgliS1YGxQZ+Z/Al+JiFcPN10HPNhoVZKk59Ud9XErsDUc8fEI8M7mSpIknatWUGfmfcBqs6VIkkbp1cxEr54tqYt6sx71LNdvlqTD6E2P2vWbJXVVb4La9ZsldVVvgtr1myV1VW+C2vWbJXVVb4La9ZsldVVvRn2A6zdL6qbe9KglqasM6gI4EUfS+fTq0EeJnIgjaRx71DPmRBxJ4xjUM+ZEHEnjGNQz5kQcSeMY1DPmRBxJ4xjUM+ZEHEnjOOqjAE7EkXQ+9qglqXAGtSQVzqCWpMIZ1JJUOINakgpnUEtS4QxqSSqcQS1JhTOoJalwBrUkFc6glqTCGdSSVDiDWpIKZ1BLUuEMakkqXK31qCPiNPC/wLeBZzNztcmiJElnTXLhgDdl5tcbq0SSNNLcHPrY2oKVFVhYqG63tmZdkSRNR92gTuDOiDgVEeujdoiI9YgYRMRgZ2dnehXWsLUF6+uwvQ2Z1e36umEtaT5EZo7fKeLyzHwsIi4D/hG4NTPvOmj/1dXVHAwGUyzz/FZWqnDeb3kZTp9urQxJumARceqg83+1etSZ+djw9gngY8DV0yvv8M6cmWy7JHXJ2KCOiIsj4pK9+8CbgQeaLmwSS0uTbZekLqnTo/4e4N8i4nPAZ4CPZ+YdzZY1mRMnYHHxhdsWF6vtktR1Y4fnZeYjwGtbqOWCra1Vtxsb1eGOpaUqpPe2S1KXTTKOumhrawazpPk0N+OoJWleGdSSVDiDWpIKZ1BLUuEMakkqXK0p5BO/aMQOMGJS9/OOAn1dic+291Nf297XdsPkbV/OzGOjnmgkqMeJiEFf17S27ba9T/rabphu2z30IUmFM6glqXCzCurNGf3eEtj2fupr2/vabphi22dyjFqSVJ+HPiSpcAa1JBWu0aCOiOsj4ksR8VBEvH/E898RER8ZPn9PRKw0WU+barT9VyPiwYj4fET8c0Qsz6LOJoxr+zn7/WREZETMxfCtOu2OiJ8efu5fiIi/arvGptT4e1+KiE9ExL3Dv/kbZlHntEXEH0fEExEx8mIqUfm94fvy+Yi46oJ+UWY28gO8CHgYeAVwEfA54Mp9+7wX+IPh/ZuBjzRVT5s/Ndv+JmBxeP89fWr7cL9LgLuAu4HVWdfd0mf+SuBe4KXDx5fNuu4W274JvGd4/0rg9KzrnlLbrwWuAh444PkbgH8AArgGuOdCfk+TPeqrgYcy85HMfBr4a+CmffvcBPzZ8P7twHUREQ3W1Jaxbc/MT2Tm7vDh3cDLWq6xKXU+d4DfBD4AfLPN4hpUp93vAj6Ymd+A569BOg/qtD2B7x7efzHw1Rbra0xWF/n+7/PschPw51m5G3hJRHzvpL+nyaC+HPjKOY8fHW4buU9mPgs8BVzaYE1tqdP2c91C9X/deTC27cOvf1dk5sfbLKxhdT7zVwGviohPRcTdEXF9a9U1q07bfx14W0Q8Cvw9cGs7pc3cpFkw0txc4aWrIuJtwCrwQ7OupQ0RsQD8DvCOGZcyC0eoDn/8MNU3qLsi4gcy88lZFtWStwJ/mpm/HRE/CPxFRBzPzOdmXVgXNNmjfgy44pzHLxtuG7lPRByh+kr0Xw3W1JY6bScifgTYAG7MzG+1VFvTxrX9EuA48MmIOE113O7kHJxQrPOZPwqczMxnMvM/gC9TBXfX1Wn7LcDfAGTmp4HvpFq0aN7VyoJxmgzqfwdeGREvj4iLqE4Wnty3z0ng54b33wL8Sw6PwHfc2LZHxOuBP6QK6Xk5Vglj2p6ZT2Xm0cxcycwVquPzN2bmYDblTk2dv/e/o+pNExFHqQ6FPNJijU2p0/YzwHUAEfEaqqDeabXK2TgJvH04+uMa4KnMfHziV2n4jOgNVL2Gh4GN4bbfoPoPE6oP6zbgIeAzwCtmfRa3xbb/E/A14L7hz8lZ19xW2/ft+0nmYNRHzc88qA77PAjcD9w865pbbPuVwKeoRoTcB7x51jVPqd0fBh4HnqH6xnQL8G7g3ed85h8cvi/3X+jfulPIJalwzkyUpMIZ1JJUOINakgpnUEtS4QxqSSqcQS1JhTOoJalw/w9HECtz8n/B+wAAAABJRU5ErkJggg==\n", - "text/plain": [ - "
" - ] - }, - "metadata": { - "needs_background": "light" - }, - "output_type": "display_data" - } - ], - "source": [ - "# 生层测试数据\n", - "x_train = np.random.rand(20, 1)\n", - "y_train = x_train * 3 + 4 + 3*np.random.rand(20,1)\n", - "\n", - "# 画出图像\n", - "import matplotlib.pyplot as plt\n", - "%matplotlib inline\n", - "\n", - "plt.plot(x_train, y_train, 'bo')" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": {}, - "outputs": [], - "source": [ - "# 转换成 Tensor\n", - "x_train = torch.from_numpy(x_train)\n", - "y_train = torch.from_numpy(y_train)\n", - "\n", - "# 定义参数 w 和 b\n", - "w = Variable(torch.randn(1), requires_grad=True) # 随机初始化\n", - "b = Variable(torch.zeros(1), requires_grad=True) # 使用 0 进行初始化" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": {}, - "outputs": [], - "source": [ - "# 构建线性回归模型\n", - "x_train = Variable(x_train)\n", - "y_train = Variable(y_train)\n", - "\n", - "def linear_model(x):\n", - " return x * w + b\n", - "\n", - "def logistc_regression(x):\n", - " return torch.sigmoid(x*w+b) " - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "metadata": {}, - "outputs": [], - "source": [ - "y_ = linear_model(x_train)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "经过上面的步骤我们就定义好了模型,在进行参数更新之前,我们可以先看看模型的输出结果长什么样" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "" - ] - }, - "execution_count": 6, - "metadata": {}, - "output_type": "execute_result" - }, - { - "data": { - "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXAAAAD4CAYAAAD1jb0+AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAAAXy0lEQVR4nO3df3Ac5X3H8c9XRmAELqGyygCOJDOTeBA2GFsQM5kYF4xxgQkwpC1UBJzBtYHCkLahA6PpQAs07UwbF5gkoCEOPyQSgmmop6UtBcw4P2xADoIkNtiJkY2MEyuGOOAf9Q99+8eehHzWSafbvb3dvfdr5ubuVqvd5zmZD889++zzmLsLAJA+NZUuAACgNAQ4AKQUAQ4AKUWAA0BKEeAAkFJHxXmyyZMne3Nzc5ynBIDUW7du3W/cvSF/e6wB3tzcrO7u7jhPCQCpZ2ZbRto+ZheKmS03sx1m9rNh237fzP7XzDblnk+MsrAAgLEV0wf+qKSFedvukPSiu39K0ou59wCAGI0Z4O6+WtL7eZsvl/RY7vVjkq6ItlgAgLGU2gd+krtvz73+laSTCu1oZkskLZGkxsbGI35+4MAB9fX1ad++fSUWBfkmTpyoKVOmqLa2ttJFAVBGoS9iurubWcEJVdy9Q1KHJLW2th6xX19fnyZNmqTm5maZWdjiVD13186dO9XX16epU6dWujgAyqjUceC/NrOTJSn3vKPUAuzbt0/19fWEd0TMTPX19XyjARKgq0tqbpZqaoLnrq5oj19qgK+UdH3u9fWS/j1MIQjvaPF5ApXX1SUtWSJt2SK5B89LlkQb4sUMI/yOpDWSpplZn5ndIOkfJV1kZpskzc+9BwDktLdLe/Ycvm3PnmB7VMbsA3f3awr86MLoipFugzcoTZ48udJFAZAQW7eOb3spUjcXSrn7lNxdAwMD0R4UQNUZYdDdqNtLkaoAL1efUm9vr6ZNm6brrrtO06dP1z333KNzzjlHZ555pu66666h/a644grNnj1bZ5xxhjo6OkLWBkCW3XefVFd3+La6umB7VFIV4OXsU9q0aZNuvvlmLVu2TNu2bdOrr76qnp4erVu3TqtXr5YkLV++XOvWrVN3d7ceeOAB7dy5M/yJAWRSW5vU0SE1NUlmwXNHR7A9KrFOZhVWOfuUmpqaNGfOHH3lK1/R888/r7PPPluS9NFHH2nTpk2aO3euHnjgAX3/+9+XJL377rvatGmT6uvrw58cQCa1tUUb2PlSFeCNjUG3yUjbwzruuOMkBX3gd955p5YuXXrYz19++WW98MILWrNmjerq6jRv3jzGWgOoqFR1ocTRp3TxxRdr+fLl+uijjyRJ27Zt044dO7Rr1y6deOKJqqur01tvvaW1a9dGd1IAKEGqWuCDX0Xa24Nuk8bGILyj/IqyYMECbdiwQeedd54k6fjjj1dnZ6cWLlyohx56SKeffrqmTZumOXPmRHdSACiBuRecxiRyra2tnr+gw4YNG3T66afHVoZqwecKZIeZrXP31vztqepCAQB8jAAHgJQiwAEgpQhwAEgpAhwAUooAB4CUIsDH4dFHH9V777039H7x4sVav3596OP29vbqySefHPfvLVq0SCtWrAh9fgDplL4AL/d8sqPID/BHHnlELS0toY9baoADqG7pCvAyzSfb2dmpc889VzNnztTSpUt16NAhLVq0SNOnT9eMGTO0bNkyrVixQt3d3Wpra9PMmTO1d+9ezZs3T4M3Jh1//PG6/fbbdcYZZ2j+/Pl69dVXNW/ePJ122mlauXKlpCCoP/e5z2nWrFmaNWuWfvzjH0uS7rjjDv3gBz/QzJkztWzZMh06dEi333770JS2Dz/8sKRgnpZbbrlF06ZN0/z587VjR8lLkQLIAneP7TF79mzPt379+iO2FdTU5B5E9+GPpqbijzHC+S+77DLfv3+/u7vfdNNNfvfdd/v8+fOH9vnggw/c3f3888/31157bWj78PeS/LnnnnN39yuuuMIvuugi379/v/f09PhZZ53l7u67d+/2vXv3urv7xo0bffDzWLVqlV966aVDx3344Yf9nnvucXf3ffv2+ezZs33z5s3+zDPP+Pz58/3gwYO+bds2P+GEE/zpp58uWC8A2SCp20fI1FTNhVKO+WRffPFFrVu3Tuecc44kae/evVq4cKE2b96sW2+9VZdeeqkWLFgw5nGOPvpoLVy4UJI0Y8YMHXPMMaqtrdWMGTPU29srSTpw4IBuueUW9fT0aMKECdq4ceOIx3r++ef15ptvDvVv79q1S5s2bdLq1at1zTXXaMKECTrllFN0wQUXlFxvAOmXri6UMqxR5O66/vrr1dPTo56eHr399tu6//779cYbb2jevHl66KGHtHjx4jGPU1tbO7QafE1NjY455pih1wcPHpQkLVu2TCeddJLeeOMNdXd3a//+/QXL9OCDDw6V6Z133inqfyLAaCp4+Qhlkq4AL8N8shdeeKFWrFgx1J/8/vvva8uWLRoYGNBVV12le++9Vz/5yU8kSZMmTdKHH35Y8rl27dqlk08+WTU1NXriiSd06NChEY978cUX65vf/KYOHDggSdq4caN2796tuXPn6qmnntKhQ4e0fft2rVq1quSyoLqUazlCVFa6ulDKMJ9sS0uL7r33Xi1YsEADAwOqra3V1772NV155ZVDixt/9atflRQM27vxxht17LHHas2aNeM+180336yrrrpKjz/+uBYuXDi0iMSZZ56pCRMm6KyzztKiRYt02223qbe3V7NmzZK7q6GhQc8++6yuvPJKvfTSS2ppaVFjY+PQlLfAWEZbjrCcK8agvJhONqP4XDFcTU3Q8s5nJuXaKUgwppMFqlgZLh8hAQhwoArEsRwh4peIAI+zG6ca8HkiX1ub1NEhNTUF3SZNTcF7+r/TreIXMSdOnKidO3eqvr5+aBgeSufu2rlzpyZOnFjpoiBh2toI7KypeIBPmTJFfX196u/vr3RRMmPixImaMmVKpYsBoMwqHuC1tbWaOnVqpYsBAKmTiD5wAMD4EeBAyiTllviklKOaEeBAniQHU1JuiY+yHEn+vJMu1J2YZvaXkhZLckk/lfQld99XaP+R7sQEkmQwmIbfdl5Xl5whd83NQVjma2qScpNepqocSf+8k6LQnZglB7iZnSrph5Ja3H2vmX1P0nPu/mih3yHAkXRJCchCknJLfFTlSPrnnRTlupX+KEnHmtlRkuokvTfG/kCilWHK+Ugl5Zb4qMqR9M876UoOcHffJumfJW2VtF3SLnd/Pn8/M1tiZt1m1s1YbyRdUgKykKTcEh9VOZL+eSddyQFuZidKulzSVEmnSDrOzK7N38/dO9y91d1bGxoaSi8pEIOkBGQhSbklfjzlGO0iZdI/78QbaZ21Yh6S/ljSt4a9v07SN0b7nZHWxATc3Ts7g6VNzYLnzs7qKkuS6h+lzk73urrDl7Ctqzu8flmte5RUYE3MMBcxPyNpuaRzJO2V9GjuJA8W+h0uYmIk1T4SIcv15yJlNCIfhZI76N9J+lNJByW9Lmmxu/9fof0JcIyk2v8jz3L9kzJqJu0KBXiouVDc/S5Jd4U5BlDtIxGyXP/GxpH/58RFymhwJyYqrtpHImS5/lykLC8CHBVX7f+Rp7H+xd7+npRRM5k10pXNcj0YhYJCqn0kQprqX8zIEkRLUY9CKQUXMYH0y/JF16RiVXoAkcjyRde0IcCBDIljatYsX3RNGwIcyIi45gpP40XXrCLAgYxobz/8bk4peN/eHu15GFmSHFzEBDKCux6zi4uYQMbRN119CHAgI+ibrj4EOJARcfVNswhxcoSazApAsrS1lfdiYv7Ut4MjXQbPjXjRAgdQtLhGuqA4iQ9wvq4BycFdmMmS6ACP68YEAMVhpEuyJDrA+bpWPL6pIA6MdEmWRAc4X9eKwzcVxIW7MJMl0XdiMm1lcficgGxL5Z2YfF0rDt9UgOqU6ADn61pxuLAEVKdEB7gUhHVvbzAZT28v4T0SvqkA1SnxAY6x8U0FqE4EeEZE8U2FoYhAujAXCiQxxwWQRrTAISmam6ZowQPxogUOSeGHItKCB+JHCxySwg9FZNoDIH4EeAaV0pURdigiNxMB8SPAM6bUeVHCDkXkZiIgfomeCwXjV6l5UfL7wKWgBc94dCC8ssyFYmafMLMVZvaWmW0ws/PCHA/hVaorg5uJgPiFHYVyv6T/dvcvmNnRkurG+gWUV2PjyC3wOLoyyr0eI4DDldwCN7MTJM2V9C1Jcvf97v7biMqFEjEvClA9wnShTJXUL+nbZva6mT1iZsfl72RmS8ys28y6+/v7Q5wOxaArA6geJV/ENLNWSWslfdbdXzGz+yX9zt3/ttDvcBETAMavHBcx+yT1ufsrufcrJM0KcbyK4jZwAGlTcoC7+68kvWtm03KbLpS0PpJSxYw1JQGkUdgbeW6V1GVmb0qaKekfQpeoArgNHEAahRpG6O49ko7ol0kbbgMHkEbcSq9k3wZO3zyAQghwJXfsNH3zAEZDgCu5Y6fpmwcwGiazSrCamqDlnc8sWPsSQHUoy2RWKK8k980DqDwCPMGS2jcPIBkI8ARLat88gGRgUeOEY4pWAIXQAgeAlCLAASClCHAASCkCHABSigAHgJTKfIAzGRSArMr0MMLByaAG5xMZnAxKYmgegPTLdAucyaAAZFmmA5yFGgBkWaYDnMmgAGRZpgOcyaAAZFmmA5zJoABkWaZHoUhMBgUguzLdAgeALCPAASClCHAASCkCHABSigAHgJQiwAEgpQhwAEgpAhwAUooAB4CUIsABIKUIcABIqdABbmYTzOx1M/uPKAoEAChOFC3w2yRtiOA4AIBxCBXgZjZF0qWSHommOACAYoVtgf+rpL+RNFBoBzNbYmbdZtbd398f8nQAgEElB7iZXSZph7uvG20/d+9w91Z3b21oaCj1dACAPGFa4J+V9Hkz65X0XUkXmFlnJKUCAIyp5AB39zvdfYq7N0u6WtJL7n5tZCUDAIyKceAAkFKRrInp7i9LejmKYwEAikMLHABSigAHgJQiwAEgpQhwAEgpAhwAUooAB4CUIsABIKUIcABIKQIcAFKKAAeAlCLAASClCHAASCkCHABSigAHgJQiwAEgpQhwAEgpAhwAUooAB4CUIsABIKUIcABIKQIcAFKKAAeAlCLAASClCHAASCkCHABSigAHgJQiwAEgpQhwACiXri6puVmqqQmeu7oiPfxRkR4NABDo6pKWLJH27Aneb9kSvJektrZITkELHADKob394/AetGdPsD0iBDgAlMPWrePbXgICHADKobFxfNtLUHKAm9knzWyVma03s5+b2W2RlQoA0u6++6S6usO31dUF2yMSpgV+UNJfu3uLpDmS/sLMWqIpFgCkXFub1NEhNTVJZsFzR0dkFzClEKNQ3H27pO251x+a2QZJp0paH1HZACDd2toiDex8kfSBm1mzpLMlvTLCz5aYWbeZdff390dxOgBZUuax0lkWOsDN7HhJz0j6srv/Lv/n7t7h7q3u3trQ0BD2dADSZrSAHhwrvWWL5P7xWGlCvCihAtzMahWEd5e7/1s0RQKQCsW0nMcK6BjGSmeZuXtpv2hmkh6T9L67f7mY32ltbfXu7u6SzgcgIbq6pKVLpd27D99eV3fkRbrm5iC08zU1Sb29QfiPlEFm0sBAlKVONTNb5+6t+dvDtMA/K+mLki4ws57c45IQxwOQdF1d0pe+dGR4SyO3nMe6mSWGsdJZVnKAu/sP3d3c/Ux3n5l7PBdl4QAkTHu7dOBA4Z/nB/ZYAR3DWOks405MoFpEMdpjrNvA8wN7rICOYax0lhHgQJYUCumoRnuM1rVhdmTLuZiAbmsL+sMHBoJnwrt47h7bY/bs2Q4gpM5O96Ymd7PgubPz4+11de5BRAePurqP9x++ffDR1DT+c9fWjnysm26Ktp4YIqnbR8hUWuBAEhTbvTFaS3q0IXlRzYzX1iZ9+9tSff3H2+rrpc5O6RvfGN+xEFrJwwhLwTBCYAT5E/9LIw/Jk0Yflrd1a+EheY2Now/nQ6KVYxghgCiM52aW0VrSo434YLRHJhHgQKWNp3uj1JBmtEcmEeBApY3nZpYwIc1oj8whwIF8cc+ON57uDUIaw7AqPTBcDCuJH2HwuIOjRQa7Qwqdr8xzTCM9GIUCDDfW5EtABTAKBShGDCuJA1EhwIHhmB0PKUKAIxmSsqwW46WRIgQ4Ki9Jy2oxXhopQoAjOqW2opO2rBZD8ZASDCNENMIMv+PCIVASWuD4WJh+6DCtaC4cAiUhwBEI2w8dphXNhUOgJAR4FpXSkg7bDx2mFc2FQ6AkBHhWDIa2mfTFL46/JR22HzpsK5oLh8C4EeBJV0xrenj3h3TkpP7FtKTD9kPTigZix1woSVbsSi2F5u8Yzixo3YY9F4DYMRdKGhXbL11MN8dYLWla0EDqMA48yYrtly603uGgYvuimaYUSBVa4ElWbL/0SBcQzYJnWtJAZhHgSVbsyI6Ruj+eeCK4mMmIDiCzCPBBSZkNb7jx9EszDA+oOtkP8PEOw6v0bHj5CGYABWQ7wIsN5qTNhgcARch2gIcdhsdseAASLPkBHqZvejzD8EbCbHgAEixUgJvZQjN728x+YWZ3RFWoIWH7psMMw2M2PAAJV3KAm9kESV+X9EeSWiRdY2YtURVMUvi+6TDD8Bg7DSDhwtyJea6kX7j7Zkkys+9KulzS+igKJil83/RgALe3B7/T2BiEd6FheAQ2gBQJE+CnSnp32Ps+SZ/J38nMlkhaIkmN4+1TLnSL+HiOQzADyKiyX8R09w53b3X31oaGhvH9Mn3TAFBQmADfJumTw95PyW2LDn3TAFBQmC6U1yR9ysymKgjuqyX9WSSlGo4uEAAYUckB7u4HzewWSf8jaYKk5e7+88hKBgAYVaj5wN39OUnPRVQWAMA4JP9OTADAiAhwAEgpAhwAUirWVenNrF/SGMuna7Kk38RQnKSp1npL1L0a616t9ZZKq3uTux9xI02sAV4MM+t299ZKlyNu1VpvibpXY92rtd5StHWnCwUAUooAB4CUSmKAd1S6ABVSrfWWqHs1qtZ6SxHWPXF94ACA4iSxBQ4AKAIBDgApVZEAH2stTTM7xsyeyv38FTNrrkAxy6KIuv+Vma03szfN7EUza6pEOcuh2DVUzewqM3Mzy8Qws2LqbWZ/kvu7/9zMnoy7jOVSxL/3RjNbZWav5/7NX1KJckbNzJab2Q4z+1mBn5uZPZD7XN40s1klncjdY30omLnwl5JOk3S0pDckteTtc7Okh3Kvr5b0VNzlrGDd/1BSXe71TdVU99x+kyStlrRWUmulyx3T3/xTkl6XdGLu/R9Uutwx1r1D0k251y2Seitd7ojqPlfSLEk/K/DzSyT9lySTNEfSK6WcpxIt8KG1NN19v6TBtTSHu1zSY7nXKyRdaGYWYxnLZcy6u/sqdx9cyXmtgoUysqCYv7sk3SPpnyTti7NwZVRMvf9c0tfd/QNJcvcdMZexXIqpu0v6vdzrEyS9F2P5ysbdV0t6f5RdLpf0uAfWSvqEmZ083vNUIsBHWkvz1EL7uPtBSbsk1cdSuvIqpu7D3aDg/9JZMGbdc18jP+nu/xlnwcqsmL/5pyV92sx+ZGZrzWxhbKUrr2Lqfreka82sT8HU1LfGU7SKG28WjCjUfOAoHzO7VlKrpPMrXZY4mFmNpK9JWlTholTCUQq6UeYp+Ma12sxmuPtvK1momFwj6VF3/xczO0/SE2Y23d0HKl2wNKhEC7yYtTSH9jGzoxR8tdoZS+nKq6h1RM1svqR2SZ939/+LqWzlNlbdJ0maLullM+tV0C+4MgMXMov5m/dJWunuB9z9HUkbFQR62hVT9xskfU+S3H2NpIkKJnvKukjWFK5EgA+tpWlmRyu4SLkyb5+Vkq7Pvf6CpJc81/OfcmPW3czOlvSwgvDOSl+oNEbd3X2Xu09292Z3b1bQ//95d++uTHEjU8y/92cVtL5lZpMVdKlsjrGM5VJM3bdKulCSzOx0BQHeH2spK2OlpOtyo1HmSNrl7tvHfZQKXaG9REEr45eS2nPb/l7Bf7BS8Ed8WtIvJL0q6bRKX1WOse4vSPq1pJ7cY2WlyxxX3fP2fVkZGIVS5N/cFHQfrZf0U0lXV7rMMda9RdKPFIxQ6ZG0oNJljqje35G0XdIBBd+wbpB0o6Qbh/3Nv577XH5a6r91bqUHgJTiTkwASCkCHABSigAHgJQiwAEgpQhwAEgpAhwAUooAB4CU+n81PmNJdk5fugAAAABJRU5ErkJggg==\n", - "text/plain": [ - "
" - ] - }, - "metadata": { - "needs_background": "light" - }, - "output_type": "display_data" - } - ], - "source": [ - "plt.plot(x_train.data.numpy(), y_train.data.numpy(), 'bo', label='real')\n", - "plt.plot(x_train.data.numpy(), y_.data.numpy(), 'ro', label='estimated')\n", - "plt.legend()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "**思考:红色的点表示预测值,似乎排列成一条直线,请思考一下这些点是否在一条直线上?**" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "这个时候需要计算我们的误差函数,也就是\n", - "\n", - "$$\n", - "E = \\sum_{i=1}^n(\\hat{y}_i - y_i)^2\n", - "$$" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "metadata": {}, - "outputs": [], - "source": [ - "# 计算误差\n", - "def get_loss(y_, y):\n", - " return torch.sum((y_ - y) ** 2)\n", - "\n", - "loss = get_loss(y_, y_train)" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "tensor(748.8935, dtype=torch.float64, grad_fn=)\n" - ] - } - ], - "source": [ - "# 打印一下看看 loss 的大小\n", - "print(loss)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "定义好了误差函数,接下来我们需要计算 w 和 b 的梯度了,这时得益于 PyTorch 的自动求导,我们不需要手动去算梯度,有兴趣的同学可以手动计算一下,w 和 b 的梯度分别是\n", - "\n", - "$$\n", - "\\frac{\\partial}{\\partial w} = \\frac{2}{n} \\sum_{i=1}^n x_i(w x_i + b - y_i) \\\\\n", - "\\frac{\\partial}{\\partial b} = \\frac{2}{n} \\sum_{i=1}^n (w x_i + b - y_i)\n", - "$$" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "metadata": {}, - "outputs": [], - "source": [ - "# 自动求导\n", - "loss.backward()" - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "tensor([-125.1102])\n", - "tensor([-243.2102])\n" - ] - } - ], - "source": [ - "# 查看 w 和 b 的梯度\n", - "print(w.grad)\n", - "print(b.grad)" - ] - }, - { - "cell_type": "code", - "execution_count": 11, - "metadata": {}, - "outputs": [], - "source": [ - "# 更新一次参数\n", - "w.data = w.data - 1e-2 * w.grad.data\n", - "b.data = b.data - 1e-2 * b.grad.data" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "更新完成参数之后,我们再一次看看模型输出的结果" - ] - }, - { - "cell_type": "code", - "execution_count": 12, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "" - ] - }, - "execution_count": 12, - "metadata": {}, - "output_type": "execute_result" - }, - { - "data": { - "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXAAAAD6CAYAAAC4RRw1AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAAAZQ0lEQVR4nO3dfZBV9X3H8fd3cRXXUKPLxpqS3cWZ1IogCIvFaSVUEWh0fChpJmYTJU2CmujYdmpLhuloq9s0mVTaZNrEHUN8YLUqaVMmtQl5wJImol4MGoMGErroqg0rGnwAwsN++8e5i8v17t6755x7z8P9vGbu3HvPPXvO73cWvvu739/DMXdHRESypynpAoiISDgK4CIiGaUALiKSUQrgIiIZpQAuIpJRCuAiIhlVMYCb2Woz22VmT4/YdrKZfcfMthefT6ptMUVEpJRVGgduZvOBN4C73X16cdvngVfc/e/NbAVwkrv/VaWTTZ482Ts7O6OXWkSkgWzevPlld28r3X5MpR90941m1lmy+VJgQfH1XcDDQMUA3tnZSaFQqLSbiIiMYGY7y20PmwM/xd1fKr7+P+CUkMcREZGQIndiepCDGTUPY2bLzaxgZoXBwcGopxMRkaKwAfyXZnYqQPF512g7unuvu3e5e1db29tSOCIiElLFHPgo1gFXAX9ffP6PsAU4ePAgAwMD7N+/P+whpMTEiROZMmUKzc3NSRdFRGqoYgA3s/sIOiwnm9kAcBNB4H7AzD4O7AQ+GLYAAwMDTJo0ic7OTsws7GGkyN3ZvXs3AwMDTJ06NeniiEgNVUyhuPsV7n6quze7+xR3/6q773b3C9z9ve6+0N1fCVuA/fv309raquAdEzOjtbVV32hEUqCvDzo7oakpeO7ri/f4YVMosVLwjpeup0jy+vpg+XLYuzd4v3Nn8B6guzuec2gqfQw6Ozt5+eWXky6GiKTIypVvBe9he/cG2+OiAF7C3RkaGkq6GCKScc89N77tYWQugNcip9Tf38/pp5/OlVdeyfTp07nllluYO3cuZ511FjfddNOR/S677DLmzJnDmWeeSW9vb/QTi0hutbePb3sYqciBV6uWOaXt27dz11138dprr7F27Voee+wx3J1LLrmEjRs3Mn/+fFavXs3JJ5/Mvn37mDt3LkuXLqW1tTXaiUUkl3p6jo5XAC0twfa4ZKoFXsucUkdHB/PmzWP9+vWsX7+es88+m9mzZ/Pss8+yfft2AL74xS8yc+ZM5s2bx/PPP39ku4hIqe5u6O2Fjg4wC557e+PrwISMtcBrmVM64YQTgCAH/pnPfIarr776qM8ffvhhvvvd7/LII4/Q0tLCggULNFRPRMbU3R1vwC6VqRZ4PXJKixcvZvXq1bzxxhsAvPDCC+zatYs9e/Zw0kkn0dLSwrPPPsumTZviO6mISAiZCuA9PUEOaaS4c0qLFi3iwx/+MOeeey4zZszgAx/4AK+//jpLlizh0KFDnHHGGaxYsYJ58+bFd1IRkRAq3tAhTl1dXV66HvgzzzzDGWecUfUx+vqCnPdzzwUt756e2n5FyarxXlcRSS8z2+zuXaXbM5UDh9rnlEREsiJTKRQREXmLAriISEYpgIuIZJQCuIhIRimAi4hklAL4ONx55528+OKLR95/4hOfYOvWrZGP29/fz7333jvun1u2bBlr166NfH4RySYF8HEoDeB33HEH06ZNi3zcsAFcRBpb9gJ4DdaTXbNmDeeccw6zZs3i6quv5vDhwyxbtozp06czY8YMVq1axdq1aykUCnR3dzNr1iz27dvHggULGJ6Y9I53vIMbb7yRM888k4ULF/LYY4+xYMECTjvtNNatWwcEgfq8885j9uzZzJ49mx/96EcArFixgh/84AfMmjWLVatWcfjwYW688cYjS9refvvtQLBOy3XXXcfpp5/OwoUL2bVrV+S6i0iGuXvoB3AD8DTwU+BPK+0/Z84cL7V169a3bRvVmjXuLS3u8NajpSXYHtLWrVv94osv9gMHDri7+7XXXus333yzL1y48Mg+r776qru7v+997/PHH3/8yPaR7wF/6KGH3N39sssu8wsvvNAPHDjgW7Zs8ZkzZ7q7+5tvvun79u1zd/dt27b58PXYsGGDX3TRRUeOe/vtt/stt9zi7u779+/3OXPm+I4dO/zrX/+6L1y40A8dOuQvvPCCn3jiif7ggw+OWi8RyQeg4GViauiZmGY2HfgkcA5wAPiWmX3T3X8e+a/KaMZaTzbk9Mzvfe97bN68mblz5wKwb98+lixZwo4dO7j++uu56KKLWLRoUcXjHHvssSxZsgSAGTNmcNxxx9Hc3MyMGTPo7+8H4ODBg1x33XVs2bKFCRMmsG3btrLHWr9+PU899dSR/PaePXvYvn07Gzdu5IorrmDChAm8+93v5vzzzw9VZxHJhyhT6c8AHnX3vQBm9t/AHwGfj6NgZdVgPVl356qrruKzn/3sUdt7enr49re/zVe+8hUeeOABVq9ePeZxmpubj9xMuKmpieOOO+7I60OHDgGwatUqTjnlFJ588kmGhoaYOHHiqGX60pe+xOLFi4/a/tBDD4Wqo4jkU5Qc+NPAeWbWamYtwPuB95TuZGbLzaxgZoXBwcEIp6Mm68lecMEFrF279kg++ZVXXmHnzp0MDQ2xdOlSbr31Vp544gkAJk2axOuvvx76XHv27OHUU0+lqamJe+65h8OHD5c97uLFi/nyl7/MwYMHAdi2bRtvvvkm8+fP5/777+fw4cO89NJLbNiwIXRZRCT7QrfA3f0ZM/scsB54E9gCHC6zXy/QC8FqhGHPB9TkHkXTpk3j1ltvZdGiRQwNDdHc3Mxtt93G5ZdffuTmxsOt82XLlnHNNddw/PHH88gjj4z7XJ/61KdYunQpd999N0uWLDlyE4mzzjqLCRMmMHPmTJYtW8YNN9xAf38/s2fPxt1pa2vjG9/4Bpdffjnf//73mTZtGu3t7Zx77rmh6y0i2RfbcrJm9nfAgLv/y2j7xLGcrNaTrY6WkxXJj5osJ2tm73L3XWbWTpD/rv1dDrSerIgIEH098K+bWStwEPi0u/8qepFERKQakQK4u58XV0FERGR8UjETM648vAR0PUUaQ+IBfOLEiezevVtBJybuzu7du0cdYy4i+ZH4PTGnTJnCwMAAkceIyxETJ05kypQpSRdDUkYDuPIn8QDe3NzM1KlTky6GSK719R09hWLnzuA9KIhnWeIpFBGpvbGWEZLsUgAXaQA1WEZIUkABXKQB1GAZIUkBBXCRBtDTEywbNFLEZYQkBRTARRpAdzf09kJHB5gFz7296sDMusRHoYhIfWgZofxRC1xEJKMUwEVEMkoBXEQkoxTARUQySgFcRCSjFMBFRDJKAVwkY/r6oLMTmpqC576+pEskSVEAFymR5gA5vKrgzp3g/taqgkmUMa7rlObrnXruHvoB/BnwU+Bp4D5g4lj7z5kzx0XSbM0a95YW9yA8Bo+WlmB7GnR0HF224UdHR33LEdd1Svv1Tgug4GViqnnIO+GY2W8B/wNMc/d9ZvYA8JC73znaz3R1dXmhUAh1PpF66OwMWrWlOjqgv7/epXm7pqYgzJUyg6Gh+pUjruuU9uudFma22d27SrdHTaEcAxxvZscALcCLEY8nkqi0L7uallUF47pOab/eaRc6gLv7C8AXgOeAl4A97r4+roKJJCEtAXI0aVlVMK7rlPbrnXahA7iZnQRcCkwF3g2cYGYfKbPfcjMrmFlB972UtEtLgBxNWlYVjOs6pf16p165xHg1D+CPga+OeH8l8C9j/Yw6MSUL1qwJOgXNgmd1qJUX13XS9a6MGnRi/i6wGpgL7APuLJ7kS6P9jDoxRRpPX19w783nngtSIz09WtZ2vGLvxHT3R4G1wBPAT4rH6g1dQmloGgucT2kat55HkUahuPtN7v477j7d3T/q7r+Oq2DSOPSfPL9/wFauhL17j962d2+wXaILnUIJQykUKafRxwIP/wEbGehaWvJxy7O0jFvPulqNAxeJrNHHAue5laphgrWlAC6Ja/T/5Hn+A6ZhgrWlAC6Ja/T/5Hn+A5aWcet5pQAuiWv0/+R5/wPW3R30ZQwNBc+N8nutBwVwSYVG/k+exT9geR01kzXHJF0AEQmCdZoD9kilo2aGh31CduqQF2qBi8i45HnUTNYogIvIuOR51EzWKICLyLjkedRM1iiAi+RIPToX8z5qJksUwEVyol5rymRx1ExeaS0UkZxo9DVl8kxroYjknDoXG48CuEhOqHOx8SiAi+SEOhcbjwK4SE6oc7HxaCq9SI5kaUq+RBe6BW5mp5vZlhGP18zsT2Msm4iIjCHKTY1/5u6z3H0WMAfYC/x7XAUTkXTSSoTpEVcK5QLgF+5eZhSqiOSFViJMl7g6MT8E3BfTsUQkpbQSYbpEDuBmdixwCfDgKJ8vN7OCmRUGBwejnk5EEqTJQukSRwv8D4En3P2X5T50915373L3rra2tnEfXPk2kfTQZKF0iSOAX0GN0if1WpxHRKqjyULpEimAm9kJwIXAv8VTnKMp3yaSLposlC6RAri7v+nure6+J64CjaR8W/WUapJ6aeQbUKdNqqfSK99WHaWaRBpTqgO48m3VUapJpDGlOoAr31YdpZpEGlPqF7PS4jyVtbeXvxOLUk0i+ZbqFrhUR6kmkcakAJ4DSjWJNKbUp1CkOko1iTQetcBFRDJKAVyO0GQgkWxRCkUArfMskkVqgQugyUAiWaQALkA8k4GUghGpLwVwAaKvO6P1WETqTwFcgOiTgZSCEak/BXABok8G0nosIvWnAJ5DYXPRUdZ51tK/IvWnAJ4zSeWitR6LSP0pgOdMUrlorcciUn/m7uF/2OydwB3AdMCBP3H3R0bbv6urywuFQujzSWVNTUHLu5RZkBoRkewxs83u3lW6PWoL/J+Ab7n77wAzgWciHk8iUi5apHGEDuBmdiIwH/gqgLsfcPdfxVQuCUm5aJHGEaUFPhUYBL5mZj82szvM7ISYyiUhKRctkiI1np4cOgduZl3AJuD33P1RM/sn4DV3/+uS/ZYDywHa29vn7Cx37y8RkbwpXSEOgq/DIVpUo+XAowTw3wQ2uXtn8f15wAp3v2i0n1Enpog0jM7O8jer7egIJlqMQ+ydmO7+f8DzZnZ6cdMFwNawxxMRyZU6TE+OOgrleqDPzJ4CZgF/F7lECdFKeiISqzoMCYsUwN19i7t3uftZ7n6Zu78aV8HqSSvpiUjs6jAkTDMx0Up6IlIDdRgSFmkm5niltRNTsxdFJM1qNRMzFzR7UUSySAEczV4UkWxSACfdsxc1OkZERnNM0gVIi+7udATskUoncg2PjoH0lVVE6k8t8BTT6BgRGYsCeIrpPpMiMhYF8BTT6BgRGYsCeIppdIyIjEUBPMXSPDpGRJKnAJ5y3d3BypNDQ8GzgrfkjsbKhqZhhCKSHI2VjUQtcBFJjsbKRqIALiK1NVaKRGNlI1EAF5HaqbTYvsbKRqIALiK1UylForGykSiAi8j49PXB5MnB2Nbhx+TJ5UePVEqRaKxsJBqFIiLV6+uDj30MDh48evvu3cF2ODr4treXvzP7yBRJGleSy4hILXAz6zezn5jZFjNL3612RCReK1e+PXgPO3jw7aNHlCKpqTha4H/g7i/HcBwRSbtKo0NKPx9uWa9cGXzW3h4Eb7W4Y5H7HLgmeYnEqNLokHKfazpxzUQN4A6sN7PNZrY8jgLFqdIIJhEZp54eaG4u/1lzs1IjdRY1gP++u88G/hD4tJnNL93BzJabWcHMCoODgxFPNz6a5CUyQhxfR7u74Wtfg9bWo7e3tgbb1bquK3P3eA5kdjPwhrt/YbR9urq6vFCoX19nU1PQ8i5lFnybE2kYpWuOQNCZqCF7mWBmm929q3R76Ba4mZ1gZpOGXwOLgKfDFzF+muQlDWe0Vra+juZSlFEopwD/bmbDx7nX3b8VS6li0tNTvtGhNJ3k0lgr+2nNkVwK3QJ39x3uPrP4ONPdUxcWNclLcilMK1tfR3Mp9zMxNclLciVsK/uee/R1NIdyPw5cJFfCtrL1dTSXFMBFsmSsVnalaeuaUJM7CuAiaVDtGG21smUEBXCRpI1nyrBa2TKCArhI0sYzRlutbBkhtpmY1aj3TEyRTNCUYakg9pmYIhITjdGWkBTARZKmmx5ISArgIklTXltCyv1MTJFM0JRhCUEtcBGRjFIAFyml+/BJRiiFIjLSWItFKcUhKaMWuMhIuvGBZIgCuMhIuvGBZIgCuMhImlQjGaIALjKSJtVIhkQO4GY2wcx+bGbfjKNAIonSpBrJkDhGodwAPAP8RgzHEkmeJtVIRkRqgZvZFOAi4I54iiMiItWKmkL5R+AvAa15KSJSZ6EDuJldDOxy980V9ltuZgUzKwwODoY9neSdZj+KjFuUFvjvAZeYWT/wr8D5ZramdCd373X3Lnfvamtri3A6ya3x3FJMRI6I5Y48ZrYA+At3v3is/XRHHimrszMI2qU6OoL7Ooo0ON2RR9JLsx9FQoklgLv7w5Va3yKj0uxHkVDUApf4hO2I1OxHkVAUwCUeUToiNftRJJRYOjGrpU7MHFNHpEjNqBNTaksdkSJ1pwAu8VBHpEjdKYDLW6LMhlRHpEjdKYBLIOpsSHVEitSdOjEloE5IkdRSJ6aMTZ2QIpmjAJ5HYXLZ6oQUyRwF8LwYDtpm8NGPjj+XrU5IkcxRAM+DkR2QEATukfbuhZUrxz6GOiFFMkedmHkwWgfkSGYwpBsniWSROjGzqpp8djUdjcpli+SOAniaVTs2u1JwVi5bJJcUwNNs5cogfz1SuXx2uQ5Is+BZuWyR3FIAT7Nqx2aX64C8556g1d7fr+AtklPHJF0AGUN7e/nOyXIpk+5uBWqRBhO6BW5mE83sMTN70sx+amZ/E2fBBI3NFpExRUmh/Bo4391nArOAJWY2L5ZSSUBjs0VkDKFTKB4MIH+j+La5+KjfoPJGodSIiIwiUiemmU0wsy3ALuA77v5oLKUSEZGKIgVwdz/s7rOAKcA5Zja9dB8zW25mBTMrDA4ORjmdiIiMEMswQnf/FbABWFLms15373L3rra2tjhOJyIiRBuF0mZm7yy+Ph64EHg2pnKJiEgFUVrgpwIbzOwp4HGCHPg34ylWAqLcD1JEJAFRRqE8BZwdY1mSM7zmyPC09eE1R0AjQEQktTSVHqpfc0REJEXyH8CjLMeq+0GKSIrlO4BHXY5Va2iLSIrlO4BHWY5Va46ISMrlO4BHWY5Va46ISMrlezlZLccqIjmW/hZ4lPHZSo2ISI6lO4BX2wk5GqVGRCTHLFgVtj66urq8UChU/wOdneVTIB0dwa3CREQagJltdveu0u3pboFrfLaIyKjSHcA1PltEZFTpDuDqhBQRGVW6A7g6IUVERpX+ceAany0iUla6W+AiIjIqBXARkYxSABcRySgFcBGRjFIAFxHJqLpOpTezQaDM3PijTAZerkNx0qZR6w2qeyPWvVHrDeHq3uHubaUb6xrAq2FmhXJz/vOuUesNqnsj1r1R6w3x1l0pFBGRjFIAFxHJqDQG8N6kC5CQRq03qO6NqFHrDTHWPXU5cBERqU4aW+AiIlKFRAK4mS0xs5+Z2c/NbEWZz48zs/uLnz9qZp0JFLMmqqj7n5vZVjN7ysy+Z2YdSZSzFirVfcR+S83MzSwXoxSqqbeZfbD4e/+pmd1b7zLWShX/3tvNbIOZ/bj4b/79SZQzbma22sx2mdnTo3xuZvbF4nV5ysxmhzqRu9f1AUwAfgGcBhwLPAlMK9nnU8BXiq8/BNxf73ImWPc/AFqKr69tpLoX95sEbAQ2AV1Jl7tOv/P3Aj8GTiq+f1fS5a5j3XuBa4uvpwH9SZc7prrPB2YDT4/y+fuB/wIMmAc8GuY8SbTAzwF+7u473P0A8K/ApSX7XArcVXy9FrjAzKyOZayVinV39w3uvrf4dhMwpc5lrJVqfu8AtwCfA/bXs3A1VE29Pwn8s7u/CuDuu+pcxlqppu4O/Ebx9YnAi3UsX824+0bglTF2uRS42wObgHea2anjPU8SAfy3gOdHvB8obiu7j7sfAvYArXUpXW1VU/eRPk7wVzoPKta9+DXyPe7+n/UsWI1V8zv/beC3zeyHZrbJzJbUrXS1VU3dbwY+YmYDwEPA9fUpWuLGGwvKSv8NHRqUmX0E6ALel3RZ6sHMmoDbgGUJFyUJxxCkURYQfOPaaGYz3P1XSRaqTq4A7nT3fzCzc4F7zGy6uw8lXbAsSKIF/gLwnhHvpxS3ld3HzI4h+Gq1uy6lq61q6o6ZLQRWApe4+6/rVLZaq1T3ScB04GEz6yfIC67LQUdmNb/zAWCdux909/8FthEE9Kyrpu4fBx4AcPdHgIkEa4XkXVWxoJIkAvjjwHvNbKqZHUvQSbmuZJ91wFXF1x8Avu/FzH/GVay7mZ0N3E4QvPOSC4UKdXf3Pe4+2d073b2TIP9/ibsXkilubKr59/4NgtY3ZjaZIKWyo45lrJVq6v4ccAGAmZ1BEMAH61rKZKwDriyORpkH7HH3l8Z9lIR6aN9P0Mr4BbCyuO1vCf7DQvBLfBD4OfAYcFrSvcp1rPt3gV8CW4qPdUmXuV51L9n3YXIwCqXK37kRpI+2Aj8BPpR0metY92nADwlGqGwBFiVd5pjqfR/wEnCQ4BvWx4FrgGtG/M7/uXhdfhL237pmYoqIZJRmYoqIZJQCuIhIRimAi4hklAK4iEhGKYCLiGSUAriISEYpgIuIZJQCuIhIRv0/2aSI16sim/wAAAAASUVORK5CYII=\n", - "text/plain": [ - "
" - ] - }, - "metadata": { - "needs_background": "light" - }, - "output_type": "display_data" - } - ], - "source": [ - "y_ = linear_model(x_train)\n", - "plt.plot(x_train.data.numpy(), y_train.data.numpy(), 'bo', label='real')\n", - "plt.plot(x_train.data.numpy(), y_.data.numpy(), 'ro', label='estimated')\n", - "plt.legend()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "从上面的例子可以看到,更新之后红色的线跑到了蓝色的线下面,没有特别好的拟合蓝色的真实值,所以我们需要在进行几次更新" - ] - }, - { - "cell_type": "code", - "execution_count": 13, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "epoch: 19, loss: 9.138844332292493\n", - "epoch: 39, loss: 8.31670591484358\n", - "epoch: 59, loss: 8.010376750480548\n", - "epoch: 79, loss: 7.896237967760094\n", - "epoch: 99, loss: 7.853709612500179\n" - ] - } - ], - "source": [ - "for e in range(100): # 进行 100 次更新\n", - " y_ = linear_model(x_train)\n", - " loss = get_loss(y_, y_train)\n", - " \n", - " w.grad.zero_() # 记得归零梯度\n", - " b.grad.zero_() # 记得归零梯度\n", - " loss.backward()\n", - " \n", - " w.data = w.data - 1e-2 * w.grad.data # 更新 w\n", - " b.data = b.data - 1e-2 * b.grad.data # 更新 b \n", - " if (e + 1) % 20 == 0:\n", - " print('epoch: {}, loss: {}'.format(e, loss.item()))" - ] - }, - { - "cell_type": "code", - "execution_count": 14, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "" - ] - }, - "execution_count": 14, - "metadata": {}, - "output_type": "execute_result" - }, - { - "data": { - "image/png": "iVBORw0KGgoAAAANSUhEUgAAAWoAAAD4CAYAAADFAawfAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAAAXXklEQVR4nO3df5BV5X3H8fd3cRVXqTFAHBPCLs4kVARFdrXYGZFEBKqZRGum0VmjpFH8Ua1tJ86Y8IdOlclkppXWJKNuLLG6a6pidJjGJtRfJU1QXAyaFBMwuOCiDetqqPwKsPvtH2cXYb3Lnrv3nnOec8/nNXNn95579t7vc5f98NznPOc55u6IiEi46rIuQEREjkxBLSISOAW1iEjgFNQiIoFTUIuIBO6oJJ50woQJ3tTUlMRTi4jUpHXr1r3j7hNLPZZIUDc1NdHZ2ZnEU4uI1CQz2zLcYxr6EBEJnIJaRCRwCmoRkcAlMkZdyv79++nu7mbv3r1pvWTNGzt2LJMmTaK+vj7rUkQkQakFdXd3N+PGjaOpqQkzS+tla5a709vbS3d3N1OmTMm6HBFJUGpDH3v37mX8+PEK6SoxM8aPH69PKCIB6OiApiaoq4u+dnRU9/lT61EDCukq0/spkr2ODli8GHbvju5v2RLdB2htrc5r6GCiiEgFliz5IKQH7d4dba8WBXVMTU1NvPPOO1mXISKB2bq1vO2jEWxQJznm4+709/dX7wlFpLAmTy5v+2gEGdSDYz5btoD7B2M+lYR1V1cXU6dO5corr2T69OnccccdnHXWWZx++uncdtttB/e7+OKLaW5u5rTTTqOtra0KrRGRWrZ0KTQ0HL6toSHaXi1BBnVSYz6bNm3ihhtuYNmyZWzbto21a9eyfv161q1bx+rVqwFYvnw569ato7Ozk7vvvpve3t7KXlREalprK7S1QWMjmEVf29qqdyARUp71EVdSYz6NjY3Mnj2br33ta6xatYozzzwTgJ07d7Jp0ybmzJnD3XffzRNPPAHAm2++yaZNmxg/fnxlLywiNa21tbrBPFSQQT15cjTcUWp7JY477jggGqP++te/zrXXXnvY488//zxPP/00a9asoaGhgblz52qesohkLsihj6THfBYsWMDy5cvZuXMnANu2bWP79u3s2LGDE088kYaGBn7961/zwgsvVOcFRUQqEGSPevAjxJIl0XDH5MlRSFfro8X8+fN57bXXOOeccwA4/vjjaW9vZ+HChdx7772ceuqpTJ06ldmzZ1fnBUVEKmDuXvUnbWlp8aEXDnjttdc49dRTq/5aRaf3VaQ2mNk6d28p9ViQQx8iIvIBBbWISOAU1CIigVNQi4gETkEtIhI4BbWISOAU1CU88MADvPXWWwfvX3311WzYsKHi5+3q6uLhhx8u++cWLVrEihUrKn59EcmncIM66WvbHMHQoL7//vuZNm1axc872qAWkWILM6iTWOcUaG9v5+yzz2bmzJlce+219PX1sWjRIqZPn86MGTNYtmwZK1asoLOzk9bWVmbOnMmePXuYO3cugyfwHH/88dxyyy2cdtppzJs3j7Vr1zJ37lxOOeUUVq5cCUSBfO655zJr1ixmzZrFz3/+cwBuvfVWfvrTnzJz5kyWLVtGX18ft9xyy8HlVu+77z4gWovkxhtvZOrUqcybN4/t27dX1G4RyTl3r/qtubnZh9qwYcOHtg2rsdE9iujDb42N8Z+jxOt/7nOf83379rm7+/XXX++33367z5s37+A+7733nru7n3feef7SSy8d3H7ofcCfeuopd3e/+OKL/YILLvB9+/b5+vXr/YwzznB39127dvmePXvc3X3jxo0++H4899xzftFFFx183vvuu8/vuOMOd3ffu3evNzc3++bNm/3xxx/3efPm+YEDB3zbtm1+wgkn+GOPPTZsu0Qk/4BOHyZTg1zrI4l1Tp955hnWrVvHWWedBcCePXtYuHAhmzdv5qabbuKiiy5i/vz5Iz7P0UcfzcKFCwGYMWMGxxxzDPX19cyYMYOuri4A9u/fz4033sj69esZM2YMGzduLPlcq1at4tVXXz04/rxjxw42bdrE6tWrufzyyxkzZgwf//jH+exnPzvqdotI/oUZ1Amsc+ruXHXVVXzzm988bPvSpUv5yU9+wr333sujjz7K8uXLj/g89fX1B6/+XVdXxzHHHHPw+wMHDgCwbNkyTjrpJF555RX6+/sZO3bssDV9+9vfZsGCBYdtf+qpp0bVRhGpTWGOUSewzun555/PihUrDo73vvvuu2zZsoX+/n4uvfRS7rzzTl5++WUAxo0bx/vvvz/q19qxYwcnn3wydXV1PPTQQ/T19ZV83gULFnDPPfewf/9+ADZu3MiuXbuYM2cOjzzyCH19fbz99ts899xzo65FRPIvzB51AuucTps2jTvvvJP58+fT399PfX09d911F5dccsnBC90O9rYXLVrEddddx7HHHsuaNWvKfq0bbriBSy+9lAcffJCFCxcevGDB6aefzpgxYzjjjDNYtGgRN998M11dXcyaNQt3Z+LEiTz55JNccsklPPvss0ybNo3JkycfXI5VRIpJy5zmnN5XkdqgZU5FRHJMQS0iErhUgzqJYZYi0/spUgypBfXYsWPp7e1VuFSJu9Pb2zvs1D8RqR2xZn2Y2c3ANYAB33P3fyr3hSZNmkR3dzc9PT3l/qgMY+zYsUyaNCnrMkQkYSMGtZlNJwrps4F9wI/N7N/d/fVyXqi+vp4pU6aMrkoRkQKLM/RxKvCiu+929wPAfwF/nmxZIiIyKE5Q/wo418zGm1kDcCHwyaE7mdliM+s0s04Nb4hIoSS8LPOIQx/u/pqZfQtYBewC1gN9JfZrA9ogOuGlqlWKiIRqcFnm3buj+4PLMkNFZ1MfKtasD3f/F3dvdvc5wHtA6eXgRESKZsmSD0J60O7d0fYqiTvr42Puvt3MJhONT8+uWgUiInmWwLLMQ8VdlOlxMxsP7Af+yt1/X7UKRETyLIFlmYeKO/RxrrtPc/cz3P2Zqr26iEjeJbAs81Ba60NEpBKtrdDWBo2NYBZ9bWur2oFECHU9ahGRPGltrWowD6UetYhI4BTUIiKBU1CLiAROQS0iEjgFtYhI4BTUIjUk4bWBJCOanidSI1JYG0gyoh61SI1IYW0gyYiCWqRGpLA2kGREQS1SI4ZbA6iKawNJRhTUIjUihbWBJCMKapEakcLaQJIRzfoQqSEJrw0kGVGPWkQkcApqEZHAKahFRAKnoBYRCZyCWkTSoYVIRk2zPkQkeVqIpCLqUYtI8rQQSUUU1CKSPC1EUhEFtYhUx5HGoLUQSUUU1CJSucEx6C1bwP2DMejBsNZCJBVRUItI5UYag9ZCJBUxd6/6k7a0tHhnZ2fVn1dEAlVXF/WkhzKD/v7068khM1vn7i2lHlOPWkQqpzHoRCmoRWR4HR0wYULUMzaLvi91oorGoBOloBYJUOYn8Q0G9BVXQG/vB9t7e+m76isfLkhj0IlSUEthZR6GwxhpAkVqBRwa0IcY07efnTeXOFGltRW6uqIx6a6uD4V0qO93HuhgohTS0DOaIfqkHkInsKkpCuehGhuj/MusgEP0Y9R5/IOEIb/foTjSwcRYQW1mfwtcDTjwS+Ar7r53uP0V1BK6zMPwCDKfQDFcAYfoopEm74r9lCG/36GoaNaHmX0C+Gugxd2nA2OAy6pboki6Qj6jOfMJFCO80F7quWt8eQcJQ36/8yDuGPVRwLFmdhTQALyVXEkiycs8DI8g8wkUJQrwgVsP47m+/vv8yT+XN14R8vudByMGtbtvA/4B2Aq8Dexw91VD9zOzxWbWaWadPT091a9UpIoyD8MjyHwCxZACdo5v5Obx7Ywx56zGd5j3/dayawn5/c6DEceozexE4HHgS8DvgceAFe7ePtzPaIxa8qCjIzrDeevWqGe3dKkObCVJ7/eRVXpm4jzgDXfvcff9wA+BP61mgSJZGGE2WT4FPAeuJt/vlMS5wstWYLaZNQB7gPMBdZdFQqOrqNSsOGPULwIrgJeJpubVAW0J1yUi5dJVVGpWrFkf7n6bu/+xu0939y+7+x+SLkxEjqDUEEfGc+ACHnXJPV3cViRvhhvi+OhHS5/2ncIcOI26JEtrfUiq1OuqguGGOCCzOXAadUmWglpSk/liQ7ViuKGMd9/NbAK2zjxMloJaUlP0XldZnyZGe6HYjObA6czDZCmoJTVF7nWV9WkihxeKDbCk2uLuVb81Nze7yFCNje5R8hx+a2zMurLkldX2ODu3t0f3zaKv7e0ptOLIAiwpV4BOHyZTtR61pKbIaxKXtXRp5uucShZ0cVsJQuaLDWWorDFcDfjKEApqSVVR13tYuhQW1XfwBk30UccbNLGovqP0GK4GfGUInfAikoJWOviSLeYoonGfJrbwPVs88Ac45H+rwf+9tNScDNAYtUgadC0qGYHGqEWyVuS5iVIxBbVIGnJ6gFCn/IdBQS2ShhweINQp/+FQUIukIYdzE4t+yn9IdDBRRErSeTfp0sFEESlbTofVa5KCWkRKyuGwes1SUItISTkcVq9ZCmqRHEpr2lxRT/kPjU4hF8kZXZ+weNSjluLK6dkcmjZXPOpRSzHluFuqs9GLRz1qKaYcd0s1ba54FNRSTDnulmraXPEoqKWYctwt1bS54lFQSzHlvFuqaXPFoqCWYlK3VHJEsz6kuFpbFcySC+pRi4gETkEtIhI4BbWISOBGDGozm2pm6w+5/Z+Z/U0KtYmICDGC2t1/4+4z3X0m0AzsBp5IujARyU5Ol0GpWeXO+jgf+K27b0miGBHJXo6XQalZ5Y5RXwb8IIlCRCQMOV4GpWbFDmozOxr4PPDYMI8vNrNOM+vs6empVn0ikrIcL4NSs8rpUf8Z8LK7/67Ug+7e5u4t7t4yceLE6lQntUeDn8HL8TIoNaucoL6cBIc99PdbAIODn1u2gPsHg5/6ZQcl58ug1KRYQW1mxwEXAD9Mogj9/RaEBj9zQcughMfcvepP2tLS4p2dnbH3b2qKwnmoxsZoZTCpEXV10f/EQ5lFy8CJFJiZrXP3llKPBXFmog5eFIQGP0VGJYig1t9vQWjwU2RUgghq/f3GF8RB19EWocFPkdFx96rfmpubvVzt7e6Nje5m0df29rKfoua1t7s3NLhHA73RraEh5fcqiCJEag/Q6cNkahAHEyWeIA66BlGESO0J/mCixBPEQdcgihApFgV1jgRx0DWIIkSKRUGdI0EcdA2iCJFiUVDnSBCTJoIoQqRYdDBRRCQAOpgohwtiMraIxFXuFV4k73T5DpHcUY+6aLSCnUjuKKiLRvOgRXJHQV00mgctkjsK6qLRPGiR3FFQ59loZm9oHrRI7iio82YwnM3gy18e1fXLOmiliS7q6KeJLjpQSIuETEGdJ4deXBI+fFmrGLM3dH1KkfzRmYl5MtwSo4ca4fqDWqVUJEw6M7FWxJlCN8LsjWrMztOJjSLpUlDnyUhT6GLM3qh0dp6GTkTSp6AOQdwuaqmpdWbR15izNyqdnacTG0XSp6DOWjld1FJT6x56KPq5rq5YU+wqnZ2nExtF0qeDiVnL2dG9nJUrkhs6mBiynHVRdWKjSPoU1Fmr4OheFrMvdGKjSPoU1FkbZRc1y9kXra3RMEd/f+yhcRGpgII6a6Psomr2hUhx6GBiTtXVffgMchjxxEQRCZQOJtYgLSstUhwK6pzS7AuR4lBQ55RmX4gUR6yrkJvZR4D7gemAA3/p7msSrEtiaG1VMIsUQaygBv4Z+LG7f9HMjgYaRvoBERGpjhGD2sxOAOYAiwDcfR+wL9myRERkUJwx6ilAD/B9M/uFmd1vZscN3cnMFptZp5l19vT0VL1QEZGiihPURwGzgHvc/UxgF3Dr0J3cvc3dW9y9ZeLEiVUuU0SkuOIEdTfQ7e4vDtxfQRTcIiKSghGD2t3/F3jTzKYObDof2JBoVUnRNaREJIfizvq4CegYmPGxGfhKciUlZHAVo8EFMgZXMQLNcRORoBVnrQ+teC8iAdNaH5C7BfpFRAYVJqh3frT0akXDbRcRCUXtBPUIBwq/wVJ2DTmhchcNfAOtYiQiYauNoI5xuZPvvNvKNbTRRSP9GF00cg1tfOddHUgUkbDVxsHEGAcKdSxRREJW+wcTYxwo1PrNIpJXtRHUMS53ovWbRSSvaiOoY3aXdfVsEcmj2ghqdZdFpIaFE9SVrsOR4+6yliARkSOJu9ZHsgq8DkeBmy4iMYUxPa/Ac+cK3HQROUT40/MKvA5HgZsuIjGFEdQxptfVqgI3XURiCiOoC3w2SoGbLiIxhRHUBZ5eV+Cmi0hMYRxMFBEpuPAPJoqIyLAU1CIigVNQi4gETkEtIhI4BbWISOAU1CIigVNQi4gETkEtIhI4BbWISOAU1CIigVNQi4gETkEtIhI4BbWISOAU1CIigVNQi4gELtZVyM2sC3gf6AMODLdmqoiIVF+soB7wGXd/J7FKRESkpJoZ+ujogKYmqKuLvnZ0ZF2RiEh1xA1qB1aZ2TozW1xqBzNbbGadZtbZ09NTvQpj6OiAxYthyxZwj74uXqywFpHaEOuaiWb2CXffZmYfA/4TuMndVw+3f9rXTGxqisJ5qMZG6OpKrQwRkVGr+JqJ7r5t4Ot24Ang7OqVV7mtW8vbLiKSJyMGtZkdZ2bjBr8H5gO/SrqwckyeXN52EZE8idOjPgn4bzN7BVgL/Mjdf5xsWeVZuhQaGg7f1tAQbRcRybsRp+e5+2bgjBRqGbXW1ujrkiXRcMfkyVFID24XEcmzcuZRB621VcEsIrWpZuZRi4jUKgW1iEjgFNQiIoFTUIuIBE5BLSISuFinkJf9pGY9QImTug+aABR1JT61vZiK2vaithvKb3uju08s9UAiQT0SM+ss6prWarvaXiRFbTdUt+0a+hARCZyCWkQkcFkFdVtGrxsCtb2Yitr2orYbqtj2TMaoRUQkPg19iIgETkEtIhK4RIPazBaa2W/M7HUzu7XE48eY2SMDj79oZk1J1pOmGG3/OzPbYGavmtkzZtaYRZ1JGKnth+x3qZm5mdXE9K047Tazvxj4vf+PmT2cdo1JifHvfbKZPWdmvxj4N39hFnVWm5ktN7PtZlbyYioWuXvgfXnVzGaN6oXcPZEbMAb4LXAKcDTwCjBtyD43APcOfH8Z8EhS9aR5i9n2zwANA99fX6S2D+w3DlgNvAC0ZF13Sr/zTwG/AE4cuP+xrOtOse1twPUD308DurKuu0ptnwPMAn41zOMXAv8BGDAbeHE0r5Nkj/ps4HV33+zu+4B/A74wZJ8vAP868P0K4HwzswRrSsuIbXf359x998DdF4BJKdeYlDi/d4A7gG8Be9MsLkFx2n0N8F13fw8OXoO0FsRpuwN/NPD9CcBbKdaXGI8u8v3uEXb5AvCgR14APmJmJ5f7OkkG9SeANw+53z2wreQ+7n4A2AGMT7CmtMRp+6G+SvS/bi0Yse0DH/8+6e4/SrOwhMX5nX8a+LSZ/czMXjCzhalVl6w4bb8duMLMuoGngJvSKS1z5WZBSTVzhZe8MrMrgBbgvKxrSYOZ1QF3AYsyLiULRxENf8wl+gS12sxmuPvvsywqJZcDD7j7P5rZOcBDZjbd3fuzLiwPkuxRbwM+ecj9SQPbSu5jZkcRfSTqTbCmtMRpO2Y2D1gCfN7d/5BSbUkbqe3jgOnA82bWRTRut7IGDijG+Z13Ayvdfb+7vwFsJAruvIvT9q8CjwK4+xpgLNGiRbUuVhaMJMmgfgn4lJlNMbOjiQ4Wrhyyz0rgqoHvvwg86wMj8Dk3YtvN7EzgPqKQrpWxShih7e6+w90nuHuTuzcRjc9/3t07sym3auL8e3+SqDeNmU0gGgrZnGKNSYnT9q3A+QBmdipRUPekWmU2VgJXDsz+mA3scPe3y36WhI+IXkjUa/gtsGRg298T/WFC9Mt6DHgdWAuckvVR3BTb/jTwO2D9wG1l1jWn1fYh+z5PDcz6iPk7N6Jhnw3AL4HLsq45xbZPA35GNCNkPTA/65qr1O4fAG8D+4k+MX0VuA647pDf+XcH3pdfjvbfuk4hFxEJnM5MFBEJnIJaRCRwCmoRkcApqEVEAqegFhEJnIJaRCRwCmoRkcD9P4jSg7+0cwW/AAAAAElFTkSuQmCC\n", - "text/plain": [ - "
" - ] - }, - "metadata": { - "needs_background": "light" - }, - "output_type": "display_data" - } - ], - "source": [ - "y_ = linear_model(x_train)\n", - "plt.plot(x_train.data.numpy(), y_train.data.numpy(), 'bo', label='real')\n", - "plt.plot(x_train.data.numpy(), y_.data.numpy(), 'ro', label='estimated')\n", - "plt.legend()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "经过 100 次更新,我们发现红色的预测结果已经比较好的拟合了蓝色的真实值。\n", - "\n", - "现在你已经学会了你的第一个机器学习模型了,再接再厉,完成下面的小练习。" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### 2.4 练习题\n", - "\n", - "重启 notebook 运行上面的线性回归模型,但是改变训练次数以及不同的学习率进行尝试得到不同的结果" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 3. 多项式回归模型" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "下面我们更进一步,讲一讲多项式回归。什么是多项式回归呢?非常简单,根据上面的线性回归模型\n", - "\n", - "$$\n", - "\\hat{y} = w x + b\n", - "$$\n", - "\n", - "这里是关于 x 的一个一次多项式,这个模型比较简单,没有办法拟合比较复杂的模型,所以我们可以使用更高次的模型,比如\n", - "\n", - "$$\n", - "\\hat{y} = w_0 + w_1 x + w_2 x^2 + w_3 x^3 + \\cdots\n", - "$$\n", - "\n", - "这样就能够拟合更加复杂的模型,这就是多项式模型,这里使用了 x 的更高次,同理还有多元回归模型,形式也是一样的,只是出了使用 x,还是更多的变量,比如 y、z 等等,同时他们的 loss 函数和简单的线性回归模型是一致的。" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "\n" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "首先我们可以先定义一个需要拟合的目标函数,这个函数是个三次的多项式" - ] - }, - { - "cell_type": "code", - "execution_count": 15, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "y = 0.90 + 0.50 * x + 3.00 * x^2 + 2.40 * x^3\n" - ] - } - ], - "source": [ - "# 定义一个多变量函数\n", - "\n", - "w_target = np.array([0.5, 3, 2.4]) # 定义参数\n", - "b_target = np.array([0.9]) # 定义参数\n", - "\n", - "f_des = 'y = {:.2f} + {:.2f} * x + {:.2f} * x^2 + {:.2f} * x^3'.format(\n", - " b_target[0], w_target[0], w_target[1], w_target[2]) # 打印出函数的式子\n", - "\n", - "print(f_des)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "我们可以先画出这个多项式的图像" - ] - }, - { - "cell_type": "code", - "execution_count": 16, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "" - ] - }, - "execution_count": 16, - "metadata": {}, - "output_type": "execute_result" - }, - { - "data": { - "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXkAAAD7CAYAAACPDORaAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAAAlIklEQVR4nO3deXxU9b3/8dcnG2FfQkAkQED2TYGwVcuDFheqXjdai61L1YpetbWtt7Xq7bX9aW+1trbaai1Vi1aLIGq1dSlqVVzKLpssBtmSCCQsCQSyznx+f2TgorIlM8mZmbyfj0ceM3PmzPl+zhDeOfOd7/kec3dERCQ5pQRdgIiINB6FvIhIElPIi4gkMYW8iEgSU8iLiCQxhbyISBI77pA3s8fMrNjMVh2yrJOZvWZm+ZHbjpHlZmYPmNl6M1thZiMbo3gRETm6+hzJzwAmf2bZj4E33L0f8EbkMcBXgH6Rn2nAH6IrU0REGsLqczKUmeUC/3D3oZHH64CJ7r7VzLoBb7n7ADP7Y+T+zM+ud7Ttd+7c2XNzcxu2JyIizdSSJUt2uHv24Z5Li3LbXQ8J7m1A18j97kDBIesVRpYdNeRzc3NZvHhxlCWJiDQvZrb5SM/F7ItXr/tIUO85EsxsmpktNrPFJSUlsSpHRESIPuS3R7ppiNwWR5YXAT0OWS8nsuxz3H26u+e5e1529mE/bYiISANFG/IvAldE7l8BvHDI8ssjo2zGAWXH6o8XEZHYO+4+eTObCUwEOptZIXAHcDcw28yuBjYDF0dWfxk4G1gP7AeubGiBNTU1FBYWUllZ2dBNyGdkZmaSk5NDenp60KWISCM77pB390uO8NSkw6zrwA0NLepQhYWFtG3bltzcXMwsFpts1tydnTt3UlhYSO/evYMuR0QaWdyf8VpZWUlWVpYCPkbMjKysLH0yEmkm4j7kAQV8jOn9FGk+EiLkRUSS2f2v57Ngw85G2bZCvgnk5uayY8eOoMsQkTi0oaSc37z+EQs27mqU7Svk68HdCYfDQZcRN3WISPT+Mn8z6anG1DE9jr1yAyjkj2HTpk0MGDCAyy+/nKFDh1JQUMC9997L6NGjGT58OHfcccfBdS+44AJGjRrFkCFDmD59+jG3/eqrrzJy5EhOPvlkJk2qG6T005/+lF/96lcH1xk6dCibNm36XB133nknP/zhDw+uN2PGDG688UYAnnzyScaMGcMpp5zCtddeSygUitXbISIxtK+qljmLCzl7WDe6tM1slDainbumSf3s7x+y+pM9Md3m4BPbccd/DDnqOvn5+Tz++OOMGzeOuXPnkp+fz8KFC3F3zjvvPObNm8eECRN47LHH6NSpExUVFYwePZopU6aQlZV12G2WlJRwzTXXMG/ePHr37s2uXcf+qHZoHSUlJYwfP557770XgFmzZnH77bezZs0aZs2axXvvvUd6ejrXX389Tz31FJdffnn93xwRaVR/W1bE3qpaLh/fq9HaSKiQD0qvXr0YN24cAHPnzmXu3LmMGDECgPLycvLz85kwYQIPPPAAzz//PAAFBQXk5+cfMeTnz5/PhAkTDo5V79SpU73qyM7Opk+fPsyfP59+/fqxdu1aTj31VB588EGWLFnC6NGjAaioqKBLly7RvQEiEnPuzhPvb2bIie0Y2bNjo7WTUCF/rCPuxtK6deuD992dW2+9lWuvvfZT67z11lu8/vrr/Pvf/6ZVq1ZMnDixQWPR09LSPtXffug2Dq0DYOrUqcyePZuBAwdy4YUXYma4O1dccQW/+MUv6t22iDSdBRt3sW77Xn45ZXijDmtWn3w9nXXWWTz22GOUl5cDUFRURHFxMWVlZXTs2JFWrVqxdu1a5s+ff9TtjBs3jnnz5rFx40aAg901ubm5LF26FIClS5cefP5wLrzwQl544QVmzpzJ1KlTAZg0aRJz5syhuLj44HY3bz7iLKQiEpAn/r2J9i3T+Y+TT2zUdhLqSD4enHnmmaxZs4bx48cD0KZNG5588kkmT57Mww8/zKBBgxgwYMDBbpUjyc7OZvr06Vx00UWEw2G6dOnCa6+9xpQpU3jiiScYMmQIY8eOpX///kfcRseOHRk0aBCrV69mzJgxAAwePJi77rqLM888k3A4THp6Og8++CC9ejVen5+I1M/Wsgr++eF2rj6tNy0zUhu1rXpdGaqx5eXl+WcvGrJmzRoGDRoUUEXJS++rSHDum7uO3725nrf/60v0zGoV9fbMbIm75x3uOXXXiIg0oaraEH9duIUvD+gSk4A/FoW8iEgTenXVNnaUV3NZIw6bPFRChHw8dSklA72fIsF5/P1N5Ga1YkK/prkSXtyHfGZmJjt37lQwxciB+eQzMxvn7DoRObLlBaUs3VLKZeNzSUlpmtlg4350TU5ODoWFhegi37Fz4MpQItK0Hnl3I21bpHFxXtP9/4v7kE9PT9cVjEQk4RWVVvDyyq1cdWoubTOb7tKbcd9dIyKSDGa8V3di47dObdqDVoW8iEgj21tZw9MLCzh7WDe6d2jZpG0r5EVEGtmsRQXsrarlmi82fddzTELezL5vZh+a2Sozm2lmmWbW28wWmNl6M5tlZhmxaEtEJJHUhsL8+b1NjMntxPCcDk3eftQhb2bdge8Cee4+FEgFpgL3AL9x977AbuDqaNsSEUk0r364jaLSCr4dwFE8xK67Jg1oaWZpQCtgK/BlYE7k+ceBC2LUlohIQnB3/vTORnKzWjFpUNdAaog65N29CPgVsIW6cC8DlgCl7l4bWa0Q6H6415vZNDNbbGaLNRZeRJLJks27WV5QytWn9Sa1iU5++qxYdNd0BM4HegMnAq2Bycf7enef7u557p6Xnd00p/mKiDSFR97ZSPuW6UwZFdzJh7Horjkd2OjuJe5eAzwHnAp0iHTfAOQARTFoS0QkIWzasY9/rt7GN8f2pFVGcOedxiLktwDjzKyV1V3DahKwGngT+GpknSuAF2LQlohIQvjjvI/JSE3hyiY++emzYtEnv4C6L1iXAisj25wO3AL8wMzWA1nAo9G2JSKSCLaVVTJnSSEX5/Ugu22LQGuJyWcId78DuOMzizcAY2KxfRGRRPKndzYQdpg2oU/QpeiMVxGRWNq1r5q/LtjC+aecSI9OjX/lp2NRyIuIxNCM9zZSWRvi+oknBV0KoJAXEYmZvZU1zHh/E2cNPoG+XdoGXQ6gkBcRiZmnFmxhT2Ut138pPo7iQSEvIhITlTUhHnlnI1/s1zmQiciORCEvIhIDzywuYEd5FddP7Bt0KZ+ikBcRiVJNKMzDb29gZM8OjOvTKehyPkUhLyISpeeWFlJUWsENX+pL3Yn/8UMhLyISheraMA+8sZ7hOe358sAuQZfzOQp5EZEoPLOkgKLSCr5/Rv+4O4oHhbyISINV1Yb4/b/WM6JnByb2j8+p0hXyIiINNGtRAVvLKrn5jAFxeRQPCnkRkQaprAnx4JvrGZPbiVP7ZgVdzhEp5EVEGuCvC7awfU9V3PbFH6CQFxGpp4rqEA+99THj+2Qx/qT4PYoHhbyISL09OX8zO8rrjuLjnUJeRKQe9lXV8vDbH/PFfp0Z0zu+zm49HIW8iEg9PPruRnbuq06Io3hQyIuIHLcd5VX88e2POWtIV0b27Bh0OcdFIS8icpx+90Y+lbVhfjR5YNClHDeFvIjIcdi0Yx9PLdjC10f34KTsNkGXc9xiEvJm1sHM5pjZWjNbY2bjzayTmb1mZvmR28T4bCMichj3zl1HemoK35vUL+hS6iVWR/L3A6+6+0DgZGAN8GPgDXfvB7wReSwiknCWF5Ty0oqtXDOhD13aZQZdTr1EHfJm1h6YADwK4O7V7l4KnA88HlntceCCaNsSEWlq7s4vXllD5zYZTJvQJ+hy6i0WR/K9gRLgz2b2gZk9Ymatga7uvjWyzjag6+FebGbTzGyxmS0uKSmJQTkiIrHz1roS5m/YxXcn9aNNi7Sgy6m3WIR8GjAS+IO7jwD28ZmuGXd3wA/3Ynef7u557p6XnR2fU3WKSPMUCjt3v7KW3KxWXDKmZ9DlNEgsQr4QKHT3BZHHc6gL/e1m1g0gclscg7ZERJrMnCUFrNu+lx+eNZD01MQcjBh11e6+DSgwswGRRZOA1cCLwBWRZVcAL0TblohIU9lTWcO9/1xHXq+OnD3shKDLabBYdTB9B3jKzDKADcCV1P0BmW1mVwObgYtj1JaISKN74PV8du6rZsaVY+J6KuFjiUnIu/syIO8wT02KxfZFRJrS+uJyZry/ia/n9WBo9/ZBlxOVxOxkEhFpJO7O//vHalpmpPJfZw049gvinEJeROQQb6wpZt5HJXzv9P50btMi6HKippAXEYmoqg1x50ur6dulDZeP7xV0OTGhkBcRiXjs3U1s3rmf/zl3cMIOmfys5NgLEZEoFe+p5Pf/yuf0QV2Z0D95TsxUyIuIAHe+tIaakPPf5wwKupSYUsiLSLP31rpi/r78E274Ul9yO7cOupyYUsiLSLNWUR3iJy+sok92a66bmHizTB5L4k2pJiISQ/e/kU/BrgqenjaOFmmpQZcTczqSF5Fma+22PTzyzgYuzsthXJ+soMtpFAp5EWmWwmHn1udW0q5lOrd+Jbm+bD2UQl5EmqW/LtzCB1tK+e9zBtGxdUbQ5TQahbyINDvFeyq559W1nNo3iwtHdA+6nEalkBeRZsXd+e+/raKqNsxdFwxL6GmEj4dCXkSalReWfcLc1dv5rzP70zvJxsQfjkJeRJqN7Xsq+Z8XVjGqV0euPi35xsQfjkJeRJoFd+fHz66gOhTmV187mdSU5O6mOUAhLyLNwjNLCnlzXQm3TB7YLLppDlDIi0jS+6S0gjv/vppxfTpxxfjcoMtpUgp5EUlq7s4tz64g5M69Xz2ZlGbSTXOAQl5EktqTC7bwTv4Objt7ED06tQq6nCYXs5A3s1Qz+8DM/hF53NvMFpjZejObZWbJe0qZiMSlddv2ctc/VjOhfzbfHNsz6HICEcsj+ZuANYc8vgf4jbv3BXYDV8ewLRGRo6qsCfGdmUtpm5nOr792ctKf9HQkMQl5M8sBzgEeiTw24MvAnMgqjwMXxKItEZHjcec/VvPR9nLuu/hkstu2CLqcwMTqSP63wI+AcORxFlDq7rWRx4XAYSeIMLNpZrbYzBaXlJTEqBwRac5eWbmVpxZs4doJfZLqeq0NEXXIm9m5QLG7L2nI6919urvnuXtednbz/scQkegVlVZwy7MrODmnPTefOSDocgIXiytDnQqcZ2ZnA5lAO+B+oIOZpUWO5nOAohi0JSJyRLWhMDfN/ICwwwOXjCAjTQMIo34H3P1Wd89x91xgKvAvd/8m8Cbw1chqVwAvRNuWiMjR/Pb1fBZv3s1dFwylV1bzOav1aBrzz9wtwA/MbD11ffSPNmJbItLMvbZ6O79/cz1fG5XDBUk+R3x9xPRC3u7+FvBW5P4GYEwsty8icjgfl5Tzg1nLGNa9PXdeMDTocuKKOqxEJKGVV9Vy3V+WkJ6WwsOXjSIzPTXokuKKQl5EEpa786M5y/m4pJzfXzKC7h1aBl1S3FHIi0jCmj5vAy+v3MaPvzKQL/TtHHQ5cUkhLyIJ6b31O7jn1bWcM7wb13yxeVzlqSEU8iKScNYXl/OfTy6hb5c2/HLK8GY7L83xUMiLSELZWV7FVTMWkZGWwqNXjKZ1i5gOEkw6endEJGFU1oS45onFbN9TyaxrxzfL+eHrSyEvIgkhHHZufmY5HxSU8tA3RnJKjw5Bl5QQ1F0jIgnhV3PX8dKKrdz6lYF8ZVi3oMtJGAp5EYl7Ty/cwkNvfcw3xvbUSJp6UsiLSFx7ddVWbnt+JRP6Z/Oz84ZoJE09KeRFJG7N+6iE78z8gFN6dODhS0eSnqrIqi+9YyISlxZv2sW0vyymb5e2/PlbY2iVoXEiDaGQF5G48+EnZVw5YxHd2rfkiavG0L5VetAlJSyFvIjElQ0l5Vz+6ELatkjjyW+PbdYX4Y4FhbyIxI0NJeV8408LAPjLt8dqVskYUCeXiMSF/O17+cYjCwiHnSe/PZaTstsEXVJSUMiLSODWbN3DpY8sICXFeHraOPp1bRt0SUlD3TUiEqhVRWVc8qf5pKemMEsBH3M6kheRwHywZTeXP7aQdpnpzLxmHD2zNOFYrEV9JG9mPczsTTNbbWYfmtlNkeWdzOw1M8uP3HaMvlwRSRZvf1TCpY8soFPrDGZfN14B30hi0V1TC9zs7oOBccANZjYY+DHwhrv3A96IPBYRYfaiAq6asYheWa155trxGkXTiKLurnH3rcDWyP29ZrYG6A6cD0yMrPY48BZwS7TtiUjicnd++3o+97+Rz4T+2Tz0zZG00UU/GlVM310zywVGAAuArpE/AADbgK6xbEtEEktNKMxtz63kmSWFfG1UDv970TDNRdMEYhbyZtYGeBb4nrvvOXSmOHd3M/MjvG4aMA2gZ8+esSpHROJI2f4abpy5lHfyd3DTpH587/R+mk2yicTkz6iZpVMX8E+5+3ORxdvNrFvk+W5A8eFe6+7T3T3P3fOys7NjUY6IxJF12/Zy3oPvMn/DTu6ZMozvn9FfAd+EYjG6xoBHgTXuft8hT70IXBG5fwXwQrRtiUhieWnFVi586D32V4d4eto4vj5an9abWiy6a04FLgNWmtmyyLLbgLuB2WZ2NbAZuDgGbYlIAgiFnXv/uY6H3/6YkT078IdLR9G1XWbQZTVLsRhd8y5wpM9ek6LdvogklpK9Vfxg9jLeyd/BN8f25I7/GEJGmr5gDYrGLolIzLy5tpgfzlnO3spa7pkyTN0zcUAhLyJRq6wJcfcra5nx/iYGntCWv14zjv6agyYuKORFJCrrtu3lpqc/YO22vVx5ai63TB5IZnpq0GVJhEJeRBqkJhRm+rwN3P9GPu0y05hx5WgmDugSdFnyGQp5Eam3ZQWl/PjZFazdtpezh53Az84bqsv0xSmFvIgct31Vtfx67kfMeH8jXdpmMv2yUZw55ISgy5KjUMiLyDG5O6+s2sbPX1pDUWkFl43rxY8mD6BtZnrQpckxKORF5KhWFJZy5z9Ws2jTbgae0JY5140nL7dT0GXJcVLIi8hhbS2r4N5X1/HcB0V0bpPBLy4axsV5PUhN0bwziUQhLyKfsqO8iunzNvDEvzcRdvjPiSdx/cST1DWToBTyIgJA8d5Kpr+9gScXbKa6Nsz5p3TnB2f0p0cnXZYvkSnkRZq5T0orePTdjTwVCfcLRnTnxi/1pU92m6BLkxhQyIs0Q+7Oks27+fP7m3h11TYALjilOzd+uS+9O7cOuDqJJYW8SDNSWRPilVVb+fN7m1hRWEa7zDSuPq03l43rpW6ZJKWQF0ly7s6yglKeXVrIi8s+YU9lLSdlt+bOC4YyZWR3WmUoBpKZ/nVFklTBrv38fcUnPLukkI9L9pGZnsJZQ07gq6NyOPWkzqRoKGSzoJAXSRLuTn5xOa+u2sY/P9zGh5/sAWB0bkemTejD2cO6aRhkM6SQF0lg+6pqWbhxF++u38G/1hazccc+AEb27MBtZw9k8pBu9MxSX3tzppAXSSD7q2tZUVjGgg27eG/9DpZu2U1t2MlIS2Fs705cfVpvzhzclS66nqpEKORF4lRNKMzGHftYUVjGB1t288GWUtZt30so7JjBsO7tuWZCH07r25lRvTrqQh1yWAp5kYBV1oTYsms/m3fuZ31xOeu27WHttr18XFJOTcgBaJuZxik9OnDD4L6M6NmBET060KFVRsCVSyJo9JA3s8nA/UAq8Ii7393YbYrEA3dnX3WI3fuqKSmvonhPJdv3VLFtTyXb91RSuLuCzTv3sX1P1aded2L7TAac0JaJA7ow8IS2DDmxHSdlt9FoGGmQRg15M0sFHgTOAAqBRWb2oruvbsx2pensr66lZG8VO8qrKN1fw57KGvZU1LKnou7+vuoQldUhKmrqfiprQlTVhgmFnZqQUxuqux/yuiPWyM1BqSlGih24NVJTjLQUIy01hdQUIz3VSEtJIT3VSE9NIS01cj8lhbTIsvRUiyxPIS3l/7Zx4DYlxTCrayfFDAMcCLvjXhfWYa/rPqkOhamp9YP391fXsr86xP6qEPsi98sqaijdX0NZRfXBI/FDpaUYXdq2oHvHlpzWN5vcrFb0zGpFr6zW9O7cmvYtNQJGYqexj+THAOvdfQOAmT0NnA8o5BNAKOwU7a6gYPd+Cnfvp3B3ReRnP8V7q9ixt4p91aEjvr5leiqtW6TRMiOFzLRUWmakkpmWSpsWaZGA/b8ATjE4cJxqVnfP3Qk5hMNO2L3uj0HYqY3c1oTCVNWEKQ/V1v3BCIepCdUtrwmFqQ051ZHbA8/FyoE/MK0y0miVkRr5qbvfr0sbOrTKoEOrdDq2SqdDywyy27agS7sWdG2XSadWGToqlybT2CHfHSg45HEhMLaR25R6cneKSitYVVTGR9vLWV9cTn5xORtKyqmqDR9cLzXF6NY+k+4dWnJyTgey27agc5sWdG6TQee2LejUKoN2LdNpl5lG28x0MtJSAtyrzztwRF4bPhD8dX8sDiw/cBt2JyVyZG9mWOQIP/3gJ4MUzakuCSPwL17NbBowDaBnz54BV9M8lO6vZvGm3awoLGV5YRkri8rYta/64PM5HVvSt0sbTuubRd8ubejZqTU5HVvSrX0maanxFdz1YWakGqSmpNIi8N98kabR2L/qRUCPQx7nRJYd5O7TgekAeXl5sfs8LQeV7q9mwcZdzN+wk/kbdrF22x7c647M+3Vpw+mDujAspwPDurenf9c2mstEJIk09v/mRUA/M+tNXbhPBb7RyG02e+7O2m17eWPNdt5YW8yyglLcoUVaCnm5HfnB6f0Z2yeLYd3b0zJDY6tFklmjhry715rZjcA/qRtC+Zi7f9iYbTZX4bCzcNMuXl65lTfWFFNUWgHA8Jz23DSpH6f27czwnPa0SFOoizQnjf653N1fBl5u7HaaqzVb9/C3ZUX8fdknfFJWSWZ6Cqf17cyNX+7Llwd2oatObxdp1tT5moDK9tcwZ2khsxcVsG77XtJSjAn9s7nlKwM5Y3BX9amLyEFKgwSysrCMv8zfxIvLP6GyJsyInh248/whnD2sG1ltWgRdnojEIYV8nKsNhXlp5VYee3cjywvLaJWRykUjc7h0bC8Gn9gu6PJEJM4p5ONUVW2IZ5cU8fDbH7Nl135Oym7Nz84bwoUju9NOF34QkeOkkI8z+6pqmblwC396ZwPb91Rxck57bj9nFGcM6qpT4UWk3hTycaImFObpRQXc/3o+O8qr+MJJWdx38Sl84aSsg3O5iIjUl0I+YO7Oq6u2ce8/17Fhxz7G5Hbij5eNZFSvTkGXJiJJQCEfoCWbd3HXS2v4YEsp/bq04ZHL85g0qIuO3EUkZhTyAdi1r5q7X1nD7MWFdG3XgnumDGPKyJyEnvxLROKTQr4JhcPOM0sKuPuVteytrOXaCX347qR+tNaUiCLSSJQuTWTdtr3c/vxKFm/ezejcjtx1wTAGnNA26LJEJMkp5BtZKOz86Z0N3Df3I1q3SOWXU4bz1VE5Gg4pIk1CId+Ituzcz83PLGPRpt2cNaQr/3vhME0/ICJNSiHfCNydWYsKuPMfq0kx476LT+bCEd01akZEmpxCPsbK9tdw8zPLeX3Ndr5wUhb3fu1kundoGXRZItJMKeRjaFVRGf/51BK2lVXyk3MHc+UXctX3LiKBUsjHgLszc2EBP/37h3RuncHsa8czomfHoMsSEVHIR6uiOsTtf1vJc0uL+GK/ztw/dQSdWmcEXZaICKCQj0pRaQVXz1jEuu17uWlSP747qR+p6p4RkTiikG+gFYWlXP34YiqrQzz2rdF8aUCXoEsSEfkchXwDvLpqK9+btYys1i146vqx9O+qM1dFJD5FNSOWmd1rZmvNbIWZPW9mHQ557lYzW29m68zsrKgrjQPuzsNvf8x1Ty5lULd2/O2GUxXwIhLXop328DVgqLsPBz4CbgUws8HAVGAIMBl4yMxSo2wrULWhMLc9v5K7X1nLOcO7MfOacWS31dmrIhLfogp5d5/r7rWRh/OBnMj984Gn3b3K3TcC64Ex0bQVpKraEN+Z+QEzFxZww5dO4ndTR5CZntB/s0SkmYhln/xVwKzI/e7Uhf4BhZFln2Nm04BpAD179oxhObFRUR3i2ieXMO+jEn5y7mCuPq130CWJiBy3Y4a8mb0OnHCYp2539xci69wO1AJP1bcAd58OTAfIy8vz+r6+Me2prOHqGYtYsnk3v5wynItH9wi6JBGRejlmyLv76Ud73sy+BZwLTHL3AyFdBByaiDmRZQljZ3kVV/x5Ieu27eV3l4zknOHdgi5JRKTeoh1dMxn4EXCeu+8/5KkXgalm1sLMegP9gIXRtNWUSvZW8fXp88nfXs70y/MU8CKSsKLtk/890AJ4LTKN7nx3v87dPzSz2cBq6rpxbnD3UJRtNYld+6q59JEFFO2u4PGrxjCuT1bQJYmINFhUIe/ufY/y3M+Bn0ez/aZWVlHDZY8uYNPOffz5W6MV8CKS8KIdJ5809lbWcMVjC8nfXs4fLxvFF/p2DrokEZGoKeSB/dW1XDVjEauKyvj9N0YwUfPQiEiSaPYhX1kT4tuPL2bJ5t3cP3UEZw453GhREZHE1KwnKAuHnZufWc77H+/kvotP1igaEUk6zfpI/n9fXsNLK7Zy29kDuWhkzrFfICKSYJptyD/67kYeeXcj3/pCLtd8sU/Q5YiINIpmGfIvrdjKXS+tZvKQE/jJuYOJjPEXEUk6zS7kF27cxfdnL2NUz478duopulyfiCS1ZhXyG0rK+fbji8jp2JI/XZ6n6YJFJOk1m5DfU1nDNU8sJi01hcevHEPH1hlBlyQi0uiaRciHws73nl7G5p37eeibI+nRqVXQJYmINIlmEfK/nruOf60t5o7zhmg+GhFpVpI+5P++/BMeeutjLhnTk0vHxt+Vp0REGlNSh/yqojJ+OGc5eb068rPzhmiopIg0O0kb8jvLq7j2L0vo2CqDP1w6ioy0pN1VEZEjSsq5a8Jh5/uzl1NSXsWc68aT3bZF0CWJiAQiKQ9vH573MfM+KuEn5w5meE6HoMsREQlM0oX8ok27+PXcjzhneDd90SoizV5ShfyufdV8568fkNOxJXdfNExftIpIs5c0ffLhsPOD2cvYta+a567/Am0z04MuSUQkcElzJP/HeRt4a10JPzl3EEO7tw+6HBGRuBCTkDezm83Mzaxz5LGZ2QNmtt7MVpjZyFi0cySLN+3iV3PXcc6wblw6rldjNiUiklCiDnkz6wGcCWw5ZPFXgH6Rn2nAH6Jt52gy01M5tW9nfjFF/fAiIoeKxZH8b4AfAX7IsvOBJ7zOfKCDmTXaBVSHdm/PE1eNoZ364UVEPiWqkDez84Eid1/+mae6AwWHPC6MLDvcNqaZ2WIzW1xSUhJNOSIi8hnHHF1jZq8DJxzmqduB26jrqmkwd58OTAfIy8vzY6wuIiL1cMyQd/fTD7fczIYBvYHlkX7wHGCpmY0BioAeh6yeE1kmIiJNqMHdNe6+0t27uHuuu+dS1yUz0t23AS8Cl0dG2YwDytx9a2xKFhGR49VYJ0O9DJwNrAf2A1c2UjsiInIUMQv5yNH8gfsO3BCrbYuISMMkzRmvIiLyeQp5EZEkZnU9K/HBzEqAzQ18eWdgRwzLCZL2JT4ly74ky36A9uWAXu6efbgn4irko2Fmi909L+g6YkH7Ep+SZV+SZT9A+3I81F0jIpLEFPIiIkksmUJ+etAFxJD2JT4ly74ky36A9uWYkqZPXkREPi+ZjuRFROQzkirkzezOyJWolpnZXDM7MeiaGsrM7jWztZH9ed7MOgRdU0OZ2dfM7EMzC5tZwo2EMLPJZrYucqWzHwddT0OZ2WNmVmxmq4KuJVpm1sPM3jSz1ZHfrZuCrqkhzCzTzBaa2fLIfvws5m0kU3eNmbVz9z2R+98FBrv7dQGX1SBmdibwL3evNbN7ANz9loDLahAzGwSEgT8C/+XuiwMu6biZWSrwEXAGdZPwLQIucffVgRbWAGY2ASin7oI+Q4OuJxqRixB1c/elZtYWWAJckGj/LlY3hW9rdy83s3TgXeCmyMWWYiKpjuQPBHxEaz59taqE4u5z3b028nA+ddM1JyR3X+Pu64Kuo4HGAOvdfYO7VwNPU3fls4Tj7vOAXUHXEQvuvtXdl0bu7wXWcIQLE8WzyNXzyiMP0yM/Mc2tpAp5ADP7uZkVAN8E/ifoemLkKuCVoItopo77KmcSDDPLBUYACwIupUHMLNXMlgHFwGvuHtP9SLiQN7PXzWzVYX7OB3D32929B/AUcGOw1R7dsfYlss7tQC11+xO3jmdfRGLNzNoAzwLf+8wn+YTh7iF3P4W6T+tjzCymXWmNNZ98oznSlaoO4ynq5rW/oxHLicqx9sXMvgWcC0zyOP/ypB7/LolGVzmLU5E+7GeBp9z9uaDriZa7l5rZm8BkIGZfjifckfzRmFm/Qx6eD6wNqpZomdlk4EfAee6+P+h6mrFFQD8z621mGcBU6q58JgGKfGH5KLDG3e8Lup6GMrPsAyPnzKwldV/wxzS3km10zbPAAOpGcmwGrnP3hDzqMrP1QAtgZ2TR/AQeKXQh8DsgGygFlrn7WYEWVQ9mdjbwWyAVeMzdfx5sRQ1jZjOBidTNdrgduMPdHw20qAYys9OAd4CV1P1/B7jN3V8Orqr6M7PhwOPU/W6lALPd/f/FtI1kCnkREfm0pOquERGRT1PIi4gkMYW8iEgSU8iLiCQxhbyISBJTyIuIJDGFvIhIElPIi4gksf8P49VH+I9HxDQAAAAASUVORK5CYII=\n", - "text/plain": [ - "
" - ] - }, - "metadata": { - "needs_background": "light" - }, - "output_type": "display_data" - } - ], - "source": [ - "# 画出这个函数的曲线\n", - "x_sample = np.arange(-3, 3.1, 0.1)\n", - "y_sample = b_target[0] + w_target[0] * x_sample + w_target[1] * x_sample ** 2 + w_target[2] * x_sample ** 3\n", - "\n", - "plt.plot(x_sample, y_sample, label='real curve')\n", - "plt.legend()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "接着我们可以构建数据集,需要 x 和 y,同时是一个三次多项式,所以我们取了 $x,\\ x^2, x^3$" - ] - }, - { - "cell_type": "code", - "execution_count": 17, - "metadata": {}, - "outputs": [], - "source": [ - "# 构建数据 x 和 y\n", - "# x 是一个如下矩阵 [x, x^2, x^3]\n", - "# y 是函数的结果 [y]\n", - "\n", - "x_train = np.stack([x_sample ** i for i in range(1, 4)], axis=1)\n", - "x_train = torch.from_numpy(x_train).float() # 转换成 float tensor\n", - "\n", - "y_train = torch.from_numpy(y_sample).float().unsqueeze(1) # 转化成 float tensor " - ] - }, - { - "cell_type": "code", - "execution_count": 18, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "torch.Size([61, 3])\n" - ] - } - ], - "source": [ - "print(x_train.size())" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "接着我们可以定义需要优化的参数,就是前面这个函数里面的 $w_i$" - ] - }, - { - "cell_type": "code", - "execution_count": 20, - "metadata": {}, - "outputs": [], - "source": [ - "# 定义参数和模型\n", - "w = Variable(torch.randn(3, 1), requires_grad=True)\n", - "b = Variable(torch.zeros(1), requires_grad=True)\n", - "\n", - "# 将 x 和 y 转换成 Variable\n", - "x_train = Variable(x_train)\n", - "y_train = Variable(y_train)\n", - "\n", - "def multi_linear(x):\n", - " return torch.mm(x, w) + b\n", - "\n", - "def get_loss(y_, y):\n", - " return torch.mean((y_ - y) ** 2)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "我们可以画出没有更新之前的模型和真实的模型之间的对比" - ] - }, - { - "cell_type": "code", - "execution_count": 21, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "" - ] - }, - "execution_count": 21, - "metadata": {}, - "output_type": "execute_result" - }, - { - "data": { - "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXkAAAD7CAYAAACPDORaAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAAAprklEQVR4nO3deXhUVbb38e8iBMI8S6OgwRaRQREJGLQVGhRQEURFaW2cRbr1tlM74tt4Ha4DjiiKKIheuSCDA22LAiqiImrACQEBUSGIEEBBIECG/f6xq0jAAEmqKqeq8vs8z3lODafOWZVh1a599lnbnHOIiEhyqhJ0ACIiEjtK8iIiSUxJXkQkiSnJi4gkMSV5EZEkpiQvIpLESp3kzWycma03s0XFHmtoZrPMbHlo3SD0uJnZSDNbYWZfmdlxsQheRET2rywt+fFAn70euxV4xznXCngndB/gNKBVaBkCPB1ZmCIiUh5WlouhzCwdeMM51z50/1ugu3NurZk1A+Y451qb2TOh2xP33m5/+2/cuLFLT08v3zsREamkFixYsME516Sk56pGuO+mxRL3z0DT0O1DgNXFtssOPbbfJJ+enk5WVlaEIYmIVC5m9uO+novaiVfnvxKUuUaCmQ0xsywzy8rJyYlWOCIiQuRJfl2om4bQen3o8TVAi2LbNQ899jvOuTHOuQznXEaTJiV+2xARkXKKNMlPBy4O3b4YeL3Y4xeFRtlkApsP1B8vIiLRV+o+eTObCHQHGptZNjAcuB+YbGaXAz8C54U2fxM4HVgBbAcuLW+AeXl5ZGdns2PHjvLuQiKQlpZG8+bNSU1NDToUESmHUid559xf9vFUzxK2dcDV5Q2quOzsbOrUqUN6ejpmFo1dSik559i4cSPZ2dm0bNky6HBEpBzi/orXHTt20KhRIyX4AJgZjRo10rcokQQW90keUIIPkH72IoktIZK8iEgyu+sumDs3NvtWki+FkSNH0qZNGy688EKmT5/O/fffD8Brr73G4sWLd283fvx4fvrpp933r7jiij2eFxHZ27JlMHw4vP9+bPYf6RWvlcJTTz3F7Nmzad68OQD9+vUDfJLv27cvbdu2BXySb9++PQcffDAAzz33XDABF5Ofn0/Vqvo1i8Srp56C1FS48srY7F8t+QMYOnQoK1eu5LTTTuPRRx9l/PjxXHPNNcybN4/p06dz0003ceyxx/LAAw+QlZXFhRdeyLHHHktubi7du3ffXaahdu3aDBs2jA4dOpCZmcm6desA+O6778jMzOToo4/mjjvuoHbt2iXG8eKLL3LMMcfQoUMHBg8eDMAll1zC1KlTd28Tfu2cOXM46aST6NevH23btuXWW29l1KhRu7e78847eeihhwAYMWIEnTt35phjjmH48OHR/wGKyD5t3QrPPw8DB8If/hCbYyRWE++66+CLL6K7z2OPhcce2+fTo0eP5q233uK9996jcePGjB8/HoATTjiBfv360bdvX84991wAZsyYwUMPPURGRsbv9rNt2zYyMzO59957ufnmm3n22We54447uPbaa7n22mv5y1/+wujRo0uM4ZtvvuGee+5h3rx5NG7cmE2bNh3wbS1cuJBFixbRsmVLPv/8c6677jquvtqPap08eTJvv/02M2fOZPny5Xz66ac45+jXrx9z587l5JNPPuD+RSRyEybAli1wdVQGnJdMLfkKUq1aNfr27QtAp06d+OGHHwD4+OOPGThwIAAXXHBBia999913GThwII0bNwagYcOGBzxely5ddo9t79ixI+vXr+enn37iyy+/pEGDBrRo0YKZM2cyc+ZMOnbsyHHHHcfSpUtZvnx5pG9VRErBOXjySejYEbp2jd1xEqslv58Wd7xLTU3dPRwxJSWF/Pz8iPdZtWpVCgsLASgsLGTXrl27n6tVq9Ye2w4cOJCpU6fy888/c/755wP+YqfbbruNq666KuJYRKRs5s6FRYtg7FiI5UhlteQjUKdOHX777bd93i+NzMxMpk2bBsCkSZNK3KZHjx5MmTKFjRs3AuzurklPT2fBggUATJ8+nby8vH0e5/zzz2fSpElMnTp19zeH3r17M27cOLZu3QrAmjVrWL9+/T73ISLR8+ST0KABDBoU2+MoyUdg0KBBjBgxgo4dO/Ldd99xySWXMHTo0N0nXkvjscce45FHHuGYY45hxYoV1KtX73fbtGvXjmHDhtGtWzc6dOjADTfcAMCVV17J+++/T4cOHfj4449/13rfex+//fYbhxxyCM2aNQOgV69eXHDBBXTt2pWjjz6ac889t8wfUiJSdtnZ8OqrcPnlULNmbI9VppmhYi0jI8PtPWnIkiVLaNOmTUARxd727dupUaMGZsakSZOYOHEir7/++oFfWIGS/XcgUtH+9S+45x5YsQIOPzzy/ZnZAufc70d8kGh98klowYIFXHPNNTjnqF+/PuPGjQs6JBGJoZ074Zln4IwzopPgD0RJPmAnnXQSX375ZdBhiEgFmTYN1q+P7bDJ4tQnLyJSgZ58Eo44Anr1qpjjKcmLiFSQzz6Djz/2rfgqFZR9leRFRCrII49A3bpw2WUVd0wleRGRCrBqFUyZ4guR1a1bccdVkq8A6enpbNiwIegwRCRAI0f69T/+UbHHVZIvA+fc7jICikNESmvLFnj2WV9t8tBDK/bYSvIH8MMPP9C6dWsuuugi2rdvz+rVq/dZnvess86iU6dOtGvXjjFjxhxw32+99RbHHXccHTp0oGdPPx968TLAAO3bt+eHH374XRx33303N9100+7twiWQAV566SW6dOnCsccey1VXXUVBQUG0fhwiUg5jx/pEf+ONFX/sqIyTN7PrgSsAB3wNXAo0AyYBjYAFwGDn3K597qQUAqg0DMDy5ct54YUXyMzM3G953nHjxtGwYUNyc3Pp3Lkz55xzDo0aNSpxnzk5OVx55ZXMnTuXli1blqp8cPE4cnJy6Nq1KyNGjADg5ZdfZtiwYSxZsoSXX36Zjz76iNTUVP7+978zYcIELrroojL+ZEQkGvLz4fHH4aSToIQq5DEXcZI3s0OAfwBtnXO5ZjYZGAScDjzqnJtkZqOBy4GnIz1eEA477DAyMzMB9ijPC7B161aWL1/OySefzMiRI3n11VcBWL16NcuXL99nkp8/fz4nn3zy7nLApSkfXDyOJk2acPjhhzN//nxatWrF0qVLOfHEExk1ahQLFiygc+fOAOTm5nLQQQdF9gMQkXJ75RX48Uef6IMQrSteqwI1zCwPqAmsBXoA4QLpLwB3EmGSD6rScPHCX/sqzztnzhxmz57Nxx9/TM2aNenevTs7duwo87GKlw8G9tjH3gXIBg0axOTJkznqqKMYMGAAZoZzjosvvpj77ruvzMcWkehyDh5+2F/8FJpOosJF3CfvnFsDPASswif3zfjumV+dc+Gi6dnAISW93syGmFmWmWXl5OREGk7M7as87+bNm2nQoAE1a9Zk6dKlzJ8/f7/7yczMZO7cuXz//ffAnuWDFy5cCPjZncLPl2TAgAG8/vrrTJw4kUGheqU9e/Zk6tSpu0sGb9q0iR9//DGyNy0i5TJvHnz6KVx/PaSkBBNDNLprGgD9gZbAr8AUoE9pX++cGwOMAV+FMtJ4Yq1Xr14sWbKErqGpXGrXrs1LL71Enz59GD16NG3atKF169a7u1X2pUmTJowZM4azzz6bwsJCDjroIGbNmsU555zDiy++SLt27Tj++OM58sgj97mPBg0a0KZNGxYvXkyXLl0AaNu2Lffccw+9evWisLCQ1NRURo0axWGHHRa9H4KIlMojj/ia8RdfHFwMEZcaNrOBQB/n3OWh+xcBXYGBwB+cc/lm1hW40znXe3/7qoylhhOBfgciZbdiBRx5JNx6K/zP/8T2WPsrNRyNIZSrgEwzq2l+fruewGLgPeDc0DYXA/FVJF1EJIYefBCqV4drrw02jmj0yX8CTAUW4odPVsF3v9wC3GBmK/DDKMdGeiwRkUSwZg2MH+9r1DRtGmwsURld45wbDgzf6+GVQJco7X/3JNhSseJp5jCRRPHww1BYCMWuVwxM3F/xmpaWxsaNG5VsAuCcY+PGjaSlpQUdikjC2LDBz/x04YWQnh50NAkwM1Tz5s3Jzs4mEYZXJqO0tDSaN28edBgiCWPkSMjN9Sdc40HcJ/nU1NTdV4WKiMSzLVvgiSdgwACIlwFpcd9dIyKSKEaPhl9/hdtuCzqSIkryIiJRkJvrL37q1SuYQmT7oiQvIhIFzz8P69bFVyselORFRCKWl+cvfuraFbp1CzqaPcX9iVcRkXj34ou+nPCTT0K8XdKjlryISAR27YK77/b98GecEXQ0v6eWvIhIBJ5/3rfin346/lrxoJa8iEi57dwJ99wDmZnQp9QF1iuWWvIiIuX03HOQne1b8/HYige15EVEyiU319eJP+kk6Nkz6Gj2TS15EZFyGDMGfvoJJkyI31Y8qCUvIlJm27fDfffBn/8M3bsHHc3+qSUvIlJGTz/tr26dMiXoSA5MLXkRkTLYuhUeeABOPdX3x8c7JXkRkTJ49FHIyYG77go6ktJRkhcRKaX1632NmgED/Nj4RKAkLyJSSnff7YdO3ndf0JGUnpK8iEgprFjhJwW54gpo3TroaEovKknezOqb2VQzW2pmS8ysq5k1NLNZZrY8tG4QjWOJiARh2DCoVg2GDw86krKJVkv+ceAt59xRQAdgCXAr8I5zrhXwTui+iEjC+ewzmDwZ/vlPaNYs6GjKJuIkb2b1gJOBsQDOuV3OuV+B/sALoc1eAM6K9FgiIhXNObj5ZjjoIJ/kE000WvItgRzgeTP73MyeM7NaQFPn3NrQNj8DTUt6sZkNMbMsM8vKycmJQjgiItEzYwbMmQP/+hfUqRN0NGUXjSRfFTgOeNo51xHYxl5dM845B7iSXuycG+Ocy3DOZTRp0iQK4YiIREdBAdxyCxxxBAwZEnQ05RONJJ8NZDvnPgndn4pP+uvMrBlAaL0+CscSEakw48fDokW+2mRqatDRlE/ESd459zOw2szCg4p6AouB6cDFoccuBl6P9FgiIhVl82a4/XY48UQ499ygoym/aBUo+y9ggplVA1YCl+I/QCab2eXAj8B5UTqWiEjM3XWXL18wY0Z8lxI+kKgkeefcF0BGCU/FcSl9EZGSLV0KI0fC5ZfDcccFHU1kdMWriEgxzsF110GtWnDvvUFHEznVkxcRKeaNN+Dtt321yYMOCjqayKklLyISsnMnXH89tGkDV18ddDTRoZa8iEjIY4/Bd9/5lnyiDpncm1ryIiLA2rVwzz3Qrx/06hV0NNGjJC8iAtxwA+zaBQ8/HHQk0aUkLyKV3ltvwaRJvpzwEUcEHU10KcmLSKW2fTv8/e9+IpBbbgk6mujTiVcRqdTuugu+/95XmqxePehook8teRGptL7+2vfBX3YZdOsWdDSxoSQvIpVSYaEvH1y/Pjz4YNDRxI66a0SkUhozBubPhxdfhEaNgo4mdtSSF5FKZ+1auPVW6NkT/vrXoKOJLSV5EalUnPOjaXbsgKefTuwywqWh7hoRqVT+7//gtddgxAho1SroaGJPLXkRqTR++gmuuQZOOMEXIqsMlORFpFJwDq680leaHD8eUlKCjqhiqLtGRCqF8ePhzTfh8ccrRzdNmFryIpL0Vq/2sz117+67ayoTJXkRSWrO+blaCwpg3DioUsmynrprRCSpjR4Ns2b54ZItWwYdTcWL2meamaWY2edm9kbofksz+8TMVpjZy2ZWLVrHEhEpjUWLfJ343r3hqquCjiYY0fzici2wpNj9B4BHnXNHAL8Al0fxWCIi+5WbC4MGQb168MILyX/R075EJcmbWXPgDOC50H0DegBTQ5u8AJwVjWOJiJTGDTfAN9/42jRNmwYdTXCi1ZJ/DLgZKAzdbwT86pzLD93PBg4p6YVmNsTMsswsKycnJ0rhiEhlNm2a74u/6abkmq+1PCJO8mbWF1jvnFtQntc758Y45zKccxlNmjSJNBwRqeRWrYIrroDOnf3E3JVdNEbXnAj0M7PTgTSgLvA4UN/MqoZa882BNVE4lojIPuXnwwUX+OGSEydCNQ33iLwl75y7zTnX3DmXDgwC3nXOXQi8B5wb2uxi4PVIjyUisj933gkffeSHS/7xj0FHEx9ieVnALcANZrYC30c/NobHEpFKbvp0uPdeuPRSuPDCoKOJH1G9GMo5NweYE7q9EugSzf2LiJTk229h8GDo1AmeeiroaOJLJbvAV0SSzW+/wdln+/73V16BtLSgI4ovKmsgIgnLObjsMli61JcuOPTQoCOKP0ryIpKwHnoIpk71szz16BF0NPFJ3TUikpDeecdPxn3eeXDjjUFHE7+U5EUk4SxdCueeC23awNixlbcuTWkoyYtIQsnJgTPO8Cda33gDatcOOqL4pj55EUkYubnQv7+fkPv99yE9PeiI4p+SvIgkhMJCuOQSmD8fpkyBLroKp1SU5EUkIdxxB0ye7EfSnHNO0NEkDvXJi0jce+45uO8+P7uTRtKUjZK8iMS1V17xyb13b3jiCY2kKSsleRGJWzNn+in8jj/eTwSSmhp0RIlHSV5E4tJHH8FZZ0HbtvCf/0CtWkFHlJiU5EUk7nzxhR8L37w5vP02NGgQdESJS0leROLKsmV+Xta6dWH27Mo9CXc0KMmLSNxYtqyo0JiqSkaHkryIxIXFi6FbN9i1y7fgW7cOOqLkoCQvIoH76ivo3t3fnjMHjjkmyGiSi5K8iARq4UL48599wbH33/ejaSR6lORFJDCffOL74OvUgblz4cgjg44o+USc5M2shZm9Z2aLzewbM7s29HhDM5tlZstDaw2CEpHd3n4bTjkFGjf2Cf7ww4OOKDlFoyWfD9zonGsLZAJXm1lb4FbgHedcK+Cd0H0REcaN8+PgjzgCPvhAo2hiKeIk75xb65xbGLr9G7AEOAToD7wQ2uwF4KxIjyUiic05uPNOuPxy34qfOxeaNQs6quQW1VLDZpYOdAQ+AZo659aGnvoZ0CUNIpVYXp4vNPb883DppfDMM6pFUxGiduLVzGoD04DrnHNbij/nnHOA28frhphZlpll5eTkRCscEYkjv/wCffv6BD98uJ+XVQm+YkQlyZtZKj7BT3DOvRJ6eJ2ZNQs93wxYX9JrnXNjnHMZzrmMJk2aRCMcEYkjixZB587w3nu+Lvydd6pccEWKxugaA8YCS5xzjxR7ajpwcej2xcDrkR5LRBLLlCmQmQnbtvmLnC6/POiIKp9otORPBAYDPczsi9ByOnA/cKqZLQdOCd0XkUqgoABuvRXOO89fvbpgAZxwQtBRVU4Rn3h1zn0I7OvLV89I9y8iiWXdOhg82BcYGzoUHn/cX80qwdBE3iISNW++6UfObNni+9/VPRM8lTUQkYjt2AH/+Ie/wKlpU8jKUoKPF0ryIhKRRYugSxc/yfa118Knn0K7dkFHJWFK8iJSLnl5cN99kJHh++FnzIDHHoO0tKAjk+LUJy8iZfbpp3Dllb4O/LnnwpNPapq+eKWWvIiU2tatcP310LUrbNwIr73mx8IrwccvJXkROSDnYOpUaN/ed8kMHeqn6+vfP+jI5EDUXSMi+5WV5VvvH34IRx/t1yeeGHRUUlpqyYtIibKz4aKLfN2ZZctgzBj4/HMl+ESjlryI7GH9ehgxAkaNgsJCX57gttugbt2gI5PyUJIXEQB+/tkn96efhp074YIL4O67IT096MgkEkryIpXc6tXw6KMwerRP7n/9Kwwbpkm1k4WSvEgl5BzMmwcjR8K0af6xcHJv1SrY2CS6lORFKpEdO/xQyMcf96Nm6tf3I2euvlrdMslKSV4kyTnnr1B94QWYOBF+/RWOOgqeesqPnqlVK+gIJZaU5EWS1A8/wKRJPrkvXQo1asCAAXDJJdCzJ1TRAOpKQUleJEk4569CfeUVePVVP6Yd4E9/8rXdBw7UMMjKSEleJIFt3Qpz58Ls2fDGG7B8uX+8a1c/HPLss+Hww4ONUYKlJC+SQLZt8ydM33/fJ/aPP4b8fKheHbp1gxtu8PVkmjULOlKJF0ryInEqL8+XE8jKgvnz/fL1136SbDPo1An++U845RQ/SXaNGkFHLPFISV4kYLm5sHIlfPcdLFniE/nXX/vbeXl+m3r14Pjj/Tj2zEx/u2HDYOOWxBDzJG9mfYDHgRTgOefc/bE+pkg8cM73mW/c6EsG/PSTX9as8esffoAVK/zt4lq08NUeTzvNrzt29EMeNRpGyiOmSd7MUoBRwKlANvCZmU13zi2O5XGl4mzb5hPYunWwaZMfg1182boVtm8vWnJz/QU5+fm+lZqX528XFPj9Obfn/lNSipYqVaBqVb+kphatw0u1anuuS7odfn3xpUqVPRczH0dhYdG6sNDHunMn7Nrll507/fvfts2/z/Dyyy/+Z7FpU1FLvLiqVX2f+WGHwamnwhFHwB//6Jcjj/QXKIlES6xb8l2AFc65lQBmNgnoD0Q3yWdnw0cf+SZQixb+P6iqeqIiVVAAP/4I33/vW53Fl7VrfWLfunXfr69ZE+rU8esaNYrWder8PkGHkysUrZ3zMRQU+CRbUOA/EMJLXp7/wNiypegDY9euktfhJVpSUhzVUh21ahRSu0ZB0TqtgLZN82l0ZB4N6+TTqG4eDevm84eGuzi48S4ObpJH4/r5VEmxok+V8CdYSgqsrQob9voEq1bNn1mtXl3NeSmzWGfCQ4DVxe5nA8dH/SgffOBL5oVVqeITffPmft2sGRx8cNHtgw4qWjTrMM7BqlWwYAF8843vC168GL791ifRsJQU/xl62GHQpQv84Q9+2rfw0rixb4XWr+/7kKtVi0JgeXlFXwH2/koQXu9r2bGjaL1jB257LoU788jPzSNvRwH5O/LJ35GP25VH4a78oiWvgCq7dlCFAqpQiOGoQiHV2EU1dpFKHikFhVAA7AB+ifB9lkXVqv5vNi1tz0/OGjX8pau1a/ulTp2i2/Xq7bnUrw8NGvhO/bp1iz5VJSkF3tw1syHAEIBDDz20fDvp39+fqVq92rfqV68uur18uR9IvGlTya+tU8cn+0aN/NKwYdG6QYM9s1b9+v6fIvwPVL16Qv6DbNrkv/h89plfsrJgw4ai59PToU0bP2qjTRvfjZCe7j8zq1bFN6P3lVhX5cLSvZLw3kl674S9v8fD/ThlVb26T3zhZJiWhlWvTkpaGilpaVSvHUqU4RZyuLVcrVrR7b37e8J9ROFWdvG+pOJ9SuGvJcUX8B9aey/Fv6aEv6oU78cKfx3ZuXPPJfwBVvxnt3Vr0der8JKbu/+fU0qK/7tu2BCaNNlzOegg/0lefKlXLyH/5iszc3t3gkZz52ZdgTudc71D928DcM7dV9L2GRkZLisrKzbB7NjhO4/XrvWzIuTk+HV42bjRL5s2+fWWLQfeZ9WqRQm/Zs2ipVYtn1iqVy9qdYWTyd6dxeGksXeyCCeH4gkj/Lsqvg4niXCiKCws6s8IJY1NW6oyd2Vz5qw8lDmrWvJVTjMcVUixAtrVyyaj/ndk1PmWTjWX0K7acmrlby5KJOF18aW8iTclZc+fT/GfWbglWtJjxbct3not3ootvqSlqVsjLD/f/y1v3ly0/PrrnicOwn/zGzYU/W9s2FDy7zktDQ45xH/ih9fNm8Ohh/qveOnpOqkQADNb4JzLKPG5GCf5qsAyoCewBvgMuMA5901J28c0yZdVXt6e/xTh9ZYt8Ntvey57tz63bft9YszNLeogzs+PaegO+Jqj+Tdn8gZ9+YTjcVQhjVxOTJlP99R5dEv7hE41l1CzhitqvVartucHUvh28e6BvW8XT6wlJd/iyTk1NabvW6KosNB/EKxb5xtH4WXtWj88KDvbL2vW+L/r4urW9Qn/8MOLziiHl/R0nS+LgcCSfOjgpwOP4YdQjnPO3buvbeMqycdSuK+5+NCS8Dq8FP9KHx7mUfyrf3gdavkXUoUPPq3OlP/U4N8zq7MqOwWAjE6F9D0Dep5ahc6dfc4WiZrCQt/yX7XKn6X/4Yei9cqVfil+Yic11Sf71q2LljZtoF07FdaJQKBJviwqTZKPoq++ggkTfAnZ1at9g/mUU+DMM+GMM/z5ZpHAFBb61v933/mLApYt82f0v/3W3y8+5KlFC5/s27XzFwgce6z/AIj4DH7y21+S1/emBPTLL7587NixsGiR//bbuzfcf78/B6364BI3qlTxffeHHAInn7znc/n5vsW/eLEf1hVe3nvPnwsC3/Jv08Yn/I4dISPD365du4LfSOJSSz6BLFjgJ3qYONF38WdmwuDBvoRskyZBRycSJQUFflTcl1/CF18Urdeu9c+b+cSfkQGdO/t/hA4dKvU5H3XXJLD8fJgyxU+0/NlnvpX+17/C3/7m/65FKo21a31LZ8ECP+73s8/8iWHwJ/4zMnzCP+EEX0S/ErV8lOQT0M6dvkvmgQf8uaujjvLzcA4e7Icqi1R6zvkTUeESnfPn+w+A8Gif1q3hpJP80q2bH/GTpJTkE8jWrTBmDDz8sC9c1bkz3H479Ounod8iB7Rzp0/0H3zgl48+8kOfwQ/p7NHDL3/+s7+4K0koySeAvDw/Rdt//7f/Btqjh0/uPXroAkORciss9KMT5szxJ3TnzClK+m3b+hELvXv7k8IJXJBfST6OOefn5Lz9dj+67KST/CiZE04IOjKRJFRQ4E/ivvsuzJrlS57s3On79E8+Gfr0gb59oVWroCMtEyX5ODVvHtx4o+9KbNvWJ/e+fdVyF6kw27f7RP/WW/D227B0qX/8yCP9xSZ9+8KJJ8b9yB0l+TizYQPccguMG+cvVrrrLrj4Yl3tLRK477+H//wH/v1v37Wza5evxXPmmX5W9F69fJmOOKMkHycKC+H5532C37wZrr8e/vUvXdchEpd++8136Uyf7pdffvH99qed5hP+mWfGTSkGJfk4sGgRDB3qT/b/6U/w9NPQvn3QUYlIqeTl+W6dV16B117zQ9+qV4fTT4dBg3wNkQAvNd9fktegvBgrKIAHH4ROnXx339ix8P77SvAiCSU1FXr2hFGj/Nj8jz6Cq67yJ9TOP9/X3h80yHfzRHMKsihQSz6GVq70fe0ffggDBsAzz1Sqi/BEkl9BgR+P//LLMHWqP+HWpIlP+IMH+6twK2AkhVryFcw5P+a9QwdfJfLFF2HaNCV4kaSTkgLdu/v+159+8n333bv7Kxq7dPHD5h54wNfiD4iSfJT98gucdRZceaX/HX/9tf9A17BIkSSXmupPxk6e7JP6s8/6iY9vvdXPntW/v/8QiPGkQXtTko+ihQt93/uMGb6g2KxZflY0Ealk6teHK67wXTnffgv//Cd8+qlP9C1awLBhfnKVCqAkHwXO+W9nJ5zgP6Q/+ACuu061ZkQEf2HV/ff72bNef90XpLr/fl9L58wz4c03yz9vcikoDUVo+3a45BJ/or1bN9+aP/74oKMSkbiTmuorDU6f7i+6uu02Xy75jDPgiCP8ybsYUJKPwKpV0LUr/O//wvDh/gO5ceOgoxKRuHfooXDPPT6JvPyyn+B8+/aYHEoX0pdTVpb/prV9u78K+rTTgo5IRBJOtWpw3nl+idFwdrXky+GVV3zBuurVfZExJXgRiViMhuBFlOTNbISZLTWzr8zsVTOrX+y528xshZl9a2a9I440Djjnr1495xw/Bv6TT/zE8iIi8SrSlvwsoL1z7hhgGXAbgJm1BQYB7YA+wFNmlhLhsQKVn+9Prt5yi/9m9e670LRp0FGJiOxfREneOTfTORce2T8faB663R+Y5Jzb6Zz7HlgBdInkWEHaudNfpfzss35yj4kTE3oSGRGpRKLZJ38ZMCN0+xBgdbHnskOP/Y6ZDTGzLDPLysnJiWI40bF9u79+Ydo0f4HTvfdq/LuIJI4Djq4xs9lASTPeDnPOvR7aZhiQD0woawDOuTHAGPAFysr6+ljavNlPDDNvnq8eedllQUckIlI2B0zyzrlT9ve8mV0C9AV6uqKSlmuAFsU2ax56LGHk5PjpHr/+GiZNgoEDg45IRKTsIh1d0we4GejnnCs+kn86MMjMqptZS6AV8Gkkx6pI69b5q1cXL/ZXISvBi0iiivRiqCeB6sAs82M85zvnhjrnvjGzycBifDfO1c652BVniKING+CUU3ztoLfe8sleRCRRRZTknXNH7Oe5e4F7I9l/Rfv1Vz9P74oV/ipWJXgRSXQqaxCyZYvvg//mG99F06NH0BGJiEROSR7Yts0XgluwwM/g1adP0BGJiERHpU/yO3b46p/z5vlRNP37Bx2RiEj0VOokX1joJ9p+911fylmjaEQk2VTqazdvuslPxzhihJ+HVUQk2VTaJP/YY/DII/Bf/wU33hh0NCIisVEpk/yUKXDDDXD22b4eTYzKOIuIBK7SJfkPPvBdMyecAC+9BCkJXQBZRGT/KlWSX7bMj6RJT/dj4VUuWESSXaVJ8ps3++GRVavCjBnQqFHQEYmIxF6lGEJZUAAXXujLFcyeDS1bBh2RiEjFqBRJ/v/9P1+L5qmnVI9GRCqXpO+ueflluO8+GDIEhg4NOhoRkYqV1En+88/h0kvhxBPhiSc0VFJEKp+kTfI5OXDWWf4E67RpUK1a0BGJiFS8pOyTLyz0Y+HXrYMPP4SmTYOOSEQkGEmZ5B98EN5+259ozcgIOhoRkeAkXXfNhx/CHXfAeefpRKuISFIl+Q0bYNAgf0Xrs8/qRKuISNJ01xQWwkUX+ROu8+dD3bpBRyQiErykSfIjRvhyBaNGQceOQUcjIhIfotJdY2Y3mpkzs8ah+2ZmI81shZl9ZWbHReM4+/LRRzBsmJ/Z6W9/i+WRREQSS8RJ3sxaAL2AVcUePg1oFVqGAE9Hepz9qVkTTjlF/fAiInuLRkv+UeBmwBV7rD/wovPmA/XNrFkUjlWijh3hrbegXr1YHUFEJDFFlOTNrD+wxjn35V5PHQKsLnY/O/RYSfsYYmZZZpaVk5MTSTgiIrKXA554NbPZwB9KeGoYcDu+q6bcnHNjgDEAGRkZ7gCbi4hIGRwwyTvnTinpcTM7GmgJfGm+I7w5sNDMugBrgBbFNm8eekxERCpQubtrnHNfO+cOcs6lO+fS8V0yxznnfgamAxeFRtlkApudc2ujE7KIiJRWrMbJvwmcDqwAtgOXxug4IiKyH1FL8qHWfPi2A66O1r5FRKR8kqp2jYiI7ElJXkQkiZnvWYkPZpYD/FjOlzcGNkQxnCDpvcSnZHkvyfI+QO8l7DDnXJOSnoirJB8JM8tyziXFFCF6L/EpWd5LsrwP0HspDXXXiIgkMSV5EZEklkxJfkzQAUSR3kt8Spb3kizvA/ReDihp+uRFROT3kqklLyIie0mqJG9md4dmovrCzGaa2cFBx1ReZjbCzJaG3s+rZlY/6JjKy8wGmtk3ZlZoZgk3EsLM+pjZt6GZzm4NOp7yMrNxZrbezBYFHUukzKyFmb1nZotDf1vXBh1TeZhZmpl9amZfht7Hf0f9GMnUXWNmdZ1zW0K3/wG0dc4NDTiscjGzXsC7zrl8M3sAwDl3S8BhlYuZtQEKgWeAfzrnsgIOqdTMLAVYBpyKL8L3GfAX59ziQAMrBzM7GdiKn9CnfdDxRCI0CVEz59xCM6sDLADOSrTfi/kSvrWcc1vNLBX4ELg2NNlSVCRVSz6c4ENqsedsVQnFOTfTOZcfujsfX645ITnnljjnvg06jnLqAqxwzq10zu0CJuFnPks4zrm5wKag44gG59xa59zC0O3fgCXsY2KieBaaPW9r6G5qaIlq3kqqJA9gZvea2WrgQuBfQccTJZcBM4IOopIq9SxnEgwzSwc6Ap8EHEq5mFmKmX0BrAdmOeei+j4SLsmb2WwzW1TC0h/AOTfMOdcCmABcE2y0+3eg9xLaZhiQj38/cas070Uk2sysNjANuG6vb/IJwzlX4Jw7Fv9tvYuZRbUrLVb15GNmXzNVlWACvq798BiGE5EDvRczuwToC/R0cX7ypAy/l0SjWc7iVKgPexowwTn3StDxRMo596uZvQf0AaJ2cjzhWvL7Y2atit3tDywNKpZImVkf4Gagn3Nue9DxVGKfAa3MrKWZVQMG4Wc+kwCFTliOBZY45x4JOp7yMrMm4ZFzZlYDf4I/qnkr2UbXTANa40dy/AgMdc4lZKvLzFYA1YGNoYfmJ/BIoQHAE0AT4FfgC+dc70CDKgMzOx14DEgBxjnn7g02ovIxs4lAd3y1w3XAcOfc2ECDKicz+xPwAfA1/v8d4Hbn3JvBRVV2ZnYM8AL+b6sKMNk5d1dUj5FMSV5ERPaUVN01IiKyJyV5EZEkpiQvIpLElORFRJKYkryISBJTkhcRSWJK8iIiSUxJXkQkif1/Z7y6+/bBNA8AAAAASUVORK5CYII=\n", - "text/plain": [ - "
" - ] - }, - "metadata": { - "needs_background": "light" - }, - "output_type": "display_data" - } - ], - "source": [ - "# 画出更新之前的模型\n", - "y_pred = multi_linear(x_train)\n", - "\n", - "plt.plot(x_train.data.numpy()[:, 0], y_pred.data.numpy(), label='fitting curve', color='r')\n", - "plt.plot(x_train.data.numpy()[:, 0], y_sample, label='real curve', color='b')\n", - "plt.legend()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "可以发现,这两条曲线之间存在差异,我们计算一下他们之间的误差" - ] - }, - { - "cell_type": "code", - "execution_count": 22, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "tensor(1144.2655, grad_fn=)\n" - ] - } - ], - "source": [ - "# 计算误差,这里的误差和一元的线性模型的误差是相同的,前面已经定义过了 get_loss\n", - "loss = get_loss(y_pred, y_train)\n", - "print(loss)" - ] - }, - { - "cell_type": "code", - "execution_count": 23, - "metadata": {}, - "outputs": [], - "source": [ - "# 自动求导\n", - "loss.backward()" - ] - }, - { - "cell_type": "code", - "execution_count": 24, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "tensor([[ -94.7455],\n", - " [-139.1247],\n", - " [-629.8584]])\n", - "tensor([-25.7413])\n" - ] - } - ], - "source": [ - "# 查看一下 w 和 b 的梯度\n", - "print(w.grad)\n", - "print(b.grad)" - ] - }, - { - "cell_type": "code", - "execution_count": 25, - "metadata": {}, - "outputs": [], - "source": [ - "# 更新一下参数\n", - "w.data = w.data - 0.001 * w.grad.data\n", - "b.data = b.data - 0.001 * b.grad.data" - ] - }, - { - "cell_type": "code", - "execution_count": 26, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "" - ] - }, - "execution_count": 26, - "metadata": {}, - "output_type": "execute_result" - }, - { - "data": { - "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXkAAAD7CAYAAACPDORaAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAAApLElEQVR4nO3de5yV4/7/8denaWo6n0tEE0IHElMmNqVI6FuiCJuctbEdt2PsfB2+DiFySpTy0y4dHNo2SUibClMOpaJEGqVGqVRTzeH6/XGt1UyZqZlZa+Zea837+Xhcj3W+789qNZ91reu+7s9lzjlERCQxVQk6ABERKT9K8iIiCUxJXkQkgSnJi4gkMCV5EZEEpiQvIpLASpzkzWyMma0zs0WF7mtoZu+b2bLQZYPQ/WZmI8xsuZl9Y2bHlEfwIiKyd6XpyY8Feu1x3x3AB8651sAHodsApwOtQ+0q4PnIwhQRkbKw0pwMZWapwNvOufah298B3Zxza8ysOTDLOXe4mb0Quj5hz+ftbfuNGzd2qampZXsnIiKV1Pz5839zzjUp6rGqEW67WaHE/SvQLHT9AGBVoedlhu7ba5JPTU0lIyMjwpBERCoXM1tZ3GNRO/Dq/E+CUtdIMLOrzCzDzDKysrKiFY6IiBB5kl8bGqYhdLkudP8vwIGFntcidN+fOOdGOefSnHNpTZoU+WtDRETKKNIkPw0YFLo+CHir0P0Xh2bZpAOb9jUeLyIi0VfiMXkzmwB0AxqbWSYwFHgYmGRmlwMrgXNDT38HOANYDmwDLi1rgDk5OWRmZrJ9+/aybkIikJKSQosWLUhOTg46FBEpgxIneefc+cU81KOI5zrg2rIGVVhmZiZ16tQhNTUVM4vGJqWEnHOsX7+ezMxMWrVqFXQ4IlIGMX/G6/bt22nUqJESfADMjEaNGulXlEgci/kkDyjBB0j/9iLxLS6SvIhIIrvvPpg9u3y2rSRfAiNGjKBNmzZceOGFTJs2jYcffhiAN998k8WLF+963tixY1m9evWu21dcccVuj4uI7On772HoUPj44/LZfqRnvFYKzz33HDNnzqRFixYA9OnTB/BJvnfv3rRt2xbwSb59+/bsv//+ALz00kvBBFxIbm4uVavqYxaJVc89B8nJcOWV5bN99eT3YfDgwaxYsYLTTz+d4cOHM3bsWK677jrmzJnDtGnTuPXWWzn66KN55JFHyMjI4MILL+Too48mOzubbt267SrTULt2bYYMGUKHDh1IT09n7dq1APzwww+kp6dz5JFHcvfdd1O7du0i43jllVc46qij6NChAxdddBEAl1xyCVOmTNn1nPBrZ82axYknnkifPn1o27Ytd9xxB88+++yu591777089thjAAwbNoxOnTpx1FFHMXTo0Oj/A4pIsbZsgZdfhgEDYL/9ymcf8dXFu/FG+Oqr6G7z6KPhySeLfXjkyJFMnz6djz76iMaNGzN27FgAjj/+ePr06UPv3r3p378/AO+++y6PPfYYaWlpf9rO1q1bSU9P58EHH+S2227jxRdf5O677+aGG27ghhtu4Pzzz2fkyJFFxvDtt9/ywAMPMGfOHBo3bsyGDRv2+bYWLFjAokWLaNWqFV9++SU33ngj117rZ7VOmjSJ9957jxkzZrBs2TI+//xznHP06dOH2bNnc9JJJ+1z+yISufHjYfNmuDYqE86Lpp58BalWrRq9e/cG4Nhjj+Wnn34CYO7cuQwYMACACy64oMjXfvjhhwwYMIDGjRsD0LBhw33ur3Pnzrvmtnfs2JF169axevVqvv76axo0aMCBBx7IjBkzmDFjBh07duSYY45h6dKlLFu2LNK3KiIl4Bw88wx07AhdupTffuKrJ7+XHnesS05O3jUdMSkpidzc3Ii3WbVqVfLz8wHIz89n586dux6rVavWbs8dMGAAU6ZM4ddff+W8884D/MlOd955J1dffXXEsYhI6cyeDYsWwejRUJ4zldWTj0CdOnX4448/ir1dEunp6UydOhWAiRMnFvmc7t27M3nyZNavXw+wa7gmNTWV+fPnAzBt2jRycnKK3c95553HxIkTmTJlyq5fDqeddhpjxoxhy5YtAPzyyy+sW7eu2G2ISPQ88ww0aAADB5bvfpTkIzBw4ECGDRtGx44d+eGHH7jkkksYPHjwrgOvJfHkk0/yxBNPcNRRR7F8+XLq1av3p+e0a9eOIUOG0LVrVzp06MDNN98MwJVXXsnHH39Mhw4dmDt37p9673tu448//uCAAw6gefPmAPTs2ZMLLriALl26cOSRR9K/f/9Sf0mJSOllZsIbb8Dll0PNmuW7r1KtDFXe0tLS3J6LhixZsoQ2bdoEFFH527ZtGzVq1MDMmDhxIhMmTOCtt97a9wsrUKJ/BiIV7Z//hAcegOXL4eCDI9+emc13zv15xgfxNiafgObPn891112Hc4769eszZsyYoEMSkXK0Ywe88AKceWZ0Evy+KMkH7MQTT+Trr78OOgwRqSBTp8K6deU7bbIwjcmLiFSgZ56BQw+Fnj0rZn9K8iIiFeSLL2DuXN+Lr1JB2VdJXkSkgjzxBNStC5ddVnH7VJIXEakAP/8Mkyf7QmR161bcfpXkK0Bqaiq//fZb0GGISIBGjPCX119fsftVki8F59yuMgKKQ0RKavNmePFFX23yoIMqdt9K8vvw008/cfjhh3PxxRfTvn17Vq1aVWx53rPOOotjjz2Wdu3aMWrUqH1ue/r06RxzzDF06NCBHj38euiFywADtG/fnp9++ulPcdx///3ceuutu54XLoEM8Oqrr9K5c2eOPvporr76avLy8qL1zyEiZTB6tE/0t9xS8fuOyjx5M7sJuAJwwELgUqA5MBFoBMwHLnLO7Sx2IyUQQKVhAJYtW8a4ceNIT0/fa3neMWPG0LBhQ7Kzs+nUqRPnnHMOjRo1KnKbWVlZXHnllcyePZtWrVqVqHxw4TiysrLo0qULw4YNA+C1115jyJAhLFmyhNdee41PP/2U5ORkrrnmGsaPH8/FF19cyn8ZEYmG3Fx46ik48UQoogp5uYs4yZvZAcD1QFvnXLaZTQIGAmcAw51zE81sJHA58Hyk+wtCy5YtSU9PB9itPC/Ali1bWLZsGSeddBIjRozgjTfeAGDVqlUsW7as2CQ/b948TjrppF3lgEtSPrhwHE2aNOHggw9m3rx5tG7dmqVLl3LCCSfw7LPPMn/+fDp16gRAdnY2TZs2jewfQETK7PXXYeVKn+iDEK0zXqsCNcwsB6gJrAG6A+EC6eOAe4kwyQdVabhw4a/iyvPOmjWLmTNnMnfuXGrWrEm3bt3Yvn17qfdVuHwwsNs29ixANnDgQCZNmsQRRxxBv379MDOccwwaNIiHHnqo1PsWkehyDh5/3J/8FFpOosJFPCbvnPsFeAz4GZ/cN+GHZzY658JF0zOBA4p6vZldZWYZZpaRlZUVaTjlrrjyvJs2baJBgwbUrFmTpUuXMm/evL1uJz09ndmzZ/Pjjz8Cu5cPXrBgAeBXdwo/XpR+/frx1ltvMWHCBAaG6pX26NGDKVOm7CoZvGHDBlauXBnZmxaRMpkzBz7/HG66CZKSgokhGsM1DYC+QCtgIzAZ6FXS1zvnRgGjwFehjDSe8tazZ0+WLFlCl9BSLrVr1+bVV1+lV69ejBw5kjZt2nD44YfvGlYpTpMmTRg1ahRnn302+fn5NG3alPfff59zzjmHV155hXbt2nHcccdx2GGHFbuNBg0a0KZNGxYvXkznzp0BaNu2LQ888AA9e/YkPz+f5ORknn32WVq2bBm9fwQRKZEnnvA14wcNCi6GiEsNm9kAoJdz7vLQ7YuBLsAAYD/nXK6ZdQHudc6dtrdtVcZSw/FAn4FI6S1fDocdBnfcAf/3f+W7r72VGo7GFMqfgXQzq2l+fbsewGLgI6B/6DmDgNgqki4iUo4efRSqV4cbbgg2jmiMyX8GTAEW4KdPVsEPv9wO3Gxmy/HTKEdHui8RkXjwyy8wdqyvUdOsWbCxRGV2jXNuKDB0j7tXAJ2jtP1di2BLxYqllcNE4sXjj0N+PhQ6XzEwMX/Ga0pKCuvXr1eyCYBzjvXr15OSkhJ0KCJx47ff/MpPF14IqalBRxMHK0O1aNGCzMxM4mF6ZSJKSUmhRYsWQYchEjdGjIDsbH/ANRbEfJJPTk7edVaoiEgs27wZnn4a+vWDWJmQFvPDNSIi8WLkSNi4Ee68M+hICijJi4hEQXa2P/mpZ89gCpEVR0leRCQKXn4Z1q6NrV48KMmLiEQsJ8ef/NSlC3TtGnQ0u4v5A68iIrHulVd8OeFnnoFYO6VHPXkRkQjs3An33+/H4c88M+ho/kw9eRGRCLz8su/FP/987PXiQT15EZEy27EDHngA0tOhV4kLrFcs9eRFRMropZcgM9P35mOxFw/qyYuIlEl2tq8Tf+KJ0KNH0NEUTz15EZEyGDUKVq+G8eNjtxcP6smLiJTatm3w0ENw8snQrVvQ0eydevIiIqX0/PP+7NbJk4OOZN/UkxcRKYUtW+CRR+DUU/14fKxTkhcRKYXhwyErC+67L+hISkZJXkSkhNat8zVq+vXzc+PjgZK8iEgJ3X+/nzr50ENBR1JySvIiIiWwfLlfFOSKK+Dww4OOpuSikuTNrL6ZTTGzpWa2xMy6mFlDM3vfzJaFLhtEY18iIkEYMgSqVYOhQ4OOpHSi1ZN/CpjunDsC6AAsAe4APnDOtQY+CN0WEYk7X3wBkybBP/4BzZsHHU3pRJzkzawecBIwGsA5t9M5txHoC4wLPW0ccFak+xIRqWjOwW23QdOmPsnHm2j05FsBWcDLZvalmb1kZrWAZs65NaHn/Ao0K+rFZnaVmWWYWUZWVlYUwhERiZ5334VZs+Cf/4Q6dYKOpvSikeSrAscAzzvnOgJb2WNoxjnnAFfUi51zo5xzac65tCZNmkQhHBGR6MjLg9tvh0MPhauuCjqasolGks8EMp1zn4VuT8En/bVm1hwgdLkuCvsSEakwY8fCokW+2mRyctDRlE3ESd459yuwyszCk4p6AIuBacCg0H2DgLci3ZeISEXZtAnuugtOOAH69w86mrKLVoGyvwPjzawasAK4FP8FMsnMLgdWAudGaV8iIuXuvvt8+YJ3343tUsL7EpUk75z7Ckgr4qEYLqUvIlK0pUthxAi4/HI45pigo4mMzngVESnEObjxRqhVCx58MOhoIqd68iIihbz9Nrz3nq822bRp0NFETj15EZGQHTvgppugTRu49tqgo4kO9eRFREKefBJ++MH35ON1yuSe1JMXEQHWrIEHHoA+faBnz6CjiR4leRER4OabYedOePzxoCOJLiV5Ean0pk+HiRN9OeFDDw06muhSkheRSm3bNrjmGr8QyO23Bx1N9OnAq4hUavfdBz/+6CtNVq8edDTRp568iFRaCxf6MfjLLoOuXYOOpnwoyYtIpZSf78sH168Pjz4adDTlR8M1IlIpjRoF8+bBK69Ao0ZBR1N+1JMXkUpnzRq44w7o0QP++tegoylfSvIiUqk452fTbN8Ozz8f32WES0LDNSJSqfzrX/DmmzBsGLRuHXQ05U89eRGpNFavhuuug+OP94XIKgMleRGpFJyDK6/0lSbHjoWkpKAjqhgarhGRSmHsWHjnHXjqqcoxTBOmnryIJLxVq/xqT926+eGaykRJXkQSmnN+rda8PBgzBqpUsqyn4RoRSWgjR8L77/vpkq1aBR1NxYvad5qZJZnZl2b2duh2KzP7zMyWm9lrZlYtWvsSESmJRYt8nfjTToOrrw46mmBE84fLDcCSQrcfAYY75w4Ffgcuj+K+RET2KjsbBg6EevVg3LjEP+mpOFFJ8mbWAjgTeCl024DuwJTQU8YBZ0VjXyIiJXHzzfDtt742TbNmQUcTnGj15J8EbgPyQ7cbARudc7mh25nAAUW90MyuMrMMM8vIysqKUjgiUplNnerH4m+9NbHWay2LiJO8mfUG1jnn5pfl9c65Uc65NOdcWpMmTSINR0QquZ9/hiuugE6d/MLclV00ZtecAPQxszOAFKAu8BRQ38yqhnrzLYBforAvEZFi5ebCBRf46ZITJkA1TfeIvCfvnLvTOdfCOZcKDAQ+dM5dCHwE9A89bRDwVqT7EhHZm3vvhU8/9dMlDzkk6GhiQ3meFnA7cLOZLceP0Y8ux32JSCU3bRo8+CBceilceGHQ0cSOqJ4M5ZybBcwKXV8BdI7m9kVEivLdd3DRRXDssfDcc0FHE1sq2Qm+IpJo/vgDzj7bj7+//jqkpAQdUWxRWQMRiVvOwWWXwdKlvnTBQQcFHVHsUZIXkbj12GMwZYpf5al796CjiU0arhGRuPTBB34x7nPPhVtuCTqa2KUkLyJxZ+lS6N8f2rSB0aMrb12aklCSF5G4kpUFZ57pD7S+/TbUrh10RLFNY/IiEjeys6FvX78g98cfQ2pq0BHFPiV5EYkL+flwySUwbx5MngyddRZOiSjJi0hcuPtumDTJz6Q555ygo4kfGpMXkZj30kvw0EN+dSfNpCkdJXkRiWmvv+6T+2mnwdNPayZNaSnJi0jMmjHDL+F33HF+IZDk5KAjij9K8iISkz79FM46C9q2hf/8B2rVCjqi+KQkLyIx56uv/Fz4Fi3gvfegQYOgI4pfSvIiElO+/96vy1q3LsycWbkX4Y4GJXkRiRnff19QaExVJaNDSV5EYsLixdC1K+zc6Xvwhx8edESJQUleRAL3zTfQrZu/PmsWHHVUkNEkFiV5EQnUggVw8sm+4NjHH/vZNBI9SvIiEpjPPvNj8HXqwOzZcNhhQUeUeCJO8mZ2oJl9ZGaLzexbM7shdH9DM3vfzJaFLjUJSkR2ee89OOUUaNzYJ/iDDw46osQUjZ58LnCLc64tkA5ca2ZtgTuAD5xzrYEPQrdFRBgzxs+DP/RQ+O9/NYumPEWc5J1za5xzC0LX/wCWAAcAfYFxoaeNA86KdF8iEt+cg3vvhcsv97342bOhefOgo0psUS01bGapQEfgM6CZc25N6KFfAZ3SIFKJ5eT4QmMvvwyXXgovvKBaNBUhagdezaw2MBW40Tm3ufBjzjkHuGJed5WZZZhZRlZWVrTCEZEY8vvv0Lu3T/BDh/p1WZXgK0ZUkryZJeMT/Hjn3Ouhu9eaWfPQ482BdUW91jk3yjmX5pxLa9KkSTTCEZEYsmgRdOoEH33k68Lfe6/KBVekaMyuMWA0sMQ590Shh6YBg0LXBwFvRbovEYkvkydDejps3epPcrr88qAjqnyi0ZM/AbgI6G5mX4XaGcDDwKlmtgw4JXRbRCqBvDy44w4491x/9ur8+XD88UFHVTlFfODVOfcJUNyPrx6Rbl9E4svatXDRRb7A2ODB8NRT/mxWCYYW8haRqHnnHT9zZvNmP/6u4ZngqayBiERs+3a4/np/glOzZpCRoQQfK5TkRSQiixZB585+ke0bboDPP4d27YKOSsKU5EWkTHJy4KGHIC3Nj8O/+y48+SSkpAQdmRSmMXkRKbXPP4crr/R14Pv3h2ee0TJ9sUo9eREpsS1b4KaboEsXWL8e3nzTz4VXgo9dSvIisk/OwZQp0L69H5IZPNgv19e3b9CRyb5ouEZE9iojw/feP/kEjjzSX55wQtBRSUmpJy8iRcrMhIsv9nVnvv8eRo2CL79Ugo836smLyG7WrYNhw+DZZyE/35cnuPNOqFs36MikLJTkRQSAX3/1yf3552HHDrjgArj/fkhNDToyiYSSvEglt2oVDB8OI0f65P7Xv8KQIVpUO1EoyYtUQs7BnDkwYgRMnervCyf31q2DjU2iS0lepBLZvt1PhXzqKT9rpn59P3Pm2ms1LJOolORFEpxz/gzVceNgwgTYuBGOOAKee87PnqlVK+gIpTwpyYskqJ9+gokTfXJfuhRq1IB+/eCSS6BHD6iiCdSVgpK8SIJwzp+F+vrr8MYbfk47wF/+4mu7DxigaZCVkZK8SBzbsgVmz4aZM+Htt2HZMn9/ly5+OuTZZ8PBBwcbowRLSV4kjmzd6g+YfvyxT+xz50JuLlSvDl27ws03+3oyzZsHHanECiV5kRiVk+PLCWRkwLx5vi1c6BfJNoNjj4V//ANOOcUvkl2jRtARS4k559dIXL0a1qzxl4cf7mtIRJmSvEjAsrNhxQr44QdYssQn8oUL/fWcHP+cevXguOP8PPb0dH+9YcNg45Y95OX5+stZWbu3dev86cRr1xa0NWv8B1/YLbfEZ5I3s17AU0AS8JJz7uHy3qdILHDOj5mvX+//xlev9u2XX/zlTz/B8uX+emEHHuirPZ5+ur/s2NFPedRsmAqQkwObNvle9qZNBe33333buLHg+vr1vm3Y4C83bvQfelEaN/ZF95s189/QzZvD/vvvfnnAAeXylso1yZtZEvAscCqQCXxhZtOcc4vLc79ScbZuLeikbNjg/58Xblu2wLZtBS0725+Qk5vr/55ycvz1vDy/vT3/RpKSClqVKlC1qm/JyQWX4Vat2u6XRV0Pv75wq1Jl92bm48jPL7jMz/ex7tgBO3f6tmOHf/9bt/r3GW6//+7/LTZsKOiJF1a1qv+bbtkSTj0VDj0UDjnEt8MO8ycoSSHO+f8gO3b8uW3fvnvLzi5ohf/jhT+owh/Y1q3wxx8FbfNmv829MfM/q+rXh0aNfDv44ILrjRtDkya7t8aN/X/AgJR3T74zsNw5twLAzCYCfQEl+TiQlwcrV8KPP/peZ+G2Zo1P7Fu2FP/6mjWhTh1/WaNGwWWdOn9O0OHkCgWX4b/tvDyfZPPy/BdCuOXk+L/rzZsLvjB27iz6MtyiJSnJUa1qPrVq5FM7JZdaKXn+sloubevvpFGLHTSsuZ1GtbbTsOZ29quzlf3rbWX/On/QuOY2qri83b9F1ubDmnyYHbqv8LdMSVr4H6zw9X3dV9TtvbX8PWILt/AHFf6QwveHP7DwhxhuhT/EvLzdP6DCLfxtunNn8T3kkkpJgdq1/Zlf4ctatfw3bZ06Ba1uXZ/Ew5fhVr8+NGjg709KiiyWClbeSf4AYFWh25nAceW8Tykl5+Dnn2H+fPj2Wz8WvHgxfPedT6JhSUl+KKFlS+jcGfbbr+AXaLNmvsNSv75v9er5nnPU5eX5b5bNm3fvhYV7ZuHu9Natu/fktm3Dbd1GfvYOcrNzyNmWQ252Drnbc3Hbd5C/M5f8HTm+5eRShfxdzXBUIZ9q7KQaO0kmh6S8fMgD9tHxqzCFvyH3/Lbc231F3S6qhb+Fw9cL/+wx2/3nVrhVrbr7T7GkJH9ftWr+Gz98X+Fv+8I/uwq35GQ/haioVqOGT+Lhy5QUn8DDvYoaNSr1WFfgB17N7CrgKoCDDjoo4Ggqhw0b4NNP4YsvfMvIgN9+K3g8NRXatPGzNtq08cMIqanQooX/G41YTk7BOGbhcc3wWGd4vGPjxt3HRTdu9Mm7pMLJJNSsRg2SatYkKSWF6k1r7p4UwgkjJcW/rnr1ggRTvXrRiSjcCo/9hBNZUQkunAT3TIaFE+aelyVpIntR3kn+F+DAQrdbhO7bxTk3ChgFkJaWFuFvMinKhg3+hJlZs3z75hvfe09KgnbtoE8fSEvzU/LatStDLZPwdLDw4Hzhy3Xr/jzTYNOm4rdVpYr/WdygQcHPgv333/2nc+Gf1nXq+J/fdeoU/BSvXdsn9qh8I4nEt/L+K/gCaG1mrfDJfSBwQTnvs9Jzzk/B+/e//VmQn33m70tJ8Uu33XefP3Hm2GN9LtyrvDyfsFet8i0z07fwFJFw27btz69NSio4+NS0qd9h+EBU+EBVuDVs6FudOuqdikRRuSZ551yumV0HvIefQjnGOfdtee6zssrPh//+FyZP9sn955/9/WlpMHSoL0jVqZMfedhNXp5P2CtW+LZypT+yunKlb5mZ/gBZYSkpfuxm//39DsJTwJo3Lxio328/n7Qr8VioSCwo99+zzrl3gHfKez+V1TffwPjxvoTsqlV+mPmUU+Duu+HMM33+JSfHT5GZucwXN1m2zE/QDif1wtNOqlTxL2rZ0p9G2bIlHHSQT+oHHugvGzZUb1skTmjQMg79/rsvHzt6NCxa5IeeTzsNHr5nC31bLaTWz0t8bdnBS/3lihUFE9HBj2sfeigccwz07+/n+R58MLRq5RN5gHN6RSS6lOTjyPz5fqGHCRMc2dlGequ1PNttNgPyJtJkwVz4z5qCJ1ev7s+sOfpoOO88v6ZbuDVqpJ64SCWhJB/LnCN3xc9Mfm4dwyfuzxerD6CWbeVi9//4G8/T4cdvYG1NaNsWevb0U2PatvXzHlu2jLuTNkQk+pTkY4VzflglIwPmz2dHxkLGfXYEj2y7jhV04giW8HTTMVx0/HLqHdsajrrfFzZp2VIHN0WkWEryQcnK8rVjP/us4IykDRvYQi1GJV3D41XGsTqnKZ1aruXxa76nzzUtqVL7nqCjFpE4oyRfEfLy/DSYOXN8Yp8719eVhV1nJOX07c9L2//K/85IZ+36ZLp3hVfugu7dm2HWLNj4RSRuKcmXh23b4PPP4ZNPfJszx9dXAT9/vEsXuPpqSE/HHXMsr0+vyV13+QUiTjwRXn/Yz14UEYmUknw0ZGf73nm4bsBnn/nKeWbQvj389a9+NeUTTvBzzkMzW+bMgVtO8Z37tm1h2jTo3VsTX0QkepTkyyI314+jz5zp27x5PqlXqeLPAL3xRt8lP+EEX4NlD7/9BrffDmPG+POOXnoJBg1SqRURiT6llZJatgymT/dJfdYsX5DLzJ9QdP31cPLJvrdet26xm8jPh5df9gl+0ya49Vb45z99PS0RkfKgJF+crVvho498Yp8+veBA6SGHwPnn+9oBJ5/sTywqgUWLYPBgX+L3L3+B55/3IzkiIuVJSb6wlSt92cZ//9sn+J07fZnGHj3g5pt97YBDDinVJvPy4PHH4Z57fIHF0aPhkks0tV1EKkblTvL5+X5++ptv+uS+cKG//7DD4Lrr4IwzfLf7T6UbS2bFCj/W/skn0K8fvPCCr7QrIlJRKl+Sz8mBjz+GN96At97yZXaTkuCkk3yXu3dvn+Qj4Jzvsd90k++xv/KKn2CjWTMiUtEqR5LfudMfMJ082Sf233/3NXl79fJd7DPP9OVzo+D33/1wzLRp0L27P9CqVQ1FJCiJm+R37oT33y9I7Bs3+hK7ffvC2WfDqaeWYFmk0lmwwFfuzcyE4cP9pBuNvYtIkBIryYeXR/rXv2DKFL+4ab16cNZZMGCAnxFTxvH1vXEOXnzRJ/WmTX0Ixx0X9d2IiJRaYiT5Zcv8Uc2JE/0Ye61avsd+/vm+x14OiT1s2zb429/8uHvPnn6VpsaNy213IiKlkhhJfskSGDECTj8dHnsM/ud/fKIvZz//7He1cKFfR/Wee1TCXURiS2Ik+V694Ndfo3bwtCQyMnyC37YN/vMf//0iIhJrEuOwYLVqFZrgX3/dz7isXt0XGVOCF5FYFVGSN7NhZrbUzL4xszfMrH6hx+40s+Vm9p2ZnRZxpDHAOXj0UTjnHOjQwRebbNcu6KhERIoXaU/+faC9c+4o4HvgTgAzawsMBNoBvYDnzCyuR6tzc30J+Ntvh3PPhQ8/hGZay0NEYlxESd45N8M5lxu6OQ9oEbreF5jonNvhnPsRWA50jmRfQdqxAwYO9NMk77oLJkzw51KJiMS6aI7JXwa8G7p+ALCq0GOZofv+xMyuMrMMM8vIysqKYjjRsW2bn405dao/wenBB3WCk4jEj33OrjGzmcB+RTw0xDn3Vug5Q4BcYHxpA3DOjQJGAaSlpbnSvr48bdrkS9nMmeNr0Vx2WdARiYiUzj6TvHPulL09bmaXAL2BHs65cJL+BTiw0NNahO6LG1lZfmbmwoX+HKsBA4KOSESk9CKdXdMLuA3o45zbVuihacBAM6tuZq2A1sDnkeyrIq1dC127wuLFvuyNEryIxKtIT4Z6BqgOvG++ju4859xg59y3ZjYJWIwfxrnWOZcX4b4qxG+/+RI3K1f6BaG6dg06IhGRsosoyTvnDt3LYw8CD0ay/Yq2caOvP7N8uT+LVQleROJdYpQ1iILNm/0Y/Lff+iGa7t2DjkhEJHJK8vg1u888E+bP9xWKe/UKOiIRkeio9El++3bo08dPk5w40c+JFxFJFJU6yefn+4W2P/zQ14PXLBoRSTSV+tzNW2+FSZNg2DC46KKgoxERib5Km+SffBKeeAL+/ne45ZagoxERKR+VMslPngw33+zX8x4+HPwUfxGRxFPpkvx//+uHZo4/Hl59Vcv1iUhiq1RJ/vvv/Uya1FQ/F17lgkUk0VWaJL9pk58eWbUqvPsuNGoUdEQiIuWvUkyhzMuDCy/05QpmzoRWrYKOSESkYlSKJH/PPb4WzXPPqR6NiFQuCT9c89pr8NBDcNVVMHhw0NGIiFSshE7yX34Jl14KJ5wATz+tqZIiUvkkbJLPyoKzzvIHWKdOhWrVgo5IRKTiJeSYfH6+nwu/di188gk0axZ0RCIiwUjIJP/oo/Dee/5Aa1pa0NGIiAQn4YZrPvkE7r4bzj1XB1pFRBIqyf/2Gwwc6M9offFFHWgVEUmY4Zr8fLj4Yn/Add48qFs36IhERIKXMEl+2DBfruDZZ6Fjx6CjERGJDVEZrjGzW8zMmVnj0G0zsxFmttzMvjGzY6Kxn+J8+ikMGeJXdvrb38pzTyIi8SXiJG9mBwI9gZ8L3X060DrUrgKej3Q/e1OzJpxyisbhRUT2FI2e/HDgNsAVuq8v8Irz5gH1zax5FPZVpI4dYfp0qFevvPYgIhKfIkryZtYX+MU59/UeDx0ArCp0OzN0X1HbuMrMMswsIysrK5JwRERkD/s88GpmM4H9inhoCHAXfqimzJxzo4BRAGlpaW4fTxcRkVLYZ5J3zp1S1P1mdiTQCvja/EB4C2CBmXUGfgEOLPT0FqH7RESkApV5uMY5t9A519Q5l+qcS8UPyRzjnPsVmAZcHJplkw5scs6tiU7IIiJSUuU1T/4d4AxgObANuLSc9iMiInsRtSQf6s2Hrzvg2mhtW0REyiahateIiMjulORFRBKY+ZGV2GBmWcDKMr68MfBbFMMJkt5LbEqU95Io7wP0XsJaOueaFPVATCX5SJhZhnMuIZYI0XuJTYnyXhLlfYDeS0louEZEJIEpyYuIJLBESvKjgg4givReYlOivJdEeR+g97JPCTMmLyIif5ZIPXkREdlDQiV5M7s/tBLVV2Y2w8z2DzqmsjKzYWa2NPR+3jCz+kHHVFZmNsDMvjWzfDOLu5kQZtbLzL4LrXR2R9DxlJWZjTGzdWa2KOhYImVmB5rZR2a2OPR/64agYyoLM0sxs8/N7OvQ+/jfqO8jkYZrzKyuc25z6Pr1QFvn3OCAwyoTM+sJfOicyzWzRwCcc7cHHFaZmFkbIB94AfiHcy4j4JBKzMySgO+BU/FF+L4AznfOLQ40sDIws5OALfgFfdoHHU8kQosQNXfOLTCzOsB84Kx4+1zMl/Ct5ZzbYmbJwCfADaHFlqIioXry4QQfUovdV6uKK865Gc653NDNefhyzXHJObfEOfdd0HGUUWdguXNuhXNuJzARv/JZ3HHOzQY2BB1HNDjn1jjnFoSu/wEsoZiFiWJZaPW8LaGbyaEW1byVUEkewMweNLNVwIXAP4OOJ0ouA94NOohKqsSrnEkwzCwV6Ah8FnAoZWJmSWb2FbAOeN85F9X3EXdJ3sxmmtmiIlpfAOfcEOfcgcB44Lpgo927fb2X0HOGALn49xOzSvJeRKLNzGoDU4Eb9/glHzecc3nOuaPxv9Y7m1lUh9LKq558uSlupaoijMfXtR9ajuFEZF/vxcwuAXoDPVyMHzwpxecSb7TKWYwKjWFPBcY7514POp5IOec2mtlHQC8gagfH464nvzdm1rrQzb7A0qBiiZSZ9QJuA/o457YFHU8l9gXQ2sxamVk1YCB+5TMJUOiA5WhgiXPuiaDjKSszaxKeOWdmNfAH+KOatxJtds1U4HD8TI6VwGDnXFz2usxsOVAdWB+6a14czxTqBzwNNAE2Al85504LNKhSMLMzgCeBJGCMc+7BYCMqGzObAHTDVztcCwx1zo0ONKgyMrO/AP8FFuL/3gHucs69E1xUpWdmRwHj8P+3qgCTnHP3RXUfiZTkRURkdwk1XCMiIrtTkhcRSWBK8iIiCUxJXkQkgSnJi4gkMCV5EZEEpiQvIpLAlORFRBLY/wdVsP4jv7Ev2wAAAABJRU5ErkJggg==\n", - "text/plain": [ - "
" - ] - }, - "metadata": { - "needs_background": "light" - }, - "output_type": "display_data" - } - ], - "source": [ - "# 画出更新一次之后的模型\n", - "y_pred = multi_linear(x_train)\n", - "\n", - "plt.plot(x_train.data.numpy()[:, 0], y_pred.data.numpy(), label='fitting curve', color='r')\n", - "plt.plot(x_train.data.numpy()[:, 0], y_sample, label='real curve', color='b')\n", - "plt.legend()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "因为只更新了一次,所以两条曲线之间的差异仍然存在,我们进行 100 次迭代" - ] - }, - { - "cell_type": "code", - "execution_count": 27, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "epoch 20, Loss: 65.56586\n", - "epoch 40, Loss: 15.41177\n", - "epoch 60, Loss: 3.70702\n", - "epoch 80, Loss: 0.97122\n", - "epoch 100, Loss: 0.32874\n" - ] - } - ], - "source": [ - "# 进行 100 次参数更新\n", - "for e in range(100):\n", - " y_pred = multi_linear(x_train)\n", - " loss = get_loss(y_pred, y_train)\n", - " \n", - " w.grad.data.zero_()\n", - " b.grad.data.zero_()\n", - " loss.backward()\n", - " \n", - " # 更新参数\n", - " w.data = w.data - 0.001 * w.grad.data\n", - " b.data = b.data - 0.001 * b.grad.data\n", - " if (e + 1) % 20 == 0:\n", - " print('epoch {}, Loss: {:.5f}'.format(e+1, loss.data.item()))" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "可以看到更新完成之后 loss 已经非常小了,我们画出更新之后的曲线对比" - ] - }, - { - "cell_type": "code", - "execution_count": 28, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "" - ] - }, - "execution_count": 28, - "metadata": {}, - "output_type": "execute_result" - }, - { - "data": { - "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXkAAAD7CAYAAACPDORaAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAAAqW0lEQVR4nO3de5yMdf/H8ddnT5Zdp3UKK7tKzhYtrYqUcogciogcuivc3UK5cxfuXyUqN6XIoRVFlFMlFTkfUlQoOW05RBZrt3Vclt3Z+f7+mKGtFrs7M3vNzH6ej8f1mJlrrrm+n2vx9t3vXNf3EmMMSiml/FOA1QUopZTyHA15pZTyYxrySinlxzTklVLKj2nIK6WUH9OQV0opP5brkBeRmSKSLCI7s62LEJGVIrLX+VjauV5EZKKI7BORn0SkkSeKV0opdXV56cm/B7T5y7pngdXGmOrAaudrgLZAdefSD5jqWplKKaXyQ/JyMZSIRAGfG2PqOl//DLQwxhwTkYrAOmNMDRF52/n8w79ud7X9ly1b1kRFReXvSJRSqpDaunXr78aYcjm9F+TivitkC+4koILzeWXgcLbtEp3rrhryUVFRbNmyxcWSlFKqcBGRQ1d6z21fvBrHrwR5niNBRPqJyBYR2ZKSkuKucpRSSuF6yB93DtPgfEx2rj8CVMm2XaRz3d8YY+KNMbHGmNhy5XL8bUMppVQ+uRryS4A+zud9gE+zre/tPMsmDjh9rfF4pZRS7pfrMXkR+RBoAZQVkUTgeeBVYIGIPAocAh50br4UuBfYB5wHHslvgZmZmSQmJnLhwoX87kK5IDQ0lMjISIKDg60uRSmVD7kOeWPMQ1d4q2UO2xrgX/ktKrvExESKFy9OVFQUIuKOXapcMsaQmppKYmIi0dHRVpejlMoHr7/i9cKFC5QpU0YD3gIiQpkyZfS3KKV8mNeHPKABbyH92Svl23wi5JVSyp+NGgUbNnhm3xryuTBx4kRq1apFz549WbJkCa+++ioAixcvZvfu3Ze3e++99zh69Ojl14899tif3ldKqb/65Rd4/nlYv94z+3f1itdCYcqUKaxatYrIyEgAOnToADhCvn379tSuXRtwhHzdunWpVKkSAO+88441BWdjs9kICtI/ZqW81ZQpEBxg4/E63wG3un3/2pO/hgEDBnDgwAHatm3LhAkTeO+99xg4cCDffPMNS5Ys4ZlnnqFBgwaMHTuWLVu20LNnTxo0aEB6ejotWrS4PE1DeHg4I0aMICYmhri4OI4fPw7A/v37iYuLo169eowcOZLw8PAc65g9ezb169cnJiaGXr16AdC3b18WLVp0eZtLn123bh3NmjWjQ4cO1K5dm2effZbJkydf3u6FF15g/PjxAIwbN47GjRtTv359nn/+eff/AJVSV5SWBu/OyKKrfT7X7V7jkTZ8q4s3ZAj8+KN799mgAbzxxhXfnjZtGl9++SVr166lbNmyvPfeewDceuutdOjQgfbt29OlSxcAli1bxvjx44mNjf3bfs6dO0dcXBxjxoxh2LBhTJ8+nZEjRzJ48GAGDx7MQw89xLRp03KsYdeuXYwePZpvvvmGsmXLcuLEiWse1rZt29i5cyfR0dH88MMPDBkyhH/9y3FW64IFC1i+fDkrVqxg7969fPfddxhj6NChAxs2bKB58+bX3L9SynVz5sCZtEAGBsdD/0XX/kA+aE++gISEhNC+fXsAbr75Zg4ePAjApk2b6Nq1KwA9evTI8bNr1qyha9eulC1bFoCIiIhrttekSZPL57Y3bNiQ5ORkjh49yvbt2yldujRVqlRhxYoVrFixgoYNG9KoUSMSEhLYu3evq4eqlMoFY+CtN7NoJD8Q16MaeGhaF9/qyV+lx+3tgoODL5+OGBgYiM1mc3mfQUFB2O12AOx2OxkZGZffCwsL+9O2Xbt2ZdGiRSQlJdGtWzfAcbHTc889R//+/V2uRSmVN+vXw66EQGYyERk8yGPtaE/eBcWLF+fs2bNXfJ0bcXFxfPTRRwDMmzcvx23uuusuFi5cSGpqKsDl4ZqoqCi2bt0KwJIlS8jMzLxiO926dWPevHksWrTo8m8OrVu3ZubMmaSlpQFw5MgRkpOTr7gPpZT7vDXJEBFwku63HoaGDT3Wjoa8C7p37864ceNo2LAh+/fvp2/fvgwYMODyF6+58cYbb/D6669Tv3599u3bR8mSJf+2TZ06dRgxYgR33HEHMTExPP300wA8/vjjrF+/npiYGDZt2vS33vtf93H27FkqV65MxYoVAWjVqhU9evSgadOm1KtXjy5duuT5PymlVN4dPgyLFxses8dT9Ol/erStPN0ZytNiY2PNX28asmfPHmrVqmVRRZ53/vx5ihYtiogwb948PvzwQz799NNrf7AA+fufgVIFbeRIeHmMnQOVmhF1aD24eJqziGw1xvz9jA98bUzeD23dupWBAwdijKFUqVLMnDnT6pKUUh508SLET7FxH18QNbijywF/LRryFmvWrBnbt2+3ugylVAFZuBBSTgYxMGQ6PDbb4+1pyCulVAF6641MasgBWvaJhFycDu0q/eJVKaUKyObN8O3WYJ4wkwkYNLBA2tSevFJKFZDXx9spKWd5pPmvULdugbSpPXmllCoAv/4KH30M/c00iv+74C5A1JAvAFFRUfz+++9Wl6GUstCbbxgCTBZP3vgl3HtvgbWrIZ8HxpjL0whoHUqp3Dp1CmZMz6I784gc3hsCCi56NeSv4eDBg9SoUYPevXtTt25dDh8+fMXpeTt16sTNN99MnTp1iI+Pv+a+v/zySxo1akRMTAwtWzruh559GmCAunXrcvDgwb/V8dJLL/HMM89c3u7SFMgAc+bMoUmTJjRo0ID+/fuTlZXlrh+HUiofpk+HtPQgni77PlxhIkJPccsXryLyFPAYYIAdwCNARWAeUAbYCvQyxmRccSe5YMFMwwDs3buXWbNmERcXd9XpeWfOnElERATp6ek0btyYBx54gDJlyuS4z5SUFB5//HE2bNhAdHR0rqYPzl5HSkoKTZs2Zdy4cQDMnz+fESNGsGfPHubPn8/XX39NcHAwTzzxBHPnzqV37955/MkopdwhMxMmvpbBnWyk4TN3Q5EiBdq+yyEvIpWBQUBtY0y6iCwAugP3AhOMMfNEZBrwKDDV1fasULVqVeLi4gD+ND0vQFpaGnv37qV58+ZMnDiRTz75BIDDhw+zd+/eK4b85s2bad68+eXpgHMzfXD2OsqVK0e1atXYvHkz1atXJyEhgdtuu43JkyezdetWGjduDEB6ejrly5d37QeglMq3hQsh8XgI04pOhf4Ff7c4d51CGQQUFZFMoBhwDLgLuPR7ySzgBVwMeatmGs4+8deVpuddt24dq1atYtOmTRQrVowWLVpw4cKFPLeVffpg4E/7+OsEZN27d2fBggXUrFmTzp07IyIYY+jTpw+vvPJKnttWSrmXMfDaKxepyQHa/jMKcpiA0NNcHpM3xhwBxgO/4Qj30ziGZ04ZYy5Nmp4IVM7p8yLST0S2iMiWlJQUV8vxuCtNz3v69GlKly5NsWLFSEhIYPPmzVfdT1xcHBs2bODXX38F/jx98LZt2wDH3Z0uvZ+Tzp078+mnn/Lhhx/SvXt3AFq2bMmiRYsuTxl84sQJDh065NpBK6XyZf162LazCE8HvEnAU4MtqcEdwzWlgY5ANHAKWAi0ye3njTHxQDw4ZqF0tR5Pa9WqFXv27KFp06aA476qc+bMoU2bNkybNo1atWpRo0aNy8MqV1KuXDni4+O5//77sdvtlC9fnpUrV/LAAw8we/Zs6tSpwy233MJNN910xX2ULl2aWrVqsXv3bpo0aQJA7dq1GT16NK1atcJutxMcHMzkyZOpWrWq+34ISqlceX1sBuU4zcPdbRAZaUkNLk81LCJdgTbGmEedr3sDTYGuwHXGGJuINAVeMMa0vtq+CuNUw75A/wyUyrs9e6B2bXieF3jhpwegXj2PtXW1qYbdcQrlb0CciBQTx/3tWgK7gbVAF+c2fQDvmiRdKaU86NUxNopxnoEtEzwa8NfijjH5b4FFwDYcp08G4Bh++Q/wtIjsw3Ea5QxX21JKKV9w8CDM/TCA/kyj7MgBltbilrNrjDHPA8//ZfUBoImb9n/5JtiqYHnTncOU8hXjXs0iwJ7F0MZfwR1PWVqL11/xGhoaSmpqqoaNBYwxpKamEhoaanUpSvmMpCSYMcPQl/eoPPqfYHEH1eunGo6MjCQxMRFfOL3SH4WGhhJp0VkBSvmiCeOzyLQJw+ovh3sWWV2O94d8cHDw5atClVLKm504AVMm2+nGQm4c84jlvXjwgeEapZTyFW9NtJN2IZhnayyGdu2sLgfwgZ68Ukr5grQ0ePO1TO5jOfXHdPOKXjxoT14ppdwi/m07J9KKMDx6HnTubHU5l2lPXimlXHTxIowfc5E72UTcmPsK9KYg16Ihr5RSLpoebzh2siizK8+GB73ruk8NeaWUckF6Orz8/AWa8T0tX2oBgYFWl/QnGvJKKeWCt6faOXayKB9ETkd6vWt1OX+jIa+UUvl07hy88uJF7uIbWoxvD0HeF6ne8+2AUkr5mCmTskg+U5QXb3gfuna1upwced9/O0op5QPOnoWxozNpxSpun/CAV51Rk513VqWUUl7urTcyST0Xyqg6C6B9e6vLuSLtySulVB6dPu2YTrgdy7nlzR5ec3VrTrQnr5RSefTmuAxOng9lVOPPoWVLq8u5Kg15pZTKg5Mn4fXX7HTiExpN7Gt1OdekIa+UUnnw6gsXOHMhhBear4G4OKvLuSYNeaWUyqXffoM3JwfyMHOImfiY1eXkioa8Ukrl0v89nQZZWbz0wHaIibG6nFzRkFdKqVz46SeY/VExngyaStU3rL05d164JeRFpJSILBKRBBHZIyJNRSRCRFaKyF7nY2l3tKWUUlZ4dsApSnGK4YPPgw/d99hdPfk3gS+NMTWBGGAP8Cyw2hhTHVjtfK2UUj5nzWrDsk2lGB4+idLPD7K6nDxxOeRFpCTQHJgBYIzJMMacAjoCs5ybzQI6udqWUkoVNLsdhvU/xfUcYuCrkVC8uNUl5Yk7evLRQArwroj8ICLviEgYUMEYc8y5TRJQIacPi0g/EdkiIltSUlLcUI5SSrnPgg9sbN1fmtGVpxHav4/V5eSZO0I+CGgETDXGNATO8ZehGWOMAUxOHzbGxBtjYo0xseXKlXNDOUop5R4XL8LwIeeI4Ud6vt3cK6cSvhZ3hHwikGiM+db5ehGO0D8uIhUBnI/JbmhLKaUKzFv/O8+vqSUZ23A+Afe2sbqcfHE55I0xScBhEanhXNUS2A0sAS79btMH+NTVtpRSqqAcPw6jRgv3spTW73b36knIrsZdv3s8CcwVkRDgAPAIjv9AFojIo8Ah4EE3taWUUh43fMAJ0jPCmdDje4i51+py8s0tIW+M+RGIzeEt756eTSmlcvD9d4aZiyN4JnQSN0160upyXOJ73yIopZQH2e0wqOfvVCCLkWOLQ0SE1SW5RKc1UEqpbOZMT2fzvnKMjX6bEgN7W12Oy7Qnr5RSTmfPwn+GZtKEn+j1QVuvvW9rXvj+ESillJuMHvI7SedKMLHDagLimlhdjltoyCulFLD3F8OEd0vSJ/gDbnnncavLcRsNeaVUoWcMPNHlOKEmnVdezAA/uvpeQ14pVejNfTuNVTuu45UqU6k4rJfV5biVhrxSqlBLTYWnnjLcwrcM+KQ1BAZaXZJbacgrpQq1YQ8f5eSFosT/YzOBNzewuhy305BXShVa61dcZOaXlRhaaib13+pndTkeoefJK6UKpYsXof9Dp4niHM9/UAOKFrW6JI/QnrxSqlAaO+QYP58oz9TWiynW9g6ry/EYDXmlVKHz8y4bY94uQ/cin9Dmw75Wl+NRGvJKqUIlKwv+0S6JYuYcEyYFQenSVpfkURrySqlCZcKwo3xzKJJJsbO57rH2VpfjcRrySqlCY/ePGYycUIZORZbS84sePnu3p7zQs2uUUoWCzQZ97z1OuCnKtPhApLz/TF1wNdqTV0oVCmOfOMT3x6ow9a6FVOjd2upyCoyGvFLK723fdJ4Xp1ekW7HP6Lr4YavLKVA6XKOU8msZGdDnvlQiCGbyvDJQvLjVJRUot/XkRSRQRH4Qkc+dr6NF5FsR2Sci80UkxF1tKaVUbr3Yax/bU6sQ32kZZe671epyCpw7h2sGA3uyvR4LTDDG3AicBB51Y1tKKXVNaxb8zisLqvGP0h/T4cOHrC7HEm4JeRGJBNoB7zhfC3AXsMi5ySygkzvaUkqp3EhJyuLhXlBDfmHiytoQGmp1SZZwV0/+DWAYYHe+LgOcMsbYnK8Tgcpuaksppa7KGHik+X5OZIQz78VfCLu5ptUlWcblkBeR9kCyMWZrPj/fT0S2iMiWlJQUV8tRSineHPgLX+y9ifFNFhLz3w5Wl2Mpd/TkbwM6iMhBYB6OYZo3gVIicunsnUjgSE4fNsbEG2NijTGx5fzovopKKWtsW5nKsClRdAxfxb9Wdba6HMu5HPLGmOeMMZHGmCigO7DGGNMTWAt0cW7WB/jU1baUUupq0s7Y6d4pnfKkMGNpJaR4uNUlWc6TF0P9B3haRPbhGKOf4cG2lFKFnDEwoPku9p+vyNyh2yjTrLbVJXkFt14MZYxZB6xzPj8ANHHn/pVS6komDdjF3O31eKn+Qu4Y1+XaHygkdFoDpZTP2zDnN56Or0HHkmsZ/nW7QjG7ZG5pyCulfFrinrN07VuMGwIPMuurGwgIL2Z1SV5FQ14p5bMuptvpcttRzmcVYfGME5Ssd73VJXkdDXmllM8adPs2vj1Zg1mPfkWtPvoVYE405JVSPmn6Ez8Qvy2WZ+t9wf3T21pdjtfSkFdK+ZzVkxN4YmpdWpfcxOhNLfWL1qvQkFdK+ZTdyw/zwJMVqRl8gPnf30BgWOGceCy3NOSVUj7j+J4T3Ns+gKKk88XyIEpWL291SV5PQ14p5RPOn7hAhyZJpNhK8Vl8EtffeYPVJfkEDXmllNez2+z0qr+d79Nq8sG/fyD2sQZWl+QzNOSVUl7N2A3/jvuKj4/cwusd1tFx3O1Wl+RTNOSVUl7t5VbrmLD1Dp6sv57Bn9xpdTk+R0NeKeW1JnVew8jVd9Kr2te8sbUZEqCnSuaVhrxSyivN6rOGQYvvolOl75i56xYCgjSu8kN/akopr/PxoHX8Y/Yd3F32B+YlNCAo1K2zohcqGvJKKa+y4r9f0X3SrdxSIoHFCbUoUjzE6pJ8moa8UsprLH9uHR1Hx1Kr2CG+2FmVsDJ6NaurNOSVUl7hs0Er6fBqU2qGJbJ6V0VKV9H7s7qDhrxSynIfPbaM+ye1IKbEQVYnVKZslAa8u2jIK6Us9cFDn9Ftxj00idjHyr1RRETqnZ3cSUNeKWUNY3i3wyc8PK8dzcr/zPJ9N1KyfBGrq/I7Loe8iFQRkbUisltEdonIYOf6CBFZKSJ7nY+lXS9XKeUPzMUMXon9iH981pl7Ku/mi301CS8dbHVZfskdPXkbMNQYUxuIA/4lIrWBZ4HVxpjqwGrna6VUIWdLPc0/b1jB8G1d6FnvJz7bX4dixQOtLstvuRzyxphjxphtzudngT1AZaAjMMu52Sygk6ttKaV827mfE+kc/QNvH2nPc+1+4v3t9QkpolMVeJJbx+RFJApoCHwLVDDGHHO+lQRUcGdbSinfcnzVDlrU+52lZ5sxddAeXv68vt61rwC4LeRFJBz4CBhijDmT/T1jjAHMFT7XT0S2iMiWlJQUd5WjlPIi2/+3nLhWxdltu4nFkxIZ8GYtq0sqNNwS8iISjCPg5xpjPnauPi4iFZ3vVwSSc/qsMSbeGBNrjIktV66cO8pRSnkLm40P2s2l6X+akRlcjHVfnOe+gVWtrqpQccfZNQLMAPYYY17P9tYSoI/zeR/gU1fbUkr5DtvRZIZGf0zPpT2JrXiUrftK0bhtWavLKnTcMbXbbUAvYIeI/OhcNxx4FVggIo8Ch4AH3dCWUsoHpHy5lW6dLrD24oM8eU8Cr31Rk2A9Q9ISLoe8MWYjcKWvT1q6un+llA+x2/lq4Hx6Tr2NFCnPrFGH6P3fmlZXVajpJM1KKbfI/DWRUXet5eWDPYgOS2bjkovcfJeOv1tNpzVQSrls/1vLaFY9idEHe9H7tv38cPQ6br6rpNVlKbQnr5RygTl5itkdFjFwYzeCAg3z3zjGg4NvsroslY2GvFIq74zh0LRl/HNIEZZlPEbz63/l/TWRXH9DCasrU3+hIa+UypOsg4eZ1H45I3d1h4AAJgw9zJNjownU6We8ko7JK6Vyx2Zj+7Mf0vSG4zy16zGa10xm1y8hDBlfRQPei2lPXil1TamL1vJi/6NMOdGNiJA0PvhfMt0HVdO5Z3yAhrxS6ooyftzN5Ic2MiqhK2dozuP3HGLMB9GUKavp7it0uEYp9TfmWBKL275NnYbBPJ3QjyY3nWb71iymraimAe9jNOSVUpeZo8f4omM8t1ROpPOX/QkpHc6yeadY/nMUdRuFWF2eygcNeaUU5shRPrsvniaRR2m/pB8pYVWZPiaZ7ckVadOtlNXlKRfomLxShVjmDzv5aOg3jFsXyzbTj2rFk5kxPJleQ8vrhGJ+QkNeqcLGbiflg5XEP3+EKQdac5R+VC95nHeHJ9PzKQ13f6Mhr1QhYY4e47uXVxE/O5S5Z+/jIqG0qn6A6aPO0ObBCgTo4K1f0pBXyp9dvEjie6uYMyGFWT/fQgK9KBaQziP3JDJofFVq1a9mdYXKwzTklfI3GRn8vngjSyYfZv43VVhpa4shgGZVfuXf/ZPo+uR1lChxo9VVqgKiIa+UPzh/nt8+2Mjit4/zyY9RbLDdgZ1AosKS+W+3A/T+v2huuCna6iqVBTTklfJFdjtpm3fy1Ts/s2o1rDpcg59MKwDqlEpkeKt9dB4SRcO48oiUt7hYZSUNeaV8QWYmv6/fxbcLDvHt15ms31uJTZmxZFKfInKR26scYmyrX+g0JJqb6kRaXa3yIhrySnkbm43kTfvZsSyRHd+eZ+vOImxOqcY+0wBoQCA2GkT8xtNxP3N3r4rc1rEsRYvqjTpUzjTklbLIhZPpHNiQyP7NKezfcZ79+2HPkRLsOFuVZGoANQCoGJxCXLWjPH7LduI6V+TmtuUJC9OzYlTueDzkRaQN8CYQCLxjjHnV020q75WZCRcvOh5ttj8es7Ic7xvz5+0DAyEg4I/HoCDHEhz8x+It090aA+fPw6mThtSDZzlx4BSph9I4cSSdpMMZHD0CR1OCOHI6nKPpERyzVwCqOxcoLmepEX6EdnUOUS/mMPVaRFCvbRUqRJYDyll5aMqHeTTkRSQQmAzcAyQC34vIEmPMbk+2qzwrPR2OHYPjxyEp6Y/HEyfg1Kk/L2lpjuBzLIasLPcncmCAneAgQ0iQneBAOyFBhuCgbOuCjPO5ISjw70uA2AkQQwCGADEIBmO3Y7cZ7HbjeMwyZGZCRgZkZMLFzAAyMoVzGcGkZYRwzhZKmr0ohgBAgBLO5Q9lSaFSSCqVws8QUzGFqlX2cEOdUG6ILc0NzStTNqo4IjXd/vNRhZune/JNgH3GmAMAIjIP6AhoyHsxYyAlBfbscSy//goHD/6xJCfn/LlSoemUCj5HqcCzlOI01ewnKZ51imK2MxTLPE0x+1mKkk4oFwjCRjCZlx8DsCM4uvGXHg2CnQCyCCSLQOwEYCMIG0FkEkwmwdgIIsMeQmZGMJkZwWQQQiZ/f7z0PItAbASR7vysjaDL7WRfsr8SDAHYCSHDsYiN8MAsQgKzCAvJJLykjbCidsLDDGHhULpMIGUqBBFRKZQy14cREVWC8nXLU6RyWRDtkauC5emQrwwczvY6Ebgl+wYi0g/oB3D99dd7uBz1V5mZsHMnbNkCW7fCrl2OYE9N/WObkKAsqpY4SVTIUTqaA1QttpvK53+hAse5jiQqcJxypBCcYYfwCChT5o+lVCkoXjzbUh7CwqBIEQgN/WO5NPZyaTwmKMgxPiPy98WYPxYAu93xPPtjTuv+NBZkBzLAXPxjLEjE8RgQACEhjnouPQYHQ7FiULSoozalfITlf1uNMfFAPEBsbKy5xubKRUlJsH49fPUVfP89bN/uGCMHKF3sAnVLJfJAkQRqF99MrbPfUYs9VLYdIeB0AFStCtWqOZYqN0Kl5lC5MlSqBBUrQkQEOgGKUt7F0yF/BKiS7XWkc50qICdOwMqVsG6dY0lIcKwPD80ktuxBBpX5jtiUZcRmfkP0+V8RCYM6daBuXajTBuoMhZtugipVtAerlA/y9L/a74HqIhKNI9y7Az083GahZowjyD//HD77DL7+2jFSUbxoJs3L7eHRssto8ftCGlz4kaDUEGjUCLo1gSYvQ+PGEB2tvXGl/IhHQ94YYxORgcByHKdQzjTG7PJkm4XVzp3wwQewYAHs3+9Y1+C6Ywwv/wXtk6Zzc/pWgk4Wg+bN4c6HoMXbEBOjvXOl/JzH/4UbY5YCSz3dTmF06BDMmwdz58KOHRAYYLi78h6GlppD+1PvU+X4EYiLgyc7wt2THL12DXWlChX9F+9jbDbHMMyUKbBqlWNd00oHeavUO3Q9FU/51HPQqhV0GAXt2kF5nZxKqcJMQ95HHD0K77wD8fFw5AhUKXWGF8vN5uGU16mWfBjatIGeE6FDB8epfkophYa819u5E159FebNc1wt2rrSDiYHj6LdqU8IatoERg2DLl2gbFmrS1VKeSENeS+1eTO88gosWQJhRTIZVGY+TyS/wI1njsNjvWDANqhf3+oylVJeTkPey6xfDy++CGvXQkSxC7xQcipPnn6JiHKV4IWh8PDDjitHlVIqFzTkvcSOHfDss7B0KVQsfpbXwsbR79zrhN8aA8/NdnyJ6i3TLSqlfIZe9WKx336Dvn0hJsbw9ZoLjC36AvvPlufpZt8Tvn4pbNwI7dtrwCul8kV78hY5fx5Gj4bXXzeQZWdo+HSeOzuciFaN4eWNcPPNVpeolPIDGvIW+OwzePJJx8VMD5f8nNGnB1K1RnkYuwjuusvq8pRSfkRDvgAdOgSDBjnOmKld/DfW8zDNyx2D6eMdp0HqkIxSys10TL4AZGXBa69B7dqGVcsyGBs0gh9t9Wj+yr2OCdy7dtWAV0p5hPbkPezAAejTx/H96X1ha5mU+QhVO98Mb+wAvUmKUsrDNOQ9xBjHNARPPWUIzLzALPrTq8I3yFvToG1bq8tTShUSOlzjAUlJcN990K8f3GLfxI6MmvQeUgbZ8ZMGvFKqQGlP3s1WrYIePQxnT9p4k38zsMLnBMx63zGPu1JKFTDtybuJ3e44771VK0PZMwfYaoth0IAMAnZs14BXSllGe/JukJoKvXrBsmXQI3ABbxd/lvDFUx3T/yqllIU05F30/ffQ5QE7SUeymMIgBjTdhcz/GipVsro0pZTS4RpXzJsHzW63Q1ISG+238s/nSiNr12jAK6W8hoZ8PhgDL70EDz0EjbM2szXsDhp/8SK8/LLeQ1Up5VU0kfLo4kV4/HHD++8LvXif6bUnUuTz1Xphk1LKK7nUkxeRcSKSICI/icgnIlIq23vPicg+EflZRFq7XKkX+P13uOduO++/L7zESGZ1+Igi36zVgFdKeS1Xh2tWAnWNMfWBX4DnAESkNtAdqAO0AaaISKCLbVnq4EFoeksW331jYx7dGPlsFvLJxxAebnVpSil1RS6FvDFmhTHG5ny5GYh0Pu8IzDPGXDTG/ArsA5q40paVEhLg9qY2fj+YxpqAe+g2u73jBqwB+pWGUsq7uTOl/gEscz6vDBzO9l6ic93fiEg/EdkiIltSUlLcWI57bNsGzW61YUs+wfqwdty6dozjpHillPIB1/ziVURWAdfl8NYIY8ynzm1GADZgbl4LMMbEA/EAsbGxJq+f96SNG6FdGxul04+yqkx3blwTD3XrWl2WUkrl2jVD3hhz99XeF5G+QHugpTHmUkgfAapk2yzSuc5nfPkl3N8pi+sz9rPy+seosnYuREdbXZZSSuWJq2fXtAGGAR2MMeezvbUE6C4iRUQkGqgOfOdKWwVp2TLo0D6Lmhd/4qs6/6TKt4s04JVSPsnV8+TfAooAK8VxZ6PNxpgBxphdIrIA2I1jGOdfxpgsF9sqEKtWQeeOWdTL+pHVt/4fpZZ+AiVLWl2WUkrli0shb4y58SrvjQHGuLL/grZhA3RoZ+OmzN2saD6GUssWQrFiVpellFL5ple8Om3aBO1aZ1I1Yx+rbnuBMsvmaMArpXyehjywZQu0aZnBdRcOsTpuJOWXv68Br5TyC4X+ap7du6FViwwi0o+wpslzVFo1G8LCrC5LKaXcolD35I8ehbYt0gk5d5rVsf+hyur3NOCVUn6l0Ib8mTPQ9o5znEixs77mUKqtfkfnoVFK+Z1CGfIZGXB/63Ps3hfC5xX70Wj9BChRwuqylFLK7QpdyBsDjz50jtWbw3i3xGBab/wvlC9vdVlKKeURhe6L1+FPpTPn4zBeCnmJvusfgWrVrC5JKaU8plCF/LtvZ/Dqm0XpF/AOI5bdDg0aWF2SUkp5VKEJ+U3fGAY8IdzNSibPLYXcdafVJSmllMcVipA/cgTub51GpP035o/cSVD3LlaXpJRSBcLvQz49HTrdcYK0NFjSfjoRo4ZYXZJSShUYvz67xhjo1+UEW/ZHsPimYdRZOAocs2UqpVSh4Nc9+ddfOMOcpRGMKjGejhuGQmio1SUppVSB8tuQX70sg2GjwugS+DEj190NFSpYXZJSShU4vxyuSUqCng+kU4MjvDs7CGnYwOqSlFLKEn7Xk8/Kgh53HuNMejALH1tBeI8OVpeklFKW8buQf2ngcdYmVGRKzUnUmTrQ6nKUUspSfhXyq5ekMWpaOfoUnU/fdX0hyC9Ho5RSKtf8JuSTjhl6PphJTRKYvDhSv2hVSincFPIiMlREjIiUdb4WEZkoIvtE5CcRaeSOdq4kKwt6ND/MmYtFWPjv7whrdZsnm1NKKZ/hcsiLSBWgFfBbttVtgerOpR8w1dV2rmbm8H2s3Xc9UxrNoM7/+niyKaWU8inu6MlPAIYBJtu6jsBs47AZKCUiFd3QVo76dkljfv0x9F3TW69oVUqpbFz6ZlJEOgJHjDHb5c/hWhk4nO11onPdMVfau5Lgxg14cHsDT+xaKaV82jVDXkRWAdfl8NYIYDiOoZp8E5F+OIZ0uP76613ZlVJKqb+4ZsgbY+7Oab2I1AOigUu9+Ehgm4g0AY4AVbJtHulcl9P+44F4gNjYWJPTNkoppfIn32PyxpgdxpjyxpgoY0wUjiGZRsaYJGAJ0Nt5lk0ccNoY45GhGqWUUlfmqauFlgL3AvuA88AjHmpHKaXUVbgt5J29+UvPDfAvd+1bKaVU/vjNFa9KKaX+TkNeKaX8mIa8Ukr5MXEMn3sHEUkBDuXz42WB391YjpX0WLyTvxyLvxwH6LFcUtUYUy6nN7wq5F0hIluMMbFW1+EOeizeyV+OxV+OA/RYckOHa5RSyo9pyCullB/zp5CPt7oAN9Jj8U7+ciz+chygx3JNfjMmr5RS6u/8qSevlFLqL/wq5EXkJeftBn8UkRUiUsnqmvJLRMaJSILzeD4RkVJW15RfItJVRHaJiF1EfO5MCBFpIyI/O29n+azV9eSXiMwUkWQR2Wl1La4SkSoislZEdjv/bg22uqb8EJFQEflORLY7j+NFt7fhT8M1IlLCGHPG+XwQUNsYM8DisvJFRFoBa4wxNhEZC2CM+Y/FZeWLiNQC7MDbwL+NMVssLinXRCQQ+AW4B8dMq98DDxljdltaWD6ISHMgDcdd2+paXY8rnHeaq2iM2SYixYGtQCdf+3MRxzztYcaYNBEJBjYCg5131HMLv+rJXwp4pzD+fEtCn2KMWWGMsTlfbsYxJ79PMsbsMcb8bHUd+dQE2GeMOWCMyQDm4bi9pc8xxmwATlhdhzsYY44ZY7Y5n58F9uC4+5xPcd4iNc35Mti5uDW3/CrkAURkjIgcBnoC/2d1PW7yD2CZ1UUUUle6laXyEiISBTQEvrW4lHwRkUAR+RFIBlYaY9x6HD4X8iKySkR25rB0BDDGjDDGVAHmAgOtrfbqrnUszm1GADYcx+O1cnMsSrmbiIQDHwFD/vKbvM8wxmQZYxrg+G29iYi4dSjNUzcN8Zgr3Y4wB3Nx3LzkeQ+W45JrHYuI9AXaAy2Nl395koc/F1+T61tZqoLlHMP+CJhrjPnY6npcZYw5JSJrgTaA274c97me/NWISPVsLzsCCVbV4ioRaQMMAzoYY85bXU8h9j1QXUSiRSQE6I7j9pbKQs4vLGcAe4wxr1tdT36JSLlLZ86JSFEcX/C7Nbf87eyaj4AaOM7kOAQMMMb4ZK9LRPYBRYBU56rNPnymUGdgElAOOAX8aIxpbWlReSAi9wJvAIHATGPMGGsryh8R+RBogWO2w+PA88aYGZYWlU8icjvwFbADx793gOHGmKXWVZV3IlIfmIXj71YAsMAYM8qtbfhTyCullPozvxquUUop9Wca8kop5cc05JVSyo9pyCullB/TkFdKKT+mIa+UUn5MQ14ppfyYhrxSSvmx/wd/D+ouaWFrvwAAAABJRU5ErkJggg==\n", - "text/plain": [ - "
" - ] - }, - "metadata": { - "needs_background": "light" - }, - "output_type": "display_data" - } - ], - "source": [ - "# 画出更新之后的结果\n", - "y_pred = multi_linear(x_train)\n", - "\n", - "plt.plot(x_train.data.numpy()[:, 0], y_pred.data.numpy(), label='fitting curve', color='r')\n", - "plt.plot(x_train.data.numpy()[:, 0], y_sample, label='real curve', color='b')\n", - "plt.legend()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "可以看到,经过 100 次更新之后,可以看到拟合的线和真实的线已经完全重合了" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "collapsed": true - }, - "source": [ - "## 4. 练习题\n", - "\n", - "上面的例子是一个三次的多项式,尝试使用二次的多项式去拟合它,看看最后能做到多好\n", - "\n", - "**提示:参数 `w = torch.randn(2, 1)`,同时重新构建 x 数据集**" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3 (ipykernel)", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.9.7" - } - }, - "nbformat": 4, - "nbformat_minor": 2 -} diff --git a/6_pytorch/1_NN/backup/6-nn_summary.ipynb b/6_pytorch/1_NN/backup/6-nn_summary.ipynb deleted file mode 100644 index 51200bf..0000000 --- a/6_pytorch/1_NN/backup/6-nn_summary.ipynb +++ /dev/null @@ -1,1955 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 第四章 神经网络工具箱nn\n", - "上一章中提到,使用autograd可实现深度学习模型,但其抽象程度较低,如果用其来实现深度学习模型,则需要编写的代码量极大。在这种情况下,torch.nn应运而生,其是专门为深度学习而设计的模块。torch.nn的核心数据结构是`Module`,它是一个抽象概念,既可以表示神经网络中的某个层(layer),也可以表示一个包含很多层的神经网络。在实际使用中,最常见的做法是继承`nn.Module`,撰写自己的网络/层。下面先来看看如何用nn.Module实现自己的全连接层。全连接层,又名仿射层,输出$\\textbf{y}$和输入$\\textbf{x}$满足$\\textbf{y=Wx+b}$,$\\textbf{W}$和$\\textbf{b}$是可学习的参数。" - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "metadata": {}, - "outputs": [], - "source": [ - "import torch as t\n", - "from torch import nn\n", - "from torch.autograd import Variable as V" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": {}, - "outputs": [], - "source": [ - "class Linear(nn.Module): # 继承nn.Module\n", - " def __init__(self, in_features, out_features):\n", - " super(Linear, self).__init__() # 等价于nn.Module.__init__(self)\n", - " self.w = nn.Parameter(t.randn(in_features, out_features))\n", - " self.b = nn.Parameter(t.randn(out_features))\n", - " \n", - " def forward(self, x):\n", - " x = x.mm(self.w) # x.@(self.w)\n", - " return x + self.b.expand_as(x)" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "Variable containing:\n", - " 0.6614 2.4618 1.6848\n", - " 1.7110 2.8197 -1.7891\n", - "[torch.FloatTensor of size 2x3]" - ] - }, - "execution_count": 3, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "layer = Linear(4,3)\n", - "input = V(t.randn(2,4))\n", - "output = layer(input)\n", - "output" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "w Parameter containing:\n", - " 0.7730 0.1062 -1.4568\n", - "-0.0182 0.3505 1.9311\n", - "-0.6398 -1.5122 0.5403\n", - " 0.1200 -0.3439 0.3741\n", - "[torch.FloatTensor of size 4x3]\n", - "\n", - "b Parameter containing:\n", - " 0.4206\n", - " 1.5090\n", - " 1.1140\n", - "[torch.FloatTensor of size 3]\n", - "\n" - ] - } - ], - "source": [ - "for name, parameter in layer.named_parameters():\n", - " print(name, parameter) # w and b " - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "可见,全连接层的实现非常简单,其代码量不超过10行,但需注意以下几点:\n", - "- 自定义层`Linear`必须继承`nn.Module`,并且在其构造函数中需调用`nn.Module`的构造函数,即`super(Linear, self).__init__()` 或`nn.Module.__init__(self)`,推荐使用第一种用法,尽管第二种写法更直观。\n", - "- 在构造函数`__init__`中必须自己定义可学习的参数,并封装成`Parameter`,如在本例中我们把`w`和`b`封装成`parameter`。`parameter`是一种特殊的`Variable`,但其默认需要求导(requires_grad = True),感兴趣的读者可以通过`nn.Parameter??`,查看`Parameter`类的源代码。\n", - "- `forward`函数实现前向传播过程,其输入可以是一个或多个variable,对x的任何操作也必须是variable支持的操作。\n", - "- 无需写反向传播函数,因其前向传播都是对variable进行操作,nn.Module能够利用autograd自动实现反向传播,这点比Function简单许多。\n", - "- 使用时,直观上可将layer看成数学概念中的函数,调用layer(input)即可得到input对应的结果。它等价于`layers.__call__(input)`,在`__call__`函数中,主要调用的是 `layer.forward(x)`,另外还对钩子做了一些处理。所以在实际使用中应尽量使用`layer(x)`而不是使用`layer.forward(x)`,关于钩子技术将在下文讲解。\n", - "- `Module`中的可学习参数可以通过`named_parameters()`或者`parameters()`返回迭代器,前者会给每个parameter都附上名字,使其更具有辨识度。\n", - "\n", - "可见利用Module实现的全连接层,比利用`Function`实现的更为简单,因其不再需要写反向传播函数。" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Module能够自动检测到自己的`Parameter`,并将其作为学习参数。除了`parameter`之外,Module还包含子`Module`,主Module能够递归查找子`Module`中的`parameter`。下面再来看看稍微复杂一点的网络,多层感知机。\n", - "\n", - "多层感知机的网络结构如图4-1所示,它由两个全连接层组成,采用$sigmoid$函数作为激活函数,图中没有画出。\n", - "![图4-1;多层感知机](imgs/multi_perceptron.png)" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "metadata": {}, - "outputs": [], - "source": [ - "class Perceptron(nn.Module):\n", - " def __init__(self, in_features, hidden_features, out_features):\n", - " nn.Module.__init__(self)\n", - " self.layer1 = Linear(in_features, hidden_features) # 此处的Linear是前面自定义的全连接层\n", - " self.layer2 = Linear(hidden_features, out_features)\n", - " def forward(self,x):\n", - " x = self.layer1(x)\n", - " x = t.sigmoid(x)\n", - " return self.layer2(x)" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "layer1.w torch.Size([3, 4])\n", - "layer1.b torch.Size([4])\n", - "layer2.w torch.Size([4, 1])\n", - "layer2.b torch.Size([1])\n" - ] - } - ], - "source": [ - "perceptron = Perceptron(3,4,1)\n", - "for name, param in perceptron.named_parameters():\n", - " print(name, param.size())" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "可见,即使是稍复杂的多层感知机,其实现依旧很简单。这里新增两个知识点:\n", - "\n", - "- 构造函数`__init__`中,可利用前面自定义的Linear层(module),作为当前module对象的一个子module,它的可学习参数,也会成为当前module的可学习参数。\n", - "- 在前向传播函数中,我们有意识地将输出变量都命名成`x`,是为了能让Python回收一些中间层的输出,从而节省内存。但并不是所有都会被回收,有些variable虽然名字被覆盖,但其在反向传播仍需要用到,此时Python的内存回收模块将通过检查引用计数,不会回收这一部分内存。\n", - "\n", - "module中parameter的命名规范:\n", - "- 对于类似`self.param_name = nn.Parameter(t.randn(3, 4))`,命名为`param_name`\n", - "- 对于子Module中的parameter,会其名字之前加上当前Module的名字。如对于`self.sub_module = SubModel()`,SubModel中有个parameter的名字叫做param_name,那么二者拼接而成的parameter name 就是`sub_module.param_name`。" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "为方便用户使用,PyTorch实现了神经网络中绝大多数的layer,这些layer都继承于nn.Module,封装了可学习参数`parameter`,并实现了forward函数,且很多都专门针对GPU运算进行了CuDNN优化,其速度和性能都十分优异。本书不准备对nn.Module中的所有层进行详细介绍,具体内容读者可参照官方文档[^1]或在IPython/Jupyter中使用nn.layer?来查看。阅读文档时应主要关注以下几点:\n", - "\n", - "- 构造函数的参数,如nn.Linear(in_features, out_features, bias),需关注这三个参数的作用。\n", - "- 属性,可学习参数,子module。如nn.Linear中有`weight`和`bias`两个可学习参数,不包含子module。\n", - "- 输入输出的形状,如nn.linear的输入形状是(N, input_features),输出为(N,output_features),N是batch_size。\n", - "\n", - "这些自定义layer对输入形状都有假设:输入的不是单个数据,而是一个batch。若想输入一个数据,则必须调用`unsqueeze(0)`函数将数据伪装成batch_size=1的batch\n", - "\n", - "[^1]: http://pytorch.org/docs/nn.html" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "下面将从应用层面出发,对一些常用的layer做简单介绍,更详细的用法请查看文档,这里只作概览参考。" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### 4.1 常用神经网络层\n", - "#### 4.1.1 图像相关层\n", - "\n", - "图像相关层主要包括卷积层(Conv)、池化层(Pool)等,这些层在实际使用中可分为一维(1D)、二维(2D)、三维(3D),池化方式又分为平均池化(AvgPool)、最大值池化(MaxPool)、自适应池化(AdaptiveAvgPool)等。而卷积层除了常用的前向卷积之外,还有逆卷积(TransposeConv)。下面举例说明一些基础的使用。" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "metadata": {}, - "outputs": [ - { - "data": { - "image/png": "iVBORw0KGgoAAAANSUhEUgAAAMgAAADICAAAAACIM/FCAABaCUlEQVR4nGX9+dMlWZIdhh13vzci\n3vJtuWfWXl29THfPTnAGhiFpBAmSv0lmNOnvk2SizARBIo0gODAAJAAS2wCD6Zneqpfaq3Jfvu1t\nEfe6H/1w42uMTFXWbVkvX+b3XkRcX46fc1z+b4Q6QysI0WzakbHfXr68KF0yMVdQAddsqoYUWYhE\nRyiQAiIAQBpDRQgKPCAQIagC8UhpcgNcBK4UEuIenHzsVkfDkLLGASIK0QqCsFANk8pbby9TwuGL\nf/vSug4JISpVQRWRCJiGsAT7949fKwAQIPDXfkWQ7T8J3vwHfvMWEIC0/5ebl29ex///2+e/6T/8\nDv/a/37zJs7vkr/+OQAAqn/tvQBF8Nd/hIB1TEGQDgkh1BUKIFgLyIpkTgBKAcIgAkcIEAAURABC\nyvxXkAChUBIQcQqBgEQVCQjIaB9Iov0hcS+dawgsEEpRYbtCDDAQu11vlHxy93UNigsEFIgI2s8j\nBIYsX5w/TlUgAKABhKi2T+UAQ8iqFEEAEISAImRA5D/cthABIRQhQkCGQoKCAOevhri5T4RDKUFo\nBcTqIWUzV1VQDAQMQSGUAEP2rwdT1eP3Xj33mlTIUKGAFEAgAkhO/vrTV1UFAapCBAJTIUmG0wkg\n2oNPBCEASaoACIoCArKCQZFo34sUiJAi858QFSCCYgKIqApGUFQAFRErh1I9ov2WJlWBiChIkkRc\nXhanDPffG9w9CAgDEYEAKBKgZXn56+dFVKAkBIRIyiZgeIVQRYQmqiqAMDhfWSJIantIBcIQBaEQ\nhqpAhRARaQ+JqoooQCoEJAUhEAUSwKScxogIWDJTtZRERNohCAIyvdk5oet3HygjIoJQiCAEKgKo\n2fj4kyuKKtGeA4CaVILe3h9gUAJkQNoTQpkvsghBiAg1iQhFlYQIBRRRUUFA2/lUEQjkJnpQlxDQ\nlBAzUR+niHAkE1FARNQAUMEAweuLKYh8+8Mj92ihQBRQAVQt5Tx++fHLHRhKiEBUINrCJ52E5qwQ\nDWE7QSKYQ14EICJAu3ZBRnsMCBAipLPFGWmfvT21FqCogqRIkBBVcUjUqQQBigEqUEBM20cSitRX\nm0rK4uEjiXAQAm1PqahatsOXn1yZKaAIQCAQAiKMYIQIIICAkhntngmCBCER83GIYDv0ZASChEQ4\nJeik0ANksN2MCECiBohgiDAoCYBoTI52oNrfrHQSaHEtgM2bQw3Y6QenQhJRwABpydTMDp9/uhGT\nUNP56EFUVSFkBCAa1YMgnKQoqHBSBC6miBa0IoI3QSwAUQToU7h7jSBB0quThLTDSYIS7aaSEoQo\n3NniJiJICKM9BO2GxpurQkp//0NzRwvQZLiHmOw/e3wASsBMoWCLzikJAA9AAK8eZY7IjDnACkXb\nPQgyKCoCMjxACsRMIiI85ntDtsshIsJ2d1r+NJIUUxFCNcIBEgSjvYkqIqIiIpDDm9EJO3r/bkS4\npnbIIyDYf/rLqygeFGPSFn2gkuzmxotSVLUFwzl9U4DQVnmIMCDzmTBSAIRoey0oIolAQKESyrl6\nkBbABaAoolIUmO+nQCUACFtGa185hJB4c9qbSnfno9cTWUxBCE1N9r/+dGuSYCLepRZEAbU5QLas\nrxBJYJjOFQp1zoqirUAIbXEA0p5sDSGgFSACQgEDMCGpjJv6AzpnUAABgiFwd4EAIUB7Simcb785\nds8GQdLVB29+5SIBhShEdfzq823SENMIbfkN0vJFy2gAI6IlT6FHUKL9RrSf0W6gMGJO9EHOtWB7\nMghQwxlsDz0EojIXTMKbjMoQiIUzIhiBdhnbAygkW+AWubp0iNjJ+0e1eDtRZunw+S+uwNRlBURr\ngsCkRcb27N98n6gwgiSEogBBQTCRAmr7oQFqtGexxTSKUERDEN4OA6gAJSiABBiqBIVBmjjpAUSE\ntOKtXY+5CJprBS0vT/pE7e5/cD6pBKlitv38sw3EoHARuqu12K4SXiNC2leSaDdFACAgDM7npd2f\nCAZFWr4QkgQEAYEoAZZaSCLoDIItw7ZbT7baj6SIaPt1eJAQIEA6naKtGiFJXr0ZnbSjDx4wangN\nlf0Xn27UxDpTDUaVRChFTd0DEiK8qaqkPecQtNdCqIAowFCJ+Tiy5e24KcYVrgiJEFKiVSot/yBM\nGSmEFIlAyx4iDEYIRL0VrAAp0WJjC+/wV2e9gfnWh69GA9T08PXnO9DM6A5G0BMogFqUSgBJNObQ\nB4gEdC55oRBtJWErOKRldrjonLuYEAxRUtOyo1efSq0a1u4a2pkDtH0vkQAixIPt47YLOPcbrqQz\nHAREtq9WJiqLd77+nBS1w+NPNhBTUyVDKy2SSECVpdYATFopI2I6F82iTkBETUIFc5UFaQ9Jaw6g\n7QorxXQA0nJtFkljvD7f7ScxFaWIQOe7Q8icVOceLQBo6By6QIG3Kg8gBP76tDNFPvvOm71qmr7+\nfAOltkjlAVASCKgxanGoyZxVWomIVsGzPdet6lORAEWp7Y2KokYKydT3SOpY9pROKV0Mx+9Mb15f\nHmqiRAKgflO9EaL1plEUkC4xh/NoPVOQLiSFwOHFKonp8u0PfqHZv/nsShGqBnHCXRRIDDGF1zqG\nJKAVhwRb9lYgKFQqaCqCCFOQKioEFRYUOgK0vuusloRhKa6SIg7FOhnefnvz+tkFVJ0a0rpnoLLl\nltBoBTklRLSFrQDAYLSoRSH9/Kg3RT799vl5PPvsGgKFIAoRSBKBRLEsXuo01pzmvksUgJNEKPwm\nMRtBaItNUGG7cKSBUMqyL2S4ZB9FtDA30IKWz87e+frLrdFAgQTnKspl7lhaVGknRFpean2PiN9E\nm93Tky4Rw8Nv/9XjzzZIKqoGMKp1EYAmiKqUMu33kZKqANpCKV0NrAoSkRHhEipsBYU42i8D0lmK\nw9hZ1XAMQj8kFXdxTYwkYKTVd+79+kkZWgVjcZNySReZQzd4UxDEb5CP1hA1wOHy2bpT6vr9X3x6\nTiEttYZXTYKhksSSllIOm11aa2sx5roIsLDWH0soENoKKhjbAQId+c64LIf9gXWXEDJEq3yM2Mug\nVU0pYPD20dkno0Os5e5WEUc4lQFI6x6DJhQp0lrOIERDQBWhv7ydRSUdl1eEaUML3GFamQBLahq1\nTturwwnQomvrv4Q+KZOEUiWgDCVAoc7PBsG82O9wMTpFQhkl+5gH8yyTiiZFb3Sq0AT9d+786rFr\nnsEezFk05hg4X3cXUVUndG5HfY53Yvunq5yFL35akwDJoIhCM3cHsqaUpdQ6Xl/GcasSMN9ddw1t\nBwJUaVUtAzegEhly7+p8ml8zRFBdorikQwxijN5ImDrEGLj1O8efbDwlZ6sFWpcSAcrcykDmGjkk\nRMG5cYNRlDx/eST+5h/+UimqSUI5hVp4CDJ3ydRLGXdXm67lutbOesNvFNB2C1wJhlChEtCGAuX8\npFYjVCqsqovDqqiOOaUsIuFqodEAJ/H+e7d//rgmzD11zGEDpioqTlICgqSt62kdAzHHR8SLO931\nP/4nbgIoQlGpWWrV1PFqlzRqrfuL87JAq1KoJIMRhBJJJVoxKCLqgKDMiI0k+7V0ZAggakYVTcW1\nZIUFFeHRVYE2wA4i6e2TH/9qZ5rUWsMqRCuQBYLwIFwoYhGtqWrtHKAUEavn+//tH22SQJSu4d6L\nBJMM3eU2JUSth93VJImikIZIYUYTG9ITJKgUhgjbhaIicFJHVS+SkHNI0cTkNUlWTiqTmqUM6bSh\nWySEXP9h/vHEnLKATp1bAxFRghI1AiGdzhVqq9YgDXFQ++bnf3athJmGBCfN7i6pW2633W8l1lr2\nFxtR2M3jL2x1okCkIZk3OKaBou1YiqS7nwtLaAcg1L0jiQFOqhahqQFi5FzYExypyz9+eP7rJ1kM\nJF09TGBJRaUS8LEqzazV1d6O+VwKqPov/+01SCZT8RLIwRDtht1FfPdvpVLKYXsR2dHqTiVdVCgi\nghDzdjfa4yMEQwFhGIbXW1ZIMndkKZ1UmNINNLdgTknVqlIVLgwPTcvc5bv43n//mIPQYd5QhohM\nwCB1PKiryU0jJPNFCxNTef7xtRCiFowa2isrJC3lOt76vQday2FzeQgxzEjlHNRDSSgrKIh2cxCg\nmKopQKR44w5NiaGGqvQaUVyqdzXUCFVDtEEDxrFgfet0mViLvvdHsj1UB73hDsG5bTCdrs8vdqOL\nMRytSQPmfHD+k6cANWUlCKYUE0OXq5CHv/8wJ6+H66tqN1D9PO2IANsxTU4ooQ15allASBPdbkM1\nm0Qyp6ubWTvWXtW0dxc3qFX4NOX1ojdHFSol/+CzPz/AIC5gbbARQYGYHPY765KqhgBUNuRDIOnw\nyecjVUUNEh6WfILYshvH299/tytpPFyf75MREvxrBwSWqjAULpTkDCPlBginKqRuqopKSEJI0Ltk\nMCBCICYWEgyYie/Rna671guLllA9+s+efxGdKcRbdarqrYmT8iYNq2RZ2Q5IQF1ATfHNzzcgNJu4\n1NJlFtpC+7w9+uC9vlDHzfkGppjjA+aUpzOewoAqQcAYIXMzAZhvA1ANJq/qNVSFKh6wjKQSwRhZ\n63R14WePbvdRSuuQzadDvfd37h7GGnTMnz9atSjJyvWbq32hWMOctRXLKuc/v1Cz1CdSImzQUpG7\nIfHog3fXHPe6u77yJC3I0oPzICPIEBFr5VzMwJ9AVBUIxG6CKUJzDU6TdkkLcqhZUjWrEyqN4S+u\nj98+TWWsLR+5V1Mgf/tv9dN4KO7z1WnVqplm7C+utvtKaWgIRFTN0u7n36xOk2ZziFf0UpnMQq1/\n+GilnDRtLg+SGRIxt8tz+qAIxEQj5sbWbW7IVGByKEgCEXUGXbUDu05pELEpaBopOa+ms0c9a9EO\nHgGlJaPQLf/h13/mKdRgrcckRBpGVMftfhg0m5M3oJrhya8OVcLEwVqi62s1YaS8uPVoLTEtjtLl\nVaRWj8y1oAYJFWu1tYFo9S4BKjwkVK3QIAgz0hkpKxXKAhO1MuXIZtTry/7BnY7wJO7uGpqDJQRe\n0/I/nX61DctSeYNnCcBwmG8O42Rq7dAKREzOf3VdoTlrIKaSu3oQkc4837l/mmuxdUqXxQRQUuEz\n4kFECIGawxquEqpQUCVEVKlp6yatUXUzQpJCkItWiyhGgzO9vrz7aKGFfVeICs2sJRxJFR7prf/T\nz//xkwJLNwNPp5hABbx6YZ2aSoN0RMzGT76sQusUoEcedB+dQdSO7h2rhBxnSRNMWks+V9Jt3tQq\nkdB5oCekCRkNvuvKFNppoJVEkC45UoLRJCbWRGG88g/uhTOzaERSLZMkms4le9TV31j/6S8lp4gZ\nPxWNCDGbJh2GLiVtMUBU48nHW9HUAB1Il8fIAEJuPzpWOtYdkMRUqCGqDdSVGaOFAhTS22w4rE1R\nFBFHunHXZBKCQEC6HJY7cTE6QgRap3H9/plZwQFKkyroUkUOiFRVhnjJv7X/5nzVBaKh+WRtZVns\nz4ehK52AiATRy4/fqIqZOFN4Z9Okpg49fe9eioplD4S2zkYhqq0NaXCjBJUyA6UMb9h4IKLKO9w6\n5kEOgmbhZknrFGIm4bXudt1772QKUaVGHErkzjyj3kR3hirrD/8Ypbj/ZpIkMFOqlovXF9vD1E6P\n6PTp11RVNTcLF8OExHCs334wsNS+F0BSan0n6WCbezagVMGEgCkDOo9eVIKIo9NPvGsoglKSmTpE\nUzmIkVE5ufDew9sBnwitWqDZQNFI+17aNIpAEu3+6LMvpuJkGwgEW7GnPp73ncQiIQzg818cBKpG\nMAKdVsnZD3723tvLqL4eLKiSFBQoSa9t3oM2wEEQdLsZwkrSVsY73sZWQHUNE0sEkTTpoYgi1CfW\n3fL92wMWY5lyKpMtxFyYVAuyKw8iSQkEVd7/r/+vu9pOSOvhIKYuGptngltvn8rT3fHuzc9ew6Ap\nVULDxGuyKCXdf/fYJ18uDaRLEsQNeDqXUa3mMhHCJMRAhpqydYp+/N43UyJFmVhURc3E3DUHIWXH\nuj97535H38NQNtINqZXD7eGsIEyV1QGk9Nvf+zde0XgX7We2Ybpf3f+d3z+J6QjLZz/+IsTC1KG5\nqriLYfThwbfOpNR+mRokmRrHQR2MeZbXarhgNbdiEgFoeKiSDi8PF2+qQCHZoSDV1FIx2yPpbusi\nb7+j2h32sRzKPvUWruJG9ySVktFXV6qGUop0P/yL4pznmARCQbCGfuv//IeymYaTTt67szMJmHry\nERCRnKZJTj56kErkVd/SDZOEQNFg8/hrAIcXBuIGdhQq6BB4fv9QrJpYuIhRqUQQ4y51st/Xw+m7\nR4tUD3vYYZ+OELUT1xDSpjazLMaaQqQKJu2+8+6zUhviJKCER3jU/g//2+9ebxani6Ta//6fXoNJ\nnSJO9TBx6vDWo4W7LjudK+fkgoZNa8OnlUG2uY57sjbJFBCgGpxHt6/b7C2UoLpl4TQVP+p8ci93\n3rmV0mHX91pK1qqKCBd1xKS50wiNmmLKACNhsrv/yd+fAmINpCei1MqTv/Pf3Hl6uHukUsX6j979\niTKq6hhMqDC4929/cMTCRQdHgwpTYiMDtAQyl8BAhJg4tY2foPMcQaa7y1ebEko3G2GeBqtThBzr\nIcZxeuvdLnNbGT6lEz+klVUUKVkOklPipFogRUELS2Cx/Puffh4VFoogGeN+4jv/h/9UPud7C5aK\nSLZ+96eURp2QKGoqwPH796xE1zeqDyhM0igMiQ2CafgD22A3KW0eSavAlEK9E8WLAWL7cRVdzzq6\nLBbjPte9Pnprmco+8tK9Q2VOQhtFbMO81lo1xiG5ppkzAdGDr/7mqwp6GyWUcZrsB//t71w/Pnq4\niOoSPtR02s6tOKVGEoma3no0lEi9zXNMNC6JWEA6qFBIhVDVREzF5jEPADWAwoln7qIGjWtZWjZ3\ncem7QO87PHwk3FfXNPWDjHXofVQYp6lbaMAmGotqoCsOZIEwCt79k09qhJJwH7ebW3/0X90/P79z\nx2qU0aRLAFw1aymukCRSRrn73nG4DBkzKtLQMMwUEd7MXWdOj6haSyk2D/sFGI7j+mAh7nJkIGNK\nXRiYpjfy/lq6cw6r7XbZo2SyCqrGdFgsMEmGuauUzjav0pGKEK5w5N+9+7GnAMLrYfroP/kenvrD\nlVTUKjATqaOoUkRBt6zcH87evyeT9501yI3i7Yu083FD5hPY/Ep7tR2fgEKEXK8O1y5TktwFXYex\nS9JxX20j75zZdueQq7zgoTOVXenJuuNKwzBSDbAwTud+xxrQVEPI9E58AgnGNC7/5IfH1/t0Zwh3\nOFVNheVCESGmJSwJYrJH7w6M1AJWowwxklK0Eb9MwaC06ksMAMJEQFQlAsrwadG/PGguYlZCpPeU\n1GIKbNJbZxGlLPTVyRKSIkJyrYZrnnZl6ozOEOTE8Vm6o25wIQln1HTv1QYRkT54+3R8XRfrVKOC\nVEkGcP9czGpVCYeRRW9/cCJVujw3AA0KTTPECkk21/8250VADCHaJhguAbif8tL6Qw8cxPMQUweM\nhwRf3urC92l9qad2GJJXj64f9lcx5MqujNGxLrJU22yOejBmWF9kUsnd2TXp63fuxHnVkwElglCF\nqYjg8lxMLG4g1rT86FHnrjZ3HnNySDIPVAmBWCPjhYkEAEObJrSiUYIVd0JNoGCuWaLkLkYZ5LC8\np9y74UKOTIJenB3c95vTfmIANe+TFj3sMxdWgJQkxUGhGkWnWFlZ3r7THUb261RrCKDKlgPk1SaJ\niLlXzYoun721cpfOGotCxQMgUptCzzQUbZBxS/WNeKQgIkRBF5Z827WzNNE0CZCNU+6x7W6Z+rbX\ni3ysZYHzIWUiXVY500kRBRodBx6u+myyS4PR3GsxK8Kaue/uHd5ebA/W5SwlAoAJGO20vtgJG/6e\nCMurR7etRtcwNAJs2E6KOaVry9cqNyEMjCmbwxAEQmB0Pz12tWwGqdFLMmEaYpPvGWvJ5TovXZcH\nWE1d2mzHfs2oKVGqdFxy8rUQuqRI1Ni3wZXUmkTumO6Y88AKEmECIKZOSKlfHjJcNQIilvo7D1cM\ny/obeqEIqZJSCVBCwAjMI90ZEmpko5msBaiDy36TrFumulehhPcmHOVOln0sy3axiFSvgHWne62H\nRZbCzsckElO/Zc6ufuiTVmJyL51FAfZ1fZTr9dClPscYM/UJ5DVyCOPqkxCA7jRNNW49OJWKZDpX\n6aCCKkiVsPaxYx6r8oYMLjd4vIRQ1QF0cA8dwnwEY7JFlX2+ZyjM0/nxmqrbabnQvvirxa10uOot\nulKT0veWwWoWVVWLi1RU1VonnpzgUochmdZSDW10DLk4zzlT9atvjJGlJ0VR0umdvkbOAvwGw4XS\nIoXS25lubRrlJnmomoAiEg3cp5BZAysH605YrRc77E7vUaun8eJ4qFmnqdei1baLtVTqOFhV33LB\n0Fo8q/STSxQ3W4y+66fD8m461L7rknhl6MxThexe77v1lZzgLy+yqsRYwZL74wdrhHbWxg3RAAZh\nICFm4rfMRNfGT2yEaGnFGkGRkCDhlehLz0Gv9ktZlbq6naZDWpxfnh0ZK2WdfCsyHh3j4GY6ap5C\n/ZB1K10Oj6EbC8LroDj4mO4c+yYvOwv3IqKN6SIi9fVW+1ebR4urH8XCqqJ6YqxOHt7Vwk4A+Pxw\nzfkj3cBgOp93AmyMpJm26LhJKxGRnWDXHfLywHQ4HsvqLkbCLq6PeoiPqZN9v9znpUwUq1llPy7I\nzK0iCbSQg4wi4Yd+OsjpmV/r0Cfx4hKSRG/o25dv5GT88gz6688x9YtpimS1W95+Z8WQZDcMirmh\nFSJRHcJQwmaOdWPVykxGCYE3TqWHx5JVZCQhnpA59fesjildHBbHqJysMtY20nq/ln4hI2gc+3HT\nZRhQs3NyoSuqH2J1Zte66PrkpTpERVVAhaiML+T48GR1AvvZ3iJKhOaQfrh9K03I1rpJaSM2CjmX\n723CB8xdCQTU9pw1OrCSiQQcK5ecd0ncFukQ0/HdVFxxUY/WzFNJpS6GupETO4iyBMcuqdVLlwVc\nfeoKWCbv7JCmms+OxjEveoudNyhTtUlDVPF6ezy9Psvjg+3PLSurmoCpO77be5gBM+O8cR/Ixg6y\ngCi0jWnYJlMixojWBQIkKkQZstIe3eALSO+Y1mcTXLo3dTGIBmQTR71s9YghY7fYcz9aFjVZoyLC\nbDRalDSKH/T4dt1zuUIdi2unKiKmggDU9Or5MO2Wu/3Dsz97IpogsbAaabh3gkA2kNHYor95sphA\nSU7VpG2OK2yc7CDdE93QKksiIsJSCluMk+bVYdTbedpnPZflSUy5xoHdglc8yTyMo6TefG9XT/Mf\nQqLCgomjliTObbaH3b4cdckONZDSzGhHG/BKvEBcb97s335o/37Xtal7qqYnd3ONlERmztBN20EC\nqZFMoCqtVqSQqoxGDjLlzIRQCZAiJlqlUy66bhl7lXRRl2tz1Ho9LGSssub+6oVinOxscXho9s2z\n5ZB0EynW8Dcv0gcssbjHy/5W1pgqEqAmM+IMiBrON/7yq4v+wQf3n/wkgpWhMYatHhzBJamgEYjm\nNqMdl9Q+vzDsZqTdIDpFWDsviZh5VIhq7DSNldC0ttGsvzislzrC625Q9Fe2Tnz1/Pj2Ybc/fDW+\ns+4Gvn7RxfFxjmdWhxf4+NUfxoP1dVkfDSwUA2jSmqJGyTPZPL349Bn7s299aP/0RW/S6G7UW3es\nQpM2acpNxGqgLlIrP4ywpK1gDLlhFgYCotoYW6KVKEzjMmoZh5RLWKTLw2qZPLCrnapfy+CyG9+x\n6/PYbMrmXrUiq7G+/vTYMO2uNlhd/PyHH9aXdnvFSjdtmMLNjFUgWl598cWTTdcdvf/B6sf/C1IC\nPYjJTx4eo0o2URLqrQhpJDUBUxPCRDKBzlS2G9j0hi4LtGdSs280NPebcRvDRUGn17ujVZpgW/aa\nZK9HOm0vjvvXL63GyUbs6RDb5y8rHvCzfD35pOf5+/f3V7dPevFG1poJdGAkiHH34rPHl7va5fW3\nvnV2/aebZIqQ3mKS07vJo8szkUhJBEWlsSAlhRprowQFVdjKFxUEVQAa6xyYNYRPiwky6y4O7Fh1\nv1hmB67rUuGQBWp9tegvr05kczd8fPPqrjzZXS3vb68l191k9uhPvrOp751pdW+P9w03HjDqxee/\nerEV025x561vnbz5J38pKQEG2iGWd9cM69LMUwu5IV8qGCJJldAqmVCyNVttkjDrY0KFUBcJknhV\ncgkzjuuFXsthHFYLSOxLhtKnhdZ6sVzuL/puf/j8auubo323snuc9ofoQzA++tvf+ubknaHc3OiW\nckkzqI6//PdvJkROeXn/nXsnv/h/f1EtM4hkgNy6lyOG1C51OyVzxwEImFL7+6KNpUIxD9xFlJwF\nVC1eByxdXJ0JUsXxkcTyep+HgcJtdKZR6zJxetUttI7lab3e7RTDWe/bssT6dd4sD7v+e//V0ce3\nP8iH0TqGOE2aOksYMj198rM3apGWt28d3bp1/OLvfp6ytUNgav1bJwjtcytgb6DdebINiaQgqRKz\ncgWgtxJZSQUkGmmZKiph22/u7TzyopauO2xTvxTWfR0ImabVMHGc7vLy1YvSd1GqDB8syhfPh/Xe\njp4cDgW/9d/Y14/etX1RcyJcZ02GiMbTXz+9czQWOzq5dTT0Q//4//llJ5pIhJgxHz8cwC7JTICc\nw5ZIkzgEkwQkRzCkac1k/rYKIGpuLza9BdTik9+V4PF2FxHbWB0nxnZcKKNOKRfRXvni8Ru7e3gy\nMeTe2e6bJ/4t7m599upWf+uD/2y6+tbbXqoYy41yh6JifPnLn+pKT1ZYHPVJIh+d/99/oWI5EAaI\nT3bvRDw6I+SG33oTe9txSNq4qdKY+CJo4kgC0eA5GueGXsXsyWYhsjiVq/7JdjVk8fPDsVHrPvJ5\nXZ3Qn3yxWa9fvOGw0+WD5xdf5+8++Gr5yVZXh7/xuxfjh29xcioqqVQqRETs+if/qt53390/VVEi\ndWe7/9fPRbUpc0y55/rBMiLn1k3NQ9s2ppVWzafQxiq9uT43z54CUCFnIq4YVWDd7vH3OaGu08Xk\nyyw+lpWhwHicuDm/HK6f4VRejHa0G9ePps8meevo+v6vx2/99O4Pbz2X753oWN0k6HB2AtAkvvoH\n5wt5eImJhInY8mj33/1FG5FBEDCa3j6Rit6g0bBpypxHMLMc0gwBULz6DJq2HivCm8AlZlmuKtDv\n//1HPa6nI+5e3slFpst1Rq1Wl8m68eKiinXddZke+TM5Sk8v0mrx+N7T8//48fD9O48X37vrhcpG\nxFSGqjDq5//sk5O7l1cne276IWtaLn/5px9HsiSicIggsLjT1dp12vQqwMyRm8VBJKGmqqYRTSiB\nOYc0FggpgInUaDQo0dw9/onlRa6rtFprHrddNkClz1fPX4/77bRb2uuL6+XbF55OpnNmfP16fHJ8\n9+qH+OXp79wmPSBUrSXMAIaXMumj04vl65p4WZ151f+r/8vPIaoBM1YQIt3xMfc1pyY28SY2Upln\nBwDYivkQgUclIigioDb+SRP4SKMdqwjEMv/sada7y3p1r1S4DibZtOoEi3xn4Yvl5etL3LO9dV68\nw26DVO7s/vC9L9bfWdOjVnfCK5Oq+XSIvHj4XuYadr1IUm19dvgHf++VStd11sR/YqqL07Uf2Nus\nLpxx0Jkg38CeBGGADW8JKABqC2wMtqIXSklNJ0X6+V/dNaY3dXWlKEvDJFOXNkybJ2PG8Xo7ZZf+\nuizvvhasjliFZyfl/uU7H+US5uEhrHWWdFnKORYP3rzaPrx+vTgu0eHn/8snNUmkJFAiwtQ8dWe5\nRp/niKoz2tNocmw6igShhEgwAoJI87dtoulZoTKP2UXMLdfPD0dX6G9dW9lnkXEni+Spf/P4Wscl\nYyuah+v+VtrtVFhsHC66zXa5eNvqqOZTNXUVoHruVUQlre6/vLq4fnD4/AeZuyc//2Q0agbEBO6U\nFJFPTlhTbwRbWzE/KMKYvwZFZ0Z0E1X5zN8Qoc76NtWmcZcmtlDTacwLXXBKQM+xDGcr2b2aPGTa\nvNjvok9peSkf2rkwyv78+ird90lWVA+ph+KgaSCmtBxMVYK2PFnmiy+G/XOkshsFQeRMMWFFUMKG\nWwt6Tg2zhvA3lgO/GXsoUtPJNqkWbwSAgIIiYjOhEI0MPLNRIH15/jSLHqUNj1bwst/sXvb5aoxs\nsVo9wXh4b/eUSkeRW+9cLe6/fbYS7eRQPDFhXzs5WonOagJZLUxeLLuL+7VL21CqUBCqFapk2NGZ\n1bTIrgLe4KRsLXn7BRuooE186IQg5jcz1NgwuoYJN+CLaoYKmT7fnAR0qqsh4oDxs+CDqz10f8fq\nEkxPH71/eVGKV7z7e/v8fXYr06TTNFFMp4PaoleIBiURfZ8NfPqoVE310lqyVk0OD00e3dlaOfTO\nbBCJxlxpeXCm7kubUYTfqNBxIwBsQ2nOIlwFxUTgFNXkFypPznVfYipHi/35OFatu7P9NYF6nqZx\nKf3063Tv0Wkt/vA/W+sfPIpdck0YJ9L62NuQTRgUzcly16/WXfLDftot+nHn82wAjEII2J3czm4r\nq9NhdALzxEZuQOz2lCXX/0BBby+KQERNbzqq30h8bKjqlstXP+AL9MSkp6zlefZNZDl9PKWevJJh\nq4vDapzefIVa7v7tvPv9W7vFfiN5KlUMuVzqIiczo4JiCebro2Ghfnn68qMMGjq2XrUGJYn3t061\n9gslvSYx+Q/S3/ZQtUvfBI1AzLepvUEZM59Z1MxEwaj29nFQxLrn17FN3eB1aVrfHF68fnrtne11\nOaTA5rIe7Pb9w/P1uI/T//zo4lt3Jul6uKm7dd3hFTuzdFO8KiGpP151AzeH11dYLbOGqKgZg5Iy\nws46l0WnKkCt4ziVCCJ8pjHMQppk80kndZbHsvEBBWpmqd0qUmVdzo0Klc2vfr9Tl7SU0q0OGi8O\n2+lWGdPyQHZ1FD/c083lu/3V0Z88Ov/wLrraFQxJA4zxkAaVnCAQ2sxC6Ja9Za+7xTVSv+1qE6ZF\nAKDb6bGWbpGqImZJcoXYLEhv/wZC2zGnhzFmBQQEFEHK1jhiToqqlQvCzDTh17s7KaZ1ThrnmxeP\n62mNfRwvrxEmWNVxef5ykW4v0h9++OrWfe+GIfOAzqcou7HrTRWtpZKms0x5MSx6kymuNJYsZjfw\nvym9f3gCdkmsy6lBofRpGvdTrc2egiQpiSCVHig3Zg7a0gwjTOEGyIw3docmDu22vzIbV3012b1+\nfB3vXdU0Hs6edUevrGg3YVv7J/Xu2cMfXB7fPXSSxPcnQ706wKO3XhURsxYZAFXysB6Ww1gP52WR\nHA2GhgtEC1Z3uiJ9QlMYBwm6hjOm2RKBoiBTSOPBRosGmAmNTdgzS7KFEFGflBVA5OEXuQyrJFN9\n9fVOr6b9Wndf3x9ubYqWoxjq8/Xtuv3lb727T2+FG2hLsXS5mURMk0J1rrd1Jr7YsBpWw+S1bgfv\nD00MWScmZenfWofnTqIx3JxsslhRDwqSaGON36ifm7xtdvS4cRD4DYOraZ9cFRE1hv3jPbLXsU5X\n17e9pKEvvezftSsXX+3rWZc6xXa4mD6QiCkkFvcXm22xvjNTVUi4A+5OujPIYZWTupT9drGmWeO0\nBonA8nZfIjfisRCWcjZtGBwYXqYxwmMeAjdDBrY8CswVJRlzNmxsbGUg5a5fnAx3P3y/VHcvL7uT\nEd3JZeWDd+v5pEh5Jyf16HNd/0F3/e3eBgmI6Im+eb4nNCdTUzJqjaZ/igBhXZ4c8N2mWyz6ZRJh\nFEmmnk7X6uz0NzwMVTVVM5nHB+HTWGowga2YVxIebPVv61hUSaoghBRtAiPw/tn5Zy8/f3dFP+Cb\n7f3DfrV6VpadP52uPbHbTUMkcf/++49/+9Z+NVWDme9/8tnyZLUYFktVZcBFE1itEWLqftqfT8Ja\nr9NKpVJS3gdDyf72MqLLmL0W5rgqphFNtC5BBIokgopZQD63Vb/5RzDXmRpChgkE+tHLf/OL/fLu\n78WPXu++Pj3+Wm7d/gwsG8/o1Ff9Wuvw+s2j33n+8NG0WB58SCoXP356hDevDsPJ/TunFsICCzho\nYvDp4skXzw9OQi70BN6PglqaHHaxzhH9PFcjJFQEUkUU2hSmgiAdSefm/WbWQ2Ez8oj2B1MjApJK\nhErqXm8PSY/vbF6+fsm0qherW7vs9ThNaX0oNW7f455Xqz/y9XemfoW9rTKffpp+gNifP37y9Sf3\nPnhnoQFvKgWX2Jw/+frNVI6iUvYbXy+ufQiMVR1KW57oaL01qLox+G8+GhLgGkGp4UggFN6UIwz+\nRjIpYm1qJSRDRcEpYXGqr/t+aXX6+pPww/HwJd4v39z6ouuvsDr9ZpLxyfv6Tdn/jbfOf2Ac+vPp\n6Eg//czW2Nr9vn9xufvm/PKjNU08FIBPrz/5Zop872Qn9VAjqqzgKQzRuG5HQ/gqs+noyVnMqwyB\nUqBI9BBKJDdRSoR7mym0g66i6g0E9tbj0wHpTo+qHP/xZXxdBn9zNdzS7vjsR+m31qH7cXj0jMl2\nn9yKVw9/8PJby2m9iGl5Kn/+acaVnV//wcudP9zRt6+QUnRirLE//+KVLq5jezn6cgorzgElN/IC\nPI5u55BOZ0efNr6RuRUJgSghuYYok4LePKva8NMoZDRfLFcLl5lJAD0+Giy0462VvvjReFit7p9+\nvFqWQ7pcbjrE9f7tncOmx133Bzy+e+iW3T4v8U9/gVoO3NerSz9+/3RZUDarYCgMcXh2fvl8t7dO\nMouHBKUfpsRaKK4x3TmzyX6j7wWhQGAOp0LQg0JNzkRpJkSM6oEmgAl4c/9inXVQ4KQPvrP+QqrJ\n+M0XJ19+GR/eOdw5pL7Pmq6kz5DtF7/XfXKdhNfff2f37ZL6XH1R//GP6zSKR+SdXW+ufu80hfjY\nZyCZvfnlXz4/0EqlFNsLWErGskREhahFDCtg0c1MZBGRkHa9Z9V4BKNZLKWiAQRqKbXyBn8kiahK\nwGFBStTyzt+49eRgKtf19ZOvz+Pb7365enpxGyf72k/jermIevmr73745LzWk9+7/LDb9cmK1X/+\ns+k6FFG7W3+z+3efXv2kexgMrwHr7fJf/ctrWd966/Nvxj4OUEr1XJBtqhQguLjdwQbjDJ4E7MaS\nRQMMDY8GZAWTtijmDG/Rl2g+WBC1ZCpBkfDg6qg+F9Hy/GQnr+zte7+u61/07yqem2/r9OBkb/L6\nr+L4+Dx9b+EPDoN1ovyzX+MwLW5d5ZfHi/c/eu8ff3r76tQHdwfzUH/8r6+rPXz43dXluRoAWKmq\nQ1EAqphw606qKZOzzQAQFG/jkea34oymdkBqUp2YByIziKdCACZmMuvocPX5R5vbxzse4urTK763\n+oWcvra+366Keeh0effNVmU/Hoq99e7l749Tv+iDP/t6vRddvdVr33tef4f1yDeqpkm6nB5/PBjK\n/s3y7q1NrRpJ8hQL6WsJFyHQna0ZXSO9NapJS2xEtH8j6F4R0CFpSDN2iWaoheZEBklZRSRJaGi/\nvXy9efr8d753+OrPP67X+nDxPG4d0uliVY7Es0m5OGZfKTgI7l/f98tl10f9+NdncnTxeeS3vzL8\n3oeqJ+9eLkuuTrXc+xN8/xuN6e2BySrIscv72iG6yWGSCnGyCOTZP65VHLNDA7w5dzi8BNJitUz8\na/3iLDWe0UZTUxNTcH94nk4eHN97eHf7p19dV7n/zgUEq5QMOznuA9Ddl+wqGLX78L3dd6fT9VLH\nX/56SIuzM7mM0+9v3v695Ps3t6EdE9yl7w+XZfHRC3nrh3VfpaOUnAR7y8yoASSwP5GwrLMeD6q1\nAVitp40gIxzDYjkMw8wTaEyBuWAHRMxEVSFJmacfjae/ff/ktE//4h/+4uGD8dHdePw03YvusB60\nu/d6v6nVdhbJnOy+e/jeyWXf2/jjXx91YkP+9vlYVmen56NGfxRjmQQOQHj8oB5/186unn913uc9\nmY0+5fUrFp0EQH93Bc8WIQwJCDrU9nQxPCQqhMyro6FPKSWloZK1yZUbxE21hlmbqkjnh0c//M64\n3d766b+4jP33p1ff9M9jsbxMm0Xcjf2jFxfIDDL76PhoOT7YLPqu/tXPjvpODHrU8XJ/cXg96qPV\nMV54p0CUArut12W3fX355WFY1H1U6yMmOYrkbhDF+v7gyBZtYkhy8ibeDpdwMii2PFoNnQokkVKD\njviNoVLT/kmbxlHicPxHD9bj4Zs3+Wo/8dXz35pubR++qpPkVM+P6t3rnyZoUAB6ffDDq/cmzR3/\n6t8dDSknsfCFIg7Z8oKbcJOO4SU8UkeUap6WeeC1RLiQLkdaKumrqOk41dTNjbkS4b+ZVEUgQtif\nHC06E2kAnbt70IvHDU5EoWizA2Rwyr+38IsvPysn5bqsx8cn357s7PVhk7YfbJ5D7dXJbgdQWSu7\nH6bhERfL/sf/sl+opU6ZQF9iqSl1KSI2XYpABUtanIilute7k9QxT6LujEgyUpR2qP0CTBpkm5Vz\nBnHDmxorZH37uE/N209SeGWt4aVMjoZk4wY3mWv4y6G+vHo83Yo3vjgrd/Rn48mznR3Ong/vTq+u\n02IsQjKMpm+/9+q9Pi3Wn/0TW5uRopoFrALLKkFeo7OA0GtIGg5VxyrHB47ZTNUQFZ0ipHIq6bij\nZhB6Y24jbGRqEvSiJ3dPBuVM9UvVnR4MlxYdGlyNG2sRQNReyoD3Vq9+NeaTk3xvvDj2c++6h93T\n9Vntv/XpQUzc3ZSL3y7DQxv6l39/87BLUIZZgvRRYRahUdzMKYzCEENfqmZE5zWBovRp4kIPU4YG\nF6fWZJuz69o8iSPhZHU9fXA0GJr4TyQVkwqp/hvLxeYt0LAtVukXi4FHQ7e7flIWPEl9PbynPzsI\nNy9vH9eX7558NX3/6uXWAjXqhw+efrTquul//PJ+JwJPACOLmEmQGVNtg5bwUlwSkSQgiZKSwFTh\nxbUfa8dEOTq1yKmZkalL+yrNwJYROL53PFibjohCUgjnQw1rhdYMNkRIxLA+W6jj8s+/XPDSkPua\nxnuHj6+HGvtvpvXtxSeP+beOv7pokv3uO9Nwrx+6//Vf3+qTNWTcKCbqpvOPSM3RtRbSpDAZSNUE\nipm57B2ngQqpOFkC3exFGE37SYBsE5Dl3eOuiakacJ1qQaVGqyFlngmx+fjChu76cjEd3pwv+wv1\n5d3ro7N0/uKVRaL4a7w4ev7yLWTJAQrrB/cu3llb+vhPkftkCCIsyCYlhsA9VNisxzwkSbioeUBR\nS2hdOad6OIkqmZGOO6SsTUMFidkiEQQQ3t867mXmb4iqaQonwwnOxnWNwtF8EVTGYjnO38SJ7I9e\nrd7ZXTzqn3+6yPtIhB4usVMg+NYzOq3kb0/5vYVe/A8Xq2xCA6takWRKMS9keAOYyMTqsDafVAF8\nAuEWPvp0lGCLAxbHitw11VJzwG3aOPEI5/HpYK2ZVZoaRD1qsFaPkGgGDO7RNHCgOxm7i+MPz3Yn\nq9sfbL6oF7UuNozOVFynaatp8Ne3vp28TvWtW9d3l7n7579etElrFqLSx3FyVgD08FKqh5PiB7cb\n+jgYY1R3kJPlfs1QseNjcLGYDbFnS9iGkbpzebZMM+SmXU6qojUajdQjmjmRewNTGTGs3AP5vXd3\nj4/T9Ojl01V+VoaTfVokcfcp3d3WznX3zXf+5gBZfLS3d4bV5/+09kr31lurwj3cI1rvwDo1J8qy\nc9Ua0aSOZRpFsnjsJ8Gt5E69NZCpmQoSjfwXLaRGpLOjGXuEppSarrOVlQxHmyeS0pzuoi6Pqt1+\n79H4z3908u3x7OnnJ7f7+8fD2aJ2xlKc5F07vnXnvU9/8dv3w8/u7+/cG7b/4KUOQq8VgDcHoupR\nahMbl+oejChOlw4EaxCY9kWzqaLsvK4thdtZRrYmBZ2l6cGmuCRXp0NrmSTllEwYVIiaoFWSLSrM\nOccZZXn0ztnVm8+v3nnrWX75cnl8dHx/L/5AJKKW0t97a82799/6aPXzr+oU31F92Pk//AtBZy0l\nNfPnYByKuzOaJ4YH6cE66hJtdICyn4LaQ6XuU17VKHV9bMgSsx33bDeLBk92p+vmd6Y5t+gooqJs\nPVZMjfqnwjnK8GJ8MH69KXn98PzpJi3uyQDZnhyGwyYg2vdu9fb97kU+e/WTFzh7tD19lH/0D12k\nbxBnI/G4e/ihVC9k1LmkC9bDZIM28npM+50IxcPLLklnmOx0SSSQXhu1dwagKaCsT7rW71lKAlJU\nNYkEIeLu0boXRZvdqkh9fDJ0CwfflIXvjpHHusnLaw0PGVy0e5K/h28+++KL8fF4/KHVd4fP/4eL\nHLY4SES4AozaRfHZJbiUqAfttRFWJySrUEawXE4KUxerG2J1fOHHp71oQhPTU1v2plDCPZ+umgIx\n5dS8/ZSiEmoGJyN8Bh9ETSlSubi1LLv9+fXJdjrdZ98cfzXaXvq4TVksfIzD0dG61+XTr30c/uZ3\nr1y/+NlrUDXZbIFLULxhHR6gR+wPmtKiMzhIw+gREeVwRSMNIeUSeXUC7291lmx2eg0PmHAm9WG5\n7pUQWEr4zXRKLQsYagpV/f+Bffvb76y3o/oGNd0rUbYLXmK5q8uru/1xHBaM4d75P381fPGrAw6/\n+ycfvHdrj2cbTSo2hJrOHo9RAwIGPUBUZLN+KYcS3JeqrBFRX20aWSQSr2GykuiPFTbzdluomge1\nDFsPSgBq1iKaCMCkRSDmltKQ5q63+Z0vpHszSWdWXRf5UF+tvvVaN47dg8/f/nobw+Jq9/r55en2\nel+tph+s0x2r//bfj6YVkNT0142+BkSkmaSgne5f24p+uN5vDnuQgrh61ty5a6R6XXofkNYLYScN\n0BIA0czNg2S/zkJCkiFk5p4JUmFQFJKGxZB1HiJCoHJZh35ArYdFqstaNu8uf3nrclpNjzKH67Ph\noIeXx3d+5+h/mvpa33vQL7T+i3+0VbOg9E1RaYA3ckvjR6raYorzn/d30xKX1+s3ENLpz7cGlQyI\n1c3YxUKP73aiM7kPITPSBgAhfW+AiJqBaJNzCaRGDVLRvOqzNfE6RIRTFVOoVs0yPt9cHr//tN65\nYtrWu1x9ZfVbJ59Mb76Vf8IT+slH11ZXX/6Lq6Uy1DJLSrgRbczaFBWIpMU0Ln83yWbqyrQrJjUC\n21cqVcQEIrbZHdvCjm4ZZ0NymSuTYDOftkU2QjSl1soqgQBS89kMZBs6VaG30n9m3FM7r8/vyLOL\n8Q8vn613PXy8ePdwizF9lC+/6FbPnwxvHa1OD9+8+k75ey/ySX8xUhP2qtLot67aSGAigKE/LqU3\nTen+naNSQ+BRn10DKtZpEWDaaKTFcAwmnTtWSggJYQDB1ETT1rC3xg0gNIEu5hbo+xaUwbkAEkFE\n1JMn52ud/AfyU8vlQUb31Qf68GySi/fOvjnbvHp7fXL3g6uPcfbof3486CTaTdRmwE0EQkJaLwEK\nTEW6MfqTKu+vSqEHg9vnIEWtK6aLmC4ZZXkyQIw+j2njpv2OCOTOhDMzown2JRoAx8SAh2adB26q\nzYscEOHqdNGLL945/dl+Wby62Xb04w+/GV6V3X39+q3v59N7+c/3b//g43/Z51AmmhbLIqbSHuS2\nU4ZQgCGKXtaP4kJbWU/Gi41FSA/bq4jxkonHdyxEG1NZwHmc2TzeuiSY3bE4g3YKQaqOSRglhTY3\nVja5RUDCyfXxq3GF9M7hF5eIIldjn+3FEe5d3fEXY43hd99Op+ufnOP4zd+7TlTkLmotS8NNWCSa\n9yLgKlEFYnl17KOqIYLU/ddVCLFgQGuNi5L1+EjDGvMNpMzBt+X2lBRipjcSknnErQG4u6vedIez\npZnAg7649boU7d6SH10fLc1tGg/TxWHhR/l+/Wp7OPvesSzPNl/i9OTvP7WOXqHOMobUmL2+AqBE\nLR7uAbpo7vq+T0htqMvnF4TAmGaos4yVy0WbbszkcBHMO0MikJJAzf4D2N7Kx8RmIC4pamnsskZ2\nEYYi37s8LMbhtHtSNGtU6aZ6wIs75G+t/81m8dsL28vKf3SVv/PnH/drVxF0ubCETI1C2CAZJ4op\nvdHXzZKoqaqUChy+qaoh5p1rcajUaqZdmEHABksLtXk8s3oyFTFr00FSdPbhVx9LU0gzbgibpu2u\nhDysbzBBXu1f51SKLRe9yPXx7mPK/S+fy7d/+PsfPnq4+smz/O3Lfxfre3QwrA/WylJmuE+aeSHq\n6HVy91Dl9vn1EZ+d+3ioePGa2onQFOy6HL6bUs4zx6oV78obnhxUoAKzZvM+sxsEECZ3ACKOpi29\nGU5LUHi2+CqSjK/vWzL1CcPpYXEY/f2vfvJwOH5wdH9855bIX/7SH5393Z0gVTIliggKpmkglC6U\npISUMbIlEjGhet+v9jvji77ff1FMC0xS7ZPuRWUae1nkuU+/kbXNHQkYyUTNbr4HblZDSNpD6ox4\nikJtXp1ARr398Pn+1t3XY9aPLWk3bvPprubY5vfq7gm/+5FGXugvfiqr7/zrZx02V4Anlg5k0Vrd\ns1Y4jZKol19iseiFGMNpi8Wb55tjwbG9eAloEUnEpBy7YV+3x8OiaV0bGZNsIF1zEtOk1k5I21mB\nmcyUAioS1NQ16GX2BQvi6Nu7V/39YcfT/evlEZEVK3uVby0vjz76+KXfqbx1Olz8dNf/4OJHOdU4\ndFXFS29edECpCG9eOp4l5LTst6nLuXoRYfAoL/q03H9yEGOodONSwLq067h+e9kpVJrCQkmwuTE3\nNytLZtoOj9xEJ0EkKsI1Us4z56GtbwAW741f8Gx4degd2a/6nIY4Hq+WeZ3rp123ePdk92Z7+ovX\n/XdO/x9TR6AKLEmtvdMn8VqF1Ggrc+S4n6oHcqd9n21Y65vn2iE+f6qiAU1eM0WSinJvKTVjACoh\nGjMfsUGgmlJjKTa795ZeKJxrOgQQPgtHAajq29tf1PXtzWXkeOuzSyyGw2pcPX95di/vpue37t06\nNo4X+58u3n/nHz/rBMSYtOa22wxWuawz+qrqKj2jSRnT0dFSuSQ15fr6V0Uy3BM3uT/YQSSJuKW2\n+KSJqWltR1PTfFNTSvPpnxlYDRNN6gpITSIRc6lNQTqZnv1sv7yzfFwu/OFpAU6ls9tI692b/OAS\n7usFx9t3/ln61t1/9RfwrKkGc4lEJqmJU65FRKgQFjWjZlYX47Qt1zZBt6nD9MWb5vIJCxiT9aEJ\njm4eBc3Rkwr/zTogSUn1N/7jIo0lACR3mVdR1LnZEoj6fvvsoEdvnx+63ebs+pBt9WrQ+4frD785\nfsK3j/synW7LsOe33/3pn8Ng0kkiewigCRYs+148JQWTBMNETSQcvp2SFUffYXrxJAzmXSjVooZa\nNY2AEdqK2rY2gGhWWQyGmFlLdwoQDVUBJAldLSiZ1b3mG0bB4RCI2/Eqb67z0YvAQjbbB3fs5cP3\n0uUz/d7kb5bMA9578OwfbUUECGpFFmoVHbts02GriQRTwNvQm6bhtcDA/qSf+OaTHaoWSaiVKSWR\nEDSWHnlTfAhSOHR2b/XazNdaxmNE+HztU4VGuAwLGL3GPDmlmoT224oX0+nhSn3YYru7XU7Pv3U4\nSv22ny5WdViVk83/+MoYRokI6tRFiFidaJXTvk8UYTQnb0KF0sNDrevzVOPzcxfRJAYPYW5HIrEq\nGnX5ZqQpcz0viMYumWsoRrSmECJI2jamLBeRNIIMoQZUVBWHjJ2su6cXWtKoA2R7+kYWq9uvvn4v\nv3mgtrh//d9/LgiKhog6fEpkssMYXT1AudCYt8C5wgiR3lIKAu72+ovMMIFGW/chgQSQU3TVE2al\nEdEMk9lgstl1hpBGtJkXBhGp0Ra7fgZZm14/grRULxaL0U8X17Vqvx/73auL3fr6rB/3fnHrbXl1\n94B/9hdtLq9gEOoTM4Vwqxi1rhWCZO1qKixppEW2GCNo8eXVaVEVzVLU0RkUlphJsaDMFPmY46zf\nRCezxpNjtPq9DRXbdkqBqqiwrXyJmVgTlOm6urtADRYFw2ZzsS2Hxeoudvn06OLq61/9SwIekI4h\nSld6jQIjK8KuXl5uDuOhCINQs9Qtl11KkK5LSPvHJlSkzCgDJKuZOqYeRDSGIG/KEgUQwaCKDJ0C\nwYjm1tGgTNIToBUqMTXjMNxQigQU1HI17c7A4VCZvA9//tWDD892x7br0mq1++Zx75PBxaTtxMvO\nmhEmZCrI1+XYeyZLafY0yuYSIZnF7cVFF2YB2hQRpkmEnbpJlRvPTr9RrDdLTAQpi6xSI9p6OsaM\nuUZNIhYED11emNxoMmS2GLDz1xpdit5Hj8NhwX578QKyWL7envhyevGNjRCE0JVqArEKN0oYlKWz\nba1HzEZCqkjUcRBVCqLi8Ma1dBJgOsikmgJ0E9EkApWo/6HXAOntu7inwaRpWVsnHqCHT9evkkZh\nyLRbDVmgMwUYCrCG4KLKyb5KWiJCLuL4ztVifzng7N51uXvns6OzJ5SAimUqBBqqEtVUPQfZr6/L\nRXXf5lXfJU+gxCDwkFrrYRKKwABTq6bGRJNUTZTVa7EkUHgbMLdJOUF2q7YlSjhXhOHj5Ysvv36a\niigYPnrsVwttaDwYUb3c+t0X79U3q822c5sK9VBO7jzP593JxX5djtPddf/69YVBjLk7lC5LqFGi\nQkwSrfRLjIzrXR6Ww2KVB++kovepwHd1v1O3edcKyY5iojCL3FvU3db6pOnGyxpCDac7V0uFB4Qz\nsl92V09+9asXUyQVcbqzTnW9nknxSvdpeue/2P3q/c/K5uiQXTfOyu3th5vL8uad7vF7E2L/5sGH\n27+8ciW6PDKCqCpURiSpgjEfrbr+OOVFn43ZLKWsEil32C9U16sizUROiwRcYBIR1Xr1/fUUU6dm\nzCaaWwCnu+No0HBv3FnGeLj6+uNfvzzE0KWspFNjgvo8lSfIGu/80fm/un5QbbzVR180F1Vc9bde\np1e/Wp++9esvjhevfv7Dv3Hnf32ZTB3ui7YqMEB6BbWWfbdYn6z63jSbw0QTEJOL6fEw6Adf/+rL\nc1di9JpEGsVVUtSOvj9oZTFRMbXUKYgQiaAdd/B6M+icthePf/TxNWw4sZQNRPU60mo4eePG/OCD\nL/+3bf/kspbUbXPtzqrL9vTq+J0f7/fv3tlv9clwq/z07Q/+9v96gawWKSfRCCMgiOSguuqwXHRS\nlSoCE2/C4NQn2LB++w8e/+Uv34zwICR1YFW6ptrBS0pWWBuYlVJSFUGQtEHpzeOfXqfD+Wdf7JCX\nR4tId8oUqkVin7eHRZAMjdA7Jz/+3638zmcjxyKTVj2ZLuMLeevZybfvnN4++tHm5PF9Tnevdt3f\n+l9GZVSGmrB2uSTQs9IEYTnN5DxRiQgpoIhZUpCy/O77T37608dFUAxGhIdE2CA1loZSo7I6IMim\nMJHw2jXyhwgB3++e/fKnr3Q9rJYS6aPN5fXuoN1U7OL0aDamGYby2S989cHwGFpeHfW+B09K7H58\nZ/Gk3D86/uLVka93tvzg/U9/8bs/+HcleYWpa8ATvGnkqKVQ1SOBLgmhMSdfVWmWZug+eO8P//2f\nPT0w12lyAbyW0oVbgvSHqF6kumPU5FDVMom7xkwyicPX/+aTbV6mYRB06ez4+PXmcsp7j3FsdTyt\n370+19MPPvifL0E977/9q30eT7f9i5d/9Qf3v3r2Rj+5lQOD3x6vr7c/+vDheUu1rVBL7ppcu1I4\neRQkl1kwAhfQNCVTE6ckAvruW7/9T/780hFkoMIFyTmkKtpFiuRjserTFExLmc5jF6SoBLxsn/3F\nT3mSTXqDReoj58u0Z0weQWGIMjZjyOLu2fVjEUHSHnVaLPT4Rffii+O3N6uLfXpbduOD6+6be9+8\nfPd3/ndMHiKmEsEuRjrNNIxA81eXCDORQJhYSiYRCJUwReTvPnj/T1+U66OoQsA5sFoWSnJl8lRr\nnaao3q+kxAiBiFFievXNv/8YZ2tQcwoghebcdZejX9cm/gaVdJdv68uvJmGN9XtHHwy37n9d33j/\nrfqzd9+6+NymF28d6TLixZMPXvz8j995QslJTUAlUg0k8dVBISEKaNA0KQCfd6yCDKGLiND19L++\n+/95UkevAolJFo6sRpq0PUSWLLvUxQDR9VJnw9vdr//NJ7i/pFuftIyWSp8X2exymw6uKmYENOBn\n3/t3z64WexHdfL4+/r5+vh3L0QdHf2W3v3wRclIvbq+nbLf+Sfcn//Txd5+WmHlTpARAMx2GMQhB\nsvCu/433nWaDFxMVhM3aye4/Xvzdz3fTpDlYkrW1yAJRVoWoaHIbs0GHh8dJFKRPb379Ge7dYmE3\nmENTCqcNAm5ym2/pDBDePd7i0rIyFvw0vb07uB6tp1+8+e7xX3A9be/tqhoW9+998kd/+Ozts8eN\nZQgTKEK097zA5TiJweIwZGtWAapmIgArVGEMihKk/R7+u714LCDshhDO3pGwgAlQs2om8sn9ZdvE\nUK5fPJ1unwhMFwMrUq/bbanarW+f9rNeWkSEUHkzfbS58tQPS13afuMre/OrX4394eJkxOD79eHS\nePX6d+988x19+iE0Zc1ZQYVlE9M0rBdRLNs49UsFTDVEZl1jePVa287dAEn93f+jbnaTjxJdHhtY\nFWwwm5qYWKfQ/uR2rwAQ4/mzzfHtnikNq4UKTfWz17tJF6frtc1fQlRVRa4fP//pKDiMB25TfXmd\nH2at9XT9zWcPji6rXGA5Idxu333y2btPP0zl4MGYSnHCx8OBtlj27s6prhZtw703W/dg1CBqePWY\niRZk/pv/uUy1HPaHwXxWrmHeJKoqYrnrLJ2cJoUIY//yGW6tiEjL5DWsz/rVi+tJuy6vVvDmKt2e\nz475GUDn0O8ncx4P5aSrfhfXm1sLZLxQO7+I9a13tn/18OTWyeTFG37DKhKMoesWiLLDuv9rw2U0\n0k5UJ0Ghx40ocvFf/q7GdPBYwxGOZvgjQm1uaGJqeTUkFTDq9smbYS3o8mLQw8QuQa9evp5MU+4X\nKNTG0gKBLYd7DghOxlI12SIfSopX/dlhvL6fjk8Or3A+1V13tvv63551pxEeIhER2dDpFINojsm1\nN7ZFXqLajOGaypwRcwtIMkjc/S9PiniJZY0bG9L23DXAncF8fJabM1A9f1oWVizlHrudp2Si9fzZ\nSNOuT7GvnJV8qnb+5vYVEmzwS3J0tUQudfPl3ePY7Y/99vu43PZd2r/0/vm9o7MkCssQk6BXRp9L\nNEcfcao4G0qu0mhu4cRcwM/s5OBv/cfY1j0Xsx8A5qawwQzuXN5992FHSITvXmz7npLzoPtdRTIP\nRT1UiqTs1WeksV29N1NRpdomlBGyxoEJaXt+tkqXe9JW4353ela3T07ldBg9kncZgPhU0clgdd91\nBp2LvGYGJdpoem3Ticus9RABydWfnExl0mXD0We/dDaCkYee3b17a50VAKfXL6TP0plpKa5JI0Sh\nZVdgyjHa9vbZwjQPg/bUzl0oWYd7Xq+7tMi7x18ftOg4FgS9lLvr/Y77V0h9ajsJlV7ZrXVMvSoj\nqSjUQOfs9NHISq3OnlkAbajz7R+WWnNHAUwhzaYX7XvL6cN7x0NOIkLfPL009dSlJKUgZaltf3mp\nhVGN3kR9ApjQ3rq60Nqn0d1lqmernUyHtMji59OI0RcDdxpvUL8jxT8pGZ34QUBxS7QkvjxKsASl\nWCMHxbyVua2D9gg45WZPPcO5+uMVo+9KxrzXPQgPegTD7j08G+Z5yPT65SRksmwxhSQjDBplN3nA\nsgrrjewtCExPJvTdodTQsu++7dX82gxTzdiNu82Ay11Ji/z40R/8YPHq79wl3Ktk4zT2XVqtj1ei\n0qlAk7BtSuZMw229dNtN0SAbJcCQ33o4+tLcbu5ZAAa6QzjcO150SQFh2by4VDTH7OJhWaAcVeiH\nogLtrC3QbSpXie15dIv9BIhMy9/S5288yhCl7mTwsn15OIqjoa665UX6rsjofd2Papos3HN/dtIJ\nqvZiEW1bb4M8m4+J3CxxFrZE4u1WnX1PuGyfoRmpz8p6r1yuOxNLAsZ0/nqytg99PKDt7oopURGE\ninRp1ivPNco2r9OmiILsPrS/+iD6yZZjVSI7re4W38VQ8fHR2fLBX/7j3ck6XJ0eFoFhue4k4hCJ\nCjBsFqeAwhCEtk0mGS4JEmJsBLP8w/9pXDb6TKNVt8kUCV1mhIcKom7PD1lgmpLXiJQkvJZJhU2M\nLX3fWhb3aSolFCcYiwhUj966/FmNISmkGe6prXcffzq8tSjl8qvX/vrRH//gYQcQq0QNWu6TQHQS\nS2rKtgcA7pxjUAM/w73NpWYeKQMf3mMnAm1pBiCkba7rekNtk/vp/JxJRNubRBHh7p4ITg6oZaH7\njFSKMNaDbqEheerHN4qtqHYXltkJ8kng9Xjn0Vj86JufPTzulh/s39uMU9kewomcFx1g7rnLOp8N\nkQiBwhhJ2JaCOhJdCMYsjIpbHz0eANJqzI5oTRSq1mtbhUnfvt6gmX95GV1U3VG1S0KOI6G5GaOQ\nqhCadodLFxOI81IWZZp2zHZYplJjwNEzLFK92p6dL+98/fVbn50Wfbe+J5tJDdEN/SJB1afOrBMR\nEZc5WYBAKD2RSiCqEDGbLRLRf/fPlgK6CjA7lwEiBrFAaw2n68vJKeKCcig+JHgFzJIC+6mEiKaG\nNrcJIxhl1LSnEbJwz7sRskXSbawOUrO/43qdNhHv1b8aunp99HJzUdZ3d3uPbug7FZEyLvKgaPNZ\nY20yIZFmyqBgaBUilNFWTSHShycqYm1J7ewqHiJKORxyhApj2uwYbaYz7aTvEaGhSKoaXmoErUu8\n+QcSebHI8HDhcLsv9Cl6QsYOlv1kmx5e3+YrpO2XFx++F8MuD4dwvv5kpGvuegPFr6c+J0gSoc6L\nQ9Ecu9re6TlHNsVtMADGWx/InCWJZsYP0CzJ+cuLTRWhl+1Uw2uIlMtdWpqPE5EUKUWMh4BIWhxu\nCIQIRyzT8fUukI6OdxeT9IxRlj6ulu6rvJVHr57/DYSky1/88PtP16++DEcwJnf0C+uM8HqNIZkk\nBWb1UNs2K5yNt0WEVc1buKciJI5/51IP3s8cjIA4xFRMDt9shtOVREyHUr2MkTnW1KcShWHUgwpQ\n9+6BbpUwTxwhgppslSnL28OrlwVmCvT9YLuljffizWZ69PxHn1zQd+cv1z/8j1TLpHUkR0hKyz4r\nYn/R9YvUXKGS3pwRKFtyj1nkHD4zXkEB0ndWvHr95mK7n0hVFdOUUt9l3T55TVP4fh+opYqMeyxz\n1CkoEN+mtqi6MqCNeN5U4zaltqf53COMnSvyYhNJD7tOd2U40u2LZ/Kdjz4rL+x3v/o6Rx0YqMza\nW28U8HJzpx8UwpBQJaTtvpvnzE02FBqh8/ADFAl79PXli2nQoRu6bNbMiXNWp1C6JDFdbd1rQfKd\nDz0LQ1Kq3G0TA2CVjjdsyNn2lxeiiWAAZO5ZptVwvdNcorvYH6ve3Y37688Hs4e7X3xdh4JJokbq\num7RqQDlNftVL82suy3fFhHXmLnspAroUAuR0Jk1juPVz157J/3QdV2vnahaTpnT5Hm1ENb91Tht\ntlhgt+8WXXFGGiLK9pCC4lcYModFlPAm6xFVu4L1xYZaRNLq0FX6IZmPan45HN+u8uFzvzv+ejqW\nqb7SOqYAScurxbpTBl4/O14tGiNEBN5AhkZADGqISbSFZsGY4ReKSrc8r9jHYS8Ly9Z1lpdD7kqM\nMpwO6uP51X6zHRc27XQ11Apa0orD3hPBMgpF8xpos3cKQJ2ymFiMoPYx9Q4N9GOxheP4yS1e3/3j\nH10dHcuw/eLVIZdIo1JSTv1ykRA+fTmdHHcQMISiodIYn22teLtabMOatnPsxlzxTo6uutcYhZZS\nzsenEC6WJQ9Zp6uXb84vp5zKgas1C+EpO30qnpK7kyHaLdNv3F1UdP9msMq0R8hRvmYiGaYoghxP\nx08edZuz/+jnQ7/47EnZDbnk6gFR6/tFIsy/en56sjZtHnCMtjdbCRChICXQ8uR8SsJIMIy3z57m\n3muUOhYGwo5v31k/vOtv1kvxw+XrV2+urS8TlkdpU0Uh4l4mSuppqiFqfbZme08IVPcVm12CmyzS\nTlCqdSMCK59SX1e7x4sjwXevX13ferPvOEoqQqr0w6JnaLz6Uu+dDNL43m3rA9TbwoGgzrwF+Iyu\nWCM9AMTq/efSMTxV4TiWgovz7W89PN3sFh3q/vz5600stNTVqg+yWjKg7kb3tF5JDJOo9tY1nV8b\nDZd+cdhrpMHGi6QxduJwAdyLLcpq96vxdLUnbuX1m1SKFpJi61U/ZFG9/nT/8N5Rs02djXCl0dhB\nIbypgcGIZnzAmElMhH7wl1BlqIkmOopPp+/eSRe+SJguv35yEUPPUfPA/RRE18lUisDS21pl5RWK\nYdnQJIEIPHp0k2g/7DyM5qkfO6jvoWm3XOZ6+Hz48FY3HYa7z8KVDs/ojvJiSLDrX14cPThLKqF6\no8mmQBzaRk+ItuaYpDdvZLpQQgl9eHzZmRpm82qyyjLhkBaYLj/72YvapxqyXtl2V0Wk64tHFy7J\nVLtVYkhaLEqDZ9vF0jELohRH4xrDl3RLBbmflkyL/un04NHZy8vn5oToBLVlzkMWbH7++Pa7DzLB\ntuOriRuphAUaxa8RMxRkVbXwG9IiKcf3Xg4QWGjyoclZdlM6Wtq4/eWfPSld1smW67S/LkS/6KQ4\nUkhWR79YdUpoymh+RyJU3WvZw6I6BJIh3FuUw+rMpJweX2vCbb16+bOrt169AMxIJFus0jpbvP7R\nk9O37g2YFUUN69Hmed6gDROgaa1AJ7y432jQg937OhUKTC13eXF8fKJbt7P17slP//zrybIeZHWU\ny6E4LA95Koypdiltbi16C7JWJLTeU4UBSOzDjJB8EAYxnOpBFw7xfOx1XcrS39l/9qTvJgoiJKV1\nfzT4m2+e1kePThY3jENCZwY7oNH2cnjjlxOBUJQQVwlv26cgby+mKokCEwvrBEdJe79+/MXnT6e8\n6EIWiwWnXYV0q4HFEaOlKW2PFbt+qobQFJz5taKyQ3HJpm7CvLgyTUcbpuHoXAu4v1c3Q3/+4Ysv\n7+StWgmNGBbdtLm4qCf3T1d9206GEFqTC8Nws/SAbQENIsxVAYlwCb1ZPHN69k2TW1Mlm6lnYWyf\n/OSrqzENHbE+XeKw39c0dJ1UyLSrnSJtNl225UTTZEAEm++rDs980lY80iWbadEIWz04lOtH2zy9\n+6vdUvWjl+dCCemA7Of1ADs9O1p2Q2pU3Vn2CWuMHQ2GosIag0jAkKAK4gbpCpCLd74qShMjojPT\nItevt8//6ivvkBS+OF7rdjOypmGRAfh+yphKqle26kuh0xbbG2qaii4u925RTQxJ8iQSpUxMx/3V\ny2r3nr04vfdqWpyffvhxVQmD41B0OFktV4tFTqqCUASaGZQyoAiFUECN5hEbymZeDwEiXESpoKS3\nrEgnTAh2OUni60/8ybPK1MEnWw7cX2/rpOv1kFBiLFKkjCmN14n7qSai68QbpQBEniZBL2NKkkWW\nF4GFF5bDYrk/zu893b+8s5sm/enJwy9SjuJIq9WwWHTdIueGwwQDRog32HA+FkqR2gA7BSNEEFVB\nMEJmdty9o3NzCVOFdkbF/mvfyEJSqdX6wQ67Ec7hqFMtrIVOuKWU6s4PkYa87G2ekAikXlRbHFwg\nDMXudh6nflfgG33w+fX06K3P3pydfP36rZPXBUXdswyrk8ViocmS3tDVRWEgRKWpCWdqqLGRiyWa\n/bNFYaNYqRBUntx50bdE0JtqFbCEdF2KKDL0tj3sqclWQ5ejTLvdGKxBDdbdZp+OT84enlqjBkNU\nXnyxLyEWqqieajmOOE5Qsef33kuX063V+eXtxZdP1Z+KuvUpLU+O1qu+0+ZVLZy9j4Mq4XGz7EAE\nAtVmQiZNDdm4EhFtzzI9+gf0EAElZ6UI1CQtVoOqdENfttdjdVsdL3ot034/VQaR1H2qhYtVzqu7\nA9oIKZzd/dXoPSnJYWoXK8lnUOvyLj5Iz55gbZev3/XXVycPjmMqko9P14tOYQ2lDKeIiYdTGpz8\n10cdImrajAE4qyc4Q4Ntz9vbfY0oXkWlbZOWPBwNKtKtOhZHUulOjgaZpnG3m1wgwy1FSBQsJMDV\n0DiQJGlHOaIgdQbr+mWtx0MUy/3kX5/eLtcn+bTuL9bjXu9/e1VdFsd9VghEdV4MJkJl02mhCRsF\nArowtO0HgaoKVFijQcMes7/7vZMaLjUQwZRM0PWrRWaNZGU75UW2/uhklTgddpv9oda8vH2cGMI4\nOANztmq3//L5Pu/KMoQmOSNfrs+mXe5Yp8uX37rQ49WbgZNvRco7f/jjw+kqq2lKkDbVFZ2VhxBh\nqDc9jMisKVTOBlkGCQXobYUFyRBQjh9dwkcdwtlnp1jKuY61ooxjZBPI8njZRYzXl1eFKun4xDQ0\nvO4PZEQpCnFRgMpDiLMLj958752cvLV39cCufr78oMo+DhcY9KrsPh3fPz7qskLoKrPejREiomxa\nYffmp0ifrXh1tqk2BajSrCeBZrlP5kf3Ogl6LbnLybqUF1l8mup+M0qXTHPKiaz7y51Pjv5o3ff/\nX6lU7Q3We5CWAAAAAElFTkSuQmCC\n", - "text/plain": [ - "" - ] - }, - "execution_count": 7, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "from PIL import Image\n", - "from torchvision.transforms import ToTensor, ToPILImage\n", - "to_tensor = ToTensor() # img -> tensor\n", - "to_pil = ToPILImage()\n", - "lena = Image.open('imgs/lena.png')\n", - "lena" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "metadata": {}, - "outputs": [ - { - "data": { - "image/png": "iVBORw0KGgoAAAANSUhEUgAAAMYAAADGCAAAAACs8KCBAABTfElEQVR4nFX9Z7MlSZIliJ2jak7u\nvY8GS06Kdvc0memZHqwIRgRY4AsEfxkCrEAgK72zg9md5tXVRbOqMjMy6GOXuLupHnwwfzmNlJSq\nlKyoF+HuZqaqhxmfyMPSghBByAUL5hToIFgSoBg0CJalOiRHUBRKJQEAkCkNAJMMGQWKIgCll1kO\nJInwNImAwOTsxWUyaDEDACZApCc8aFG785qkLd9h60xXElY9YYIpYQxDIPPjsyIkKlMUWM0MoBBB\nIEgPCfIkkEaQgTQwQRjEJJkyAQoQAgTKJIhMAUTSVC0TJCSSAijJBFkiTRDhUJpgTIoUlICIuZoA\nG7b3o5FBrO+aUJASQDFH378xSARgApIkBEEwgwQEQXMRICSKBNRepwQBEg0CaQYIgKpEPX6MTArr\nXxKkEGEpWAIwLZlCggTMjJAbRMEBSCX3IjLsql/WFxEwgRIJgqTkQ5/ffmsQZQYSaF8bJBIhkFIS\nSEGggZJkJCnRAEIIIkWmJEAQYQaRkECABiBTdAKkGTlTNEIEUDIgQWj/b0N7pyYJgOjzbEBquMxZ\nEkhTIlNMIIlAWL+d3sWZwdpbFGBuJKCgyWmU04wkDJKAJCCE1NYNQFBJg+Cg0o0gUiDa4oGRgAFK\nE9t3DAqUChRORhohmpEAzYG2KgCBLDyKRNXZuAhQSjCQCLZFChvK/sbGYsgkEkLSgES2X1+VmUyk\nEhQgQAYSZuvKI5gspIlmKVh7RTQaEO3VtgXLtm8gSNyKSGeKbjBEEgo5BWYKMBOQBiUEr5MRGeUJ\nJ8EAUjDA2sawsin3r+c0Ga29MBoFERQA64qBnky0I4QwkilkEiTXF61IpdRWEASaFADMQaKteZCg\nJ0R3Sgmm2kKvopQJku30ISjRSYJGk7kd0oiIcTeFku0LA+1b03u/e89d57Bse5CpBJSCRGrd1eyR\nQLbtKkGwjLaWle0DE4DEVMKkoIMJUSGsu1dCCmDWBJCQGVIoIOkIWFsB7bQ0hSSkRCHZ1RMghV15\nBYBcoIRUioEc/P17GwqDJj6uRHJ9p4Ap1U5eIelIQ0Akgu6PR0tm++3NxCRIJamKVIZAEKAiIYFJ\ntpUgy0zAmIEQYMxs50VbKKCSAEQlAJZDJZW1fzItSToFQBkR8L7cTH1vNWXGtmQfT861OEWk6rrW\nFRAokqK19ZNCytwIKSPb2jFnO4ZpJiUgRbZ9zvb12P5yKWXFgBRNCBJqGzAFUrYepwAKjmZU6qKf\npLRibY0m2HW3b1EiU7BiUlvGRoHZikDQDGinC9A2NZAEAYlEO45FyludSDBJrTWwBCC5DNGqDlvZ\nWmsOKUSatX9GezdaS78JrchBkKycth2U0T2Zpw0qAQlpYF/e3BlVYEC1tjtbFWA7RQE6oCIF2xsn\nZCAMmbkeO/lY10jBTJQCMJFAVrXjK2nSulEECXStlRQUUpnGVPtRa9+TePw0JA2ed0tERl4+i2i9\nD41i3+/3ZVskUvFYErA+TPsvZbbjzRQpMAURyCRh64lnkMwA0bKVSfOUwLZ1XIlW601oB0/bd9ae\nwxwKwkrm+jqsHT8CuX43iiD7WvtiqLrAKaC28G3sbt+oSytOydxAN7L9JhKSkgSjIlq3IbYSDpFK\nW2suJCCjImsYzajMbMuCBjNJqfZ6CVjbr6ICJiWRkndEVrXmSJlQO1j4/fmNhFjsIQhlHa6nCQRg\nsJHv31tPi1bvZFQqtDYeas0RLdX6OxBImFLWKnPKBEmSSHohrT04jAkCTEk1a3up2R43gVxXYKq1\nlVICaE0C1VYQ2f59itZ+ESD4cjIiU1fbWQIirfO7WxuK2Jl5ClFgSbNWPpLMBKWULChv70UCEVBJ\n0IRWsEmIKUPrY9cdYwhaakkCDqMbPGEJU5qFPKn1/UNpJLC27khC7W9Q2bY4AdHKYTQhWK419Yj0\nwe/vrBCOCJkAlDQIrhSRMlq2mgDBkCIgowQTLdeDF49dMYSAt18FFUhVSASHklGzJrqubWesZ5HU\nXgTXpZXW1oDUzrLH/ZkuJJUEwf50OKMyeDaezmk+locbdjSQJsAXlOIJIkNrn9RWs6z9NiQstK7w\nMAKplGFdPeTaqGT7XlWGMbPvh7QsrjjcH6N0nYFJsL2J9iraNNJqhCllokix1T5jApS1DpN+HJ1U\n+nNGx667vUXnwdafZlpmEbTuYMjaTxdFrcMHsk1aYq6/BG0oaidnGha6YJnVyw7OOQdfDkU5D+m7\na33Qn6YymoUTeBwyubYErYNLU3v/wtpfAaDAaMc1S91fMpGxuTicjf3DrQYkYOlJVpmxBNeKAqxj\nZyuCUBs9BKYZIFvPXG//C0AZSsqUSCZHM9apROlChCOPi3XHsr2c9ze3Y28S20NAYGX7Pon0Vmv4\nr2pgW4Va/8NEn0+jlOlP+tIf3qN3QVBWEShUFtKVopRcBzjAAkwjKLbhNUF6hWBgwpLW9r2ktBQd\nGPxYLIMWc8gjXShgxFy6y+vny7J1EART1qavxQnCaOsrJNrEtG4iCrRsr1W+3I+WmTE8u3+4U+nW\nVg6o3icAA+AAUsnCtnHbiZ0htHFA1ZkZyHX7JRC1bexlnlB2Ow/3Wiyx6aAlM5eKpDOMyNN+uv5y\nvNuD5kyVJKGsFJDWJj3icaG39uex6eY6GrJb7osxaz7vX0axSBoyzc0ckYFCt0gqKryN+8G2ZGBd\ntO4NTLYZGJTkbRQms8o/uN8t81yBI7GoR9AYPckpe0YxhqWm0+6zF0vZWJraZmtVThKU3lohMmWU\nKYxqD0YyWx/k3WHrgHD17fGMtAomTVXOYLfIiBSRtfrjIUIj4IZcakiJJJVCKxGEPdZdVbvAQXe3\nJ5SiYgzDMVjSvaass+g8a02ldDo++dFwd0oa2kHeajXFx4kFDSEQzdpoCVqb9QwEet3D6Nv69VhM\n6SYzRNJTsagzN4CMJawVAgkQpUjC1gMLrdkgqWyzOwlEfaGHw2mWRBRKLCrMKptngBFdgVBMKebp\nXh8/nx4mFfyrURHIXH8cFVqnnXWgVDtpGrrC/rTv++34Tze7lNGZ5KK18yw0KiVEom0ZM0HrHx4G\nkbkOC1AbWA0NN0B2m1e04qV41FwQrOkJsKorhU7N6bAACVL749Mv+7tTDaH9+ZAg6PAVd2gfJWnf\nD84QW8uYZCl3cXb2219d94ak0hmCs1ZYX6aihOV8RIcAyYABSE8BlihAtupHh1eQWtpCEKz8qu/I\nAAQvTAO7JbW4maVZJsKrSM+kSNbD9ofXD5uxDYZoi15A4rG3lQUg2gps/fdflwBHef+LX52ftbHI\nENGDkscwHB+KRKGSJcsKyIFo9bBNTVRSMkBJtoeTTLlc6DB4hjxLCVaUKHUpNghLYSTAXuiRgrfj\nJg/9Z0/fbPutS0DCLCGVNvcDzIbQOPPxOZIrokUkz+s/fbPbQTSGARVlyQT6zfywLwI0zQ75ik+i\nHdpqCERamreTRGkSvb05wj/4bWFNdCLDonaSMKJSVsKQ7gmWbENGmsApdP7JT/X2dDlAWHt5yi1F\nC1ki2qhMwcg20PKxOy3b29fjSGWhUTXVBURZf9zfXxURmryblCYQLqS5kgSRxRchLQnIzCAoKALp\nuXk3DZUoVjPd5t5CosKlUgH0sKQtJhJZqFrBLX15sfvU768KFPQ0tb3tARkZ7U3BYkVV2lZnGMy2\n9d2yIWCWkmQdkJA63h83HxgUS6v4DSBocC3SJFgukGFtQhEQ3cwJSJ3fFwBmGV4sCjKkmhYxBEqR\nAGvP7N7V+bSMF2cdlv27V/bF3T4kQ0JJIteOVOa5LBWSWTsXk/THwXDMb+eNieYUyTTPBaF+E/P4\nUVdMOTdkWCuOBUmJhMhEqSlLUqAsqRVGcZZ5MVqBwq16Mmkg08BYIFqiLJ4sYTwdudkUzg3mnW6f\nffhuHpkWRtbS5mjJAC9LhnfIRkikt9ENEIbu7b5HNK5CKStRIfV+uF+eni8FMUXvbXxd63eC8NKm\ng6DorZKAucLMRhEnGqhEQTqYhazWTmKaeXgSyeJ5qNxtLGaIgi3Guf5Rf3PdCY+HOc3ysQ88+Vkr\nJGrTI60SYDfuX7MIdDAtaimqYF/NbvX8fJksp8WNCtJgBBI0yBuWhrZCIaC0KbedjyqcjDJGlqgl\na7S+e6nJkt7GoJlJHG+ns+fnmKfaaluX9XSrP+V9tN7OCJAhKgV4Yc4BGK1hx6aEIBv0KnrSOss0\niYMtAafjxMudlqXEYkUJT4hBtp0hAVkAjxV3bquNoFHJdJ5kRKZ7wGotJkylqzRRNC7whQVWXufZ\nBWMOloTLQumWOp7/4J9PmwZPpLfGWTBKnS9L31o2iAmSJtPgLw/Pe6/OhV5TRQuKuFB1c57QXJbs\nSoornNxmSeba/ckyV3AkaIKEoKXbnOYUjAEoyBLZu8mSKLPClMWNt4ftc801OSKV1dI6JoHYf/Ls\n7dOeslzRQYFEGqSlnS70hDwhSVk29+/6eQ5mkpnpQ40OoU4adqg2DWU20wptCAa4IJq8wgC5QNBj\nhdxYUyI86ITCDAike1aKWbMnPJZBDmLYv7enu5gRHTOVlhwqaiCtLt0Pru7GnQPpWAEjAKmgWAcx\n8f2xSdiQb0vncmcSNayrk0mdJh/7kjFjMLRjk44QIWQKmUihai2jihVvsTR6EbsqL5EkFN73lkYh\nC7xmRPTsENzcvNx8vFmm6mPWOgeHIedlmum9o96V//AX9WZaEesVnjDSgLLso002XPvUvnv30Jus\no6RI722RG6HKjS+I3NLM25Sdj42nNWx9hWxaH0KtuA+kjL5Gg8EbOZPhHhWWVDGrWCJk/u7+kw/n\nw+yeAkrp6ml267re2lh3e//03+0eAiu3hASdjWrNpbYmrNVz+nB6bU5vtYTwMqtLYZm7MwtFjqbi\nRFqCgK1VrtGTViETkm0GJyiDETlvuDdJngKkhDpf2DuSlmlwuNeH9I/6E2mzTG5V0XWzSgisNNSu\n3pxe/Oh4e2GwpLVVValGHlazFVCSAYN9l9siGKP2UTtbqhkjFr8cl8zoSsgyQThIQ0AKKUljK3pt\nHFcmiWQmhDp/XE+EohFqmcaoBjIW0YtB0vFBHzxbJgjhMJuXKJ3NBQsyQSeNpdTXT768mVfumSDS\n4EbRsSwhPRa+rn+33zgEi9IpQMzwRAZ2G8WylCKhGKU0SGtVUzuoLNUhRctsGCQbSWRzv93TJUsZ\ng2YsM4lSZzYaJjIOZ2fDAalgibKIPZCw7PajyaSUFifi/pPL044S7JEEzsYzY3HXHEp27uP8mi6S\niUygt2ApuUx5cYFa61gYaUaRxQAFIYrWeCekoGiwlxnh5sqk1enFciwQAwmWzhVh9G46uUOOpPaH\np+d5GrhM1uc898NYowoedUhyijB3Eqnj+OcPR7QWTSuk6iZ40f7d62n74tPuWLB8m70T5kgYDFnd\nMuZTf27TVIceAlWYRIj5r7YGADoNLNaEE6JbtIKy2Mcv1SuNWbQAbeyv8kFpRdXm+83VpkZOKhZ3\n6HqL8GptulZIdINqmqLef/TsNdJlCwDIEop0kPnG/uhPbFqe4Hz55buNWxirrKu0SBiWxNkzX5Za\nSivWlsiANwCU6yTZumOkFlGRFRlzNprzcOVvoSTQSwaFDF5U+lqtzPs6HS4/ipNjOtiIo29LXcQo\nzKoulzDbeM2mNWBM95+cZpBiI9YgJ4wxvSv/w0/rN3+4CR4udw/GlBFFU4LGUlRnXg91Tut9pf2b\n/ANYKa+1mCMrZVppY5BORAJRP76LTqRngMWNTMhKvVdfluAhXmyXIerkY72dL8ZY6G2E95gq3Tib\np8k8gcjT091NtG46gXbeKB5eX/2PH/7hd3F53hPzD1mpYpFWsB4pUm42WgL9OsLDJBoNjYhD035Y\nG5qyjXpubLiem5ZyfhgKAIQByUVuFtPxOGy4EMfN08E47Q99l7OVWumCWERMi409kphNCyKyCnGy\nPz4cSVqDj2Am1dubH/+n4Wcvzy48l7rMTy8PUi7spiWLR1Kx5PYS8xLFsiZIyoxI4v8PwlXDahla\nsdV2gAEs83k5HUNAlE7Jit4156RLn2ocbs9eKHAMp064HE6zOxgWpczV+y5nWIUthJzdUEzLwycv\nbhRwrFqWqNOb03/4y/3f33y0jbmq1hOeTCAJmTlmmhkDF/2yZCmtpYfMHis1c6UjyYZWi84kxPZt\nZJ5QvTwFk4D1p5PVMiprarycH5xTvbxQqYfFt6gd51oKVWqiHPfcDVqEUxRZ8c6dWWuiPsw/PO0b\nukJKyvtX43/64e/+wT/zpQKxZI3dAmXje2saDcuy22ZNFihTgqgCwFLWVTI9mzynISiyx4YNMKZE\nTNP5zGI7eT5oY44IJYvN2vG+Xl0cuqXCu0MZcFqGYZkMPZdj2SjkSxbUztOGeUl1BKG8O//RcTxv\nkDrj8Pr+x39x8cvX1y/qIi2u3gBUuXGZ5WQRY5780o6hjtkkQq30k/9aHJMrqUy0Qcoob3IIEvJt\nHpeSyoXnHUDU7Gm0Dd6dnnfzeH+0i7wvQ05drzCXYzn2G83BLAiWhUO8uTcv1gaw5fDJD+PUZAF6\neFX+z395/0/vXlzFokynm6ilGBM0R5W7L8e8GuY53LliWZYqK+u2EqcS5IxVXyYoV6UJLUmr43ia\nyKVkVwIzhmNv2WGO/iGebeJgpv6Ou3roitmeAwyH3CqMC0DCqzv2cVVCSEcm4xRnH303WFLTcXr2\nw4uvHuy5T62SiSB5eK9MOitpgJbjxUdZ01ySGhVjWoHyVqolwFpH4AAQMEcqBCgg5VS6fTUH3OtS\nMaJ3dqwqD/hwN1VN3fCGO2EglqlHlLKP8wGL+p5MBbs+X9ozVlBO0IGc7s+GCsR0uvyzH9z85ra/\n4JxJwWSUDDe9uSKkCCnnpbvkEiyeAolGtRQ+ao8IrIC31ZWvcQVoSoBN1zWf17uuXzrTycIHHYvb\ncuoxlV3MeRrO3+sqD6WLyCjDOB9yU5aln5foNQ+F5vvjVZ9qkzBAX1zzcHbT1UP38fnxVbWLLhcJ\naSYzgd183FGNUYMY6ZfXUQPIdsQ2mLk0lVKKlLmtbFOrRiXhkSaXrRqR+Xr2zHCor51yLkMcubV9\nuVxqxOA3OAfgjMSYyuVw1R+FanU4udcyHYt2PmWaw1lN5hacT2O9659e6N0J/ZZzaOVYlTIMD7Uf\nAMua1ild/RNNFY9aF2MmqUJYNpRTMokJp9VGByhoahCMJHLGxZFFHul0RhbmVAbecVfh+83wrpzl\nMupuYEkN+8prTAURcPUxatlvBuC+6y1KJQQPWPScdufzR3YzW1e86QLU0IYCavOaBYQx4QFkbvsa\nWRq8nBDTAJR0NpkiG/1tSksGIc20pDHVOg5o6TcPBDMLlxzMofQx7/BkoeqY78tm4dkBfZS+7E+1\n39aoXmBVBbuYcc4w3yWg5MRVrbmwHs697rPYkEtbI02f4RYq42tXiO0sdUW/Y6T5o0ig6SisIAE0\n2KusLAPY9F6rkGRFdyws5FOX7HrOxqSWzk2neCo75nncDf1SdFQdTVPguHHU7OvU0bJ2x+gsGEtP\nX9wWYOk807BMw7Ys7/riveei9RyW+QEOsuBVOJQRMJZl6nZlqVoldKsYQIQB/v152+QxJqjJQZva\nhZYULYj0jDqHVboInaoHHvI6stpO77c7eFn2Gnx0eztdb/OOXQwlrLjmLM6IwhCLqnmY5Iz70+7S\n3x/GoXRY5nQ2TQP7aarpwPbufpdB2xSjYT51I5ZwW0m2hhoQbo7gI3vZZFur5sVoQNJKos2yQBYT\ntmbFFjE59qW+t2sxcojXm+4EX+axRN3nYbPVwrLAo/jxWIjCeU9yWIVN3BQurHv/6Pxw328Gt6wJ\nJxroWeIY0cdiZ19hAJFTheYFZWc1WVaUVpCUSUTJlfDBqvcDZaJyrSiCGseQEIUIpc9dduU0bzTW\nZbhGLuXs7uFiTMvgWanHs3zY7eopS8nZupNKTJ0dVIbE0g+TaFF72JR3uNqdHspQmMrFYO0IBUq5\nXbqLw7vLYr8rIyaWFDO6wa/r3MTm4U0q1yb2Rpagaa3Yeq1V2KEUrIEmDUlCdEFm16WNi7pTmW/L\nk+UIH/f7M6/EgsIHnE15xnnxrqIM9TgUjv3hyN5RIhbvIkGExWkaX/jbh34ozDqL3wtJiXJ6KE/4\nXVcu3txup7rBUovJ6GeogjfZ0CrHk0AUj1Wv6graun5sxXCpNFaSAgNZN1GdJyQ5Fex04lVo6TYP\np3FTK5a+LvOZT1n65YBuUyaoaN4sd10PT4aJmumJMmu/DFd6q947qxHtIdbDBsVuhuvpu7OruP75\n4NSi9BJpMQ61ytlIC6IJWExiWbdU4lHOK7Jphq0pkmCWKikycjPBcRpj3vRdjdN4jjlLua/nw1Ri\nGeoybOJBFzqxyzwxB6PVu8QuKjX36akafT/1ccTFdj953zGnYJqteE4DO+/q9XLzgd+8sK/POymc\nGUH5JucwJpq6tak3TSKKTKIlYQ3DXHsTXyRLUyNmUEWXtM1CK+wXY+LU7R669OEmeg+X81g3vT3Y\nZvF+7scHmysGW8J2uWTKfSkqufQhPGhzfbhbhk0uEcne21pqCjTr5veXWj46vOw++sPxeW/JGKwK\n2vZTwq31hK6mMgBMKEnLMLOqaIx3U0IkkDSlA1FcTCYUBKkeS2CYDriknfpya2U31z7spH5bb/Os\nZwQWjeSE4xv/d/fItD7llr2VKcvB8gVvj7uemBPm3sizVTVAljfbbfWv32+/PP9fRm/VyxfO/TaX\nNGsUhSmtsTmSWKA0bwQoACQTrswAQLFkMxjA+D1usnRdUbdwVxeWzW0MA+S0h7GLabGd4vS2MMKv\nN7rg51/fsO94c+q1G/Jw558+7Jfd8+V9uSRiDvO2NKhHma53s51Pr/feP/9y/80uM5XMKWTnZY5G\nmjeUtJHeApMFj+Kv9agFGzGQ2Rh7lZStlImiz35hpkva+ATb3M+bXidDPWygzY1t3G7fnX2xfBT1\n+Op8W598dFyGvHrm+PaQV/sP//7JF++fjTeHceuxwJxIeyTrJBL9aPeblw+98eLjy/+vzp2gnFaj\ne1YXwJreYAUJ2oEFFaRBHmrQSHM5paQkkwmQclEyhKwOvgyhWsc+7okcD6extwV2zIHIBxuzTPHl\n0/v7bl7sQGjyF5++eP3bMxtwd5jsaffb/+MXh2/yqp9D6VxlwBLTmEDZjNP7m6O20tnF2Tf/eO0U\nMljqwvMyV3jDoyyTpvZMgKGQAsKaUQBCG8exqmda466kIPNyug6D14jArXLgcb8dOHE4cIDxwC3r\n6faif3831no+XD/fp9m7O+jj57/bTe6Jb28/un773eXGRK5iKmt0crpss/HDb+/TrERcXl1u/5f7\nbQ8KvWlZzl7Eku5NZEVSSHDFdVDkCiJNSmtAZ6sbKZMoz4p1ykp0bz52GhJLuUWv8NPYs5rvYysF\nbYTs3bi9Pzy5OnxW3uV0c853djx7cZq2171kx82/vfzu7YebOmU2KVhj5gDJec53355gPSy6qxfn\nw9+9fe4mUFHmnLY2pfn37aCtFbDJjAqTFugS3mo6GnvcoJ2mn4cFmZT6h//TMEfnmJ/2ZRqW41BK\nGKbs0hinnsL9bjPtL7cV35xSIPpn2s71pPxEHV7v/+r677oP7RiNtFxlFJIZhm33+rfIC/SZ2J0N\n12//a1x6lwk5lPOwzYxCaaW1GxDf9HxkaTUn/7tTrymLiaZ+NjatOST4+Ho2wOfYPakY6slLqcYD\neppi2RB6P4zFB/+OC7IsvNzZPI9x/rCtF7/B/dN/O/zn8ZP6kOyy7bf1dxWHcf/6W7tQVtturPTX\n9r/hA/dWpI2y86s50LUx4lHbAqA1S7TWikupVbMTq6oz5aRlqrUnpLPkm4uMYMExO9u7dQFMOUgW\n89jXWuuzPg633+L8fCgRH23szbfv4+H+2TwfHt5e/Pu7v7v6bH6ocLV5s9U82O4Jfv0P8XyYDjx/\nftWz7vJ/1vNCJxFpLviZMn31pmAlL1Z/C5CWQGdSGqyxedCjpilrMxymjGmUfPx6NNXRl0k4ZbeV\n+Tz1oBT9mMXHMhzfvOJnl+/vK5fLM7+5WT5VXr59vyM+/8uvf/HhZzGlF2ajyRortH2++fp/3z/d\njpeXH338pK/3y5Pxr++fGkoq3UCdTluvEaVV+YasrUYArG6VXBXhmdZ+rNYTt5CE5A3GkpHY3vxf\nh736sTtu3p+2BvB+OTd5HMPfa3eO7v5Nvbp8f1fOj8AHb+c3/efXv7n8Laby1Wd//pu3L673R5IM\nyNMaqlcu+2/+oft0sofdZcS0JLuPLv76zbNCKj1ohlONC6vhTUeeTcvO5qlgUsDquKPkXH2YlEBj\n8wJAzV5WjBB63lx0mcfuvEqjGTN3hrowzzYb3t4Evy7P+jdz/wQLPshXh+6D7v1nt8cf3oxffPyz\nt59tpjnDTAZGNFNpf1n/37/8+PzTsZs1n2oFx8+e/PWbZz2VATIDcmy6CHWQrRR0k1m1QigYLZuQ\n0DJyFTG26TAzVleJSCiSRvSbn4+bcjjRH16dcarL7bZTTcybUnaW+685bse0+Vl3j8HfVpydfde/\nef1Xb5cvNr88fLqpM92RjWJNMyPw7X85nH8yvDl3P9WAcPbp4f/17mlvRpgLRkbFFkvQhXy0EKOZ\ncBp0LpmZuaUEb9hzG+olmKFN6qzpzXoDOzt8fdn1PJ6V3UZDPQ2mhHPo7l7fLDV83nb3p0P58E3a\nRTz0Q3n5UL473/3qh4e/G37YHzJBykuGrNlRa2D75cV323dhemBWXnzy9V+fnhWpyg3RSPvS1zl8\ndU81vgKG5lDDOv01ZSoSzBTZ9NyrEas5Gr2ZTlLqzv7ltIvrfnn//JhF1hODcym1v9xefrSt/fZ+\nf8ITHMwD8Pl+7/3x+uWfPPvH/jNM0dSjzJCTnvMxvHvyZJp26I5njurnnz/92d/snpRSiikzE3DK\n+r5O6lY+7/vJ7xEbAVlWrxqUJYslkA4l2uwkWiUoFK21fPfqt3+s2N7N/d2QdSNGnYbxnrt6gxHb\nszk3yzzczWcv3ngZdx8v0/Orq8Plyyef1tNSZIBlVjOjAmbuNcc+jx/c3I5XiPOzNz+bL0dfOrbF\nA9CWxAZVg61wZ2ucMh8F+a13wmrRBteJfj3UmpQ+IOVq+QSlcv4dzkLl6UOPg2DL/t7cxif1nYbo\nqL0J4/7hYrOPzNO7+5vTa799c8cnp9MpM2POrOFkPWUp7kj6dsPjVyO/YZf1mz/MVxuoZMrMkGDJ\nQD8sixemVkdHoynMHxvd1n6wgZuJaAowAGFSsyGZGZMwNFsKi6p32FntpZF1GS63Nt9149Aj7lRR\nErsHfTK/68yo4/KAZ9Npxn0NuUIwdAWWE4ciSRHBzsfl1XO+73r69qKk5CVpULDZbUar9dE6BQpu\nj7PF6hw30+PJZWwDB0XA0RTEDTR0NAIwlSINXX9z4+JZtz/0O0awe/h22FblFDk8WVhPl+MpYCyY\nt88e6vOrpxddDh2lIvdlqrHZNPUcM5bi7B+Wy4NVbk9sdYs1mXSTopYxKvuWJJBtcG2ySazuB8ka\ntwcpVs8eBElhlius0OCelQOhMVjKa3QJm+o4pmbn6+/m4RA2xNhNXmN88/DppaLG8X7z+auHL8bJ\nFrlFZEZazrLR1/dilCHF/l030137vjgBmXdkVpYIjSVr5zXNmxsttUJozESz+JgxQ4TRDGiVY5Uk\nUdmIfEPCSNQE4LYfhnfTeArO026cj103nA/LB/NkgO7GOm1y4K/K9UfPlA/Dn55ufrB5uOurFasB\n+cbm0lMRNYFiMKN3MEg+7fp5Us11P6q2dqnb5IIeNZZKc9JtVcM1Dp+NF88mXFpFZk3/ju+3T3t2\nIzAtniaW8u043AyjWa3nFn54WF4+jP3VMVSsO81nJ+/ms8D7X7zNO/3F+29/MN4M82HRPMnMet1E\nIWnm1vzEhlJUwON4WzqloS9OiQgIjshxyMk7kah1WZprt1kwGtOExuSLQAC5ZhkQNEWDDRM0cxLS\ngk/OIyFtH7I/ltIv80gt77TPCVb85GNn6acDFuwuDt89y5jzz06//WB4X0vPcFd6P9YbGwCGmpOa\nkSl0jo6zHU4YB0MVRHMA9IK6bDSra8YSISKTpDWJ2tpgks19TJPsEaJo/ToFmpd1dBexm954UrbN\nr88Llmq7PJXLJx9fz3Y72XK0cVb2qNRx++TJm2e2n/5o+N1H2wNp1GidDJoeyqimsENjR4mEo1pk\n7SeHL72UZs3ZIdUcSizeNcUxjYCihmgrX08BAWuka6prs9EjME95aVpKRROXLLdGI2Xjt/iAeRxF\n8qHevqxPUKrOh/eKQp3V5fLuzdX4ArdfXP5qc/lgpTiO2StM0zIUAErjWrKaaA5u8OiO/Tzm0j7E\nqpYNXPQ1HGJxawkDUGYs+T11gdXpY0hxydWWY1RLAAgzhFZNSYIbo9HcNvj6wpfB4BuvN/vyBU4l\n5qd7u6hlRpHVPb773Xf98w+/6q/uIeNytI1OmYGBBVQ2BbaafA1kJ3dT7rVhFWSKkEzyWrHRotLy\nYMxptsqe6xyNsUh5Ux+jtYHNvLjadJUNzjVrwgV6zKaaEXN59vWb7Dfsoh7eabOf7rajfz3zagrO\nwzL0r+JJ6X752b/5dn4+zWJiqBgOs0IUV8C7UVqrqcEcnUv04zB1EkQvuaC4Fl3bEsXRXAJuTTwi\nWUdEXQRjM+CZQXJ+P1iRq7WWzQ+IFgkDphsil6U7vTwFlzrDY/4govTd1OvuE3uALeeH6boMo+M9\nfn/z8WlRFRe/GiaVfigNy6EyjVq9y1KUDmSa8rTtF0HZNkcogyPncEbTyCdgRtAefeUZC9n6qJQU\nCcFMa24CmvfWVkgxDZJlyjuXzq9+8pd/NFWwdHfD5Ul2cVvri891jyLvjzjbX/4S+uN4/WlEbyBU\nL88turHrOmeT1zRO5XGstq6vIhnHzkvZOJFRrXOv6j1rFrT8kPYAXMH0751iEkvrqGAmNNXwCgmv\nnh0SYUiYmvD2eP0E394NuxLi5u3p491+6N7FgHgZi5Uol6dh9oH7T778b5+dv95qkdgN9uvvrs7G\n3p3xqLnxRAPsizk7HJyQH7yvOQnez1BQiV2ZsnjTR6+ILVvQDFYPKpioJWVAyyVqPC7Wfdc+P0CD\nJdVOF8UXN39/OL7XT/HLpb66+ug77v79r13LUjfWd7EpGy1nw6vui386++Ktj/fR08f51/sLLlw2\n2/PB6slVW/ZFklYGz+P+IRJAX323ZHeER4WBGSiKGFcFpNbYCYY1kK55WMRGSram9ntWsHknAkow\niiSDZxoskepemV3H2cfHm1Pdjmd5Uy7u+mU+51j6zJg/fqJDf/fwJ3fxk/fY6Ii+Gx6+O3/2RX58\nvLlLP7s+72qbQylL78v89t1pQG9pqMf/y9gvF2NwzhLwVPlICwv0aONqNa3hIgY8ZmKpSMaWCiPJ\nLB7N3HAjYU3+ajBiYfg2Xj+fjr4/nr55Ns6769+cflC/+SC9f8/h/Dt5vvyIL8u7L3/w2x8sy9Dt\nT+cX21evz85sbx88/PB+f7zP62eN8hVgtMOru8HG5/HRzTGTFjmc0qK3SAMMxU6192a91trZCoY0\nPu6VBGUFNFHIdD7ymjSBVq0AxmW1zyWW5Fk/ef9Hef77T54+me/9I/bdl7/s/t1m+XBZzj/4dvFN\n4Lr/dvjsZ8/10Hf1NDy5/M37K5/x8P6v7pUvrnTUcZBUSFXp9DB+efak5C2W8cRSFV3MTdlhysV3\nChRFc+K0D9GOBTVrI4AuQ1aoVLPoWwMLLZmGAJVo0mcBYK22KTmrnHb27M2v8njRXz392WU51v5u\nu+/YLYcXv5v7ccmy/OnRL95a77M/ufrn96PV2ZbjP70vu35jVrV4hjPoI05n8f7N5B2HrkZsQXin\nLjOAajjxo6iw70tltlcuNe5ZaMFAts5NSjek2khFAyNBAxGktaCUeW8ffjajouDbv/3N/e3N8xdn\nT0+d0cvmwRzV6jfPfpj77P3ti7N3n9+z2Fwvrv7+5Vj31eD9sa/7r2IoRgScop1tDt/+7vd3lf1Q\nErMDsXQxZghJeMHCrkaT2gKZsNaO0NQKDlZnYMJqZiZqtF/QJCZqyGcCEZmZUJy2P72+PST6uUqv\nJ/74y+/Kuz9czxeIotiNnZfpV0++uK7zyX78q2dxFJGbi5/dXiMcLh/+9E83p8Mv3vYN3BN8u/vt\nv0zDxZMfn51qsZMbLbOfwkzRWBXbqSXhoaWn5OprEJxEUq3mwWTeTEANTmjWy0ackGbu7aBO2PZs\n/hZy3GzLswNfPPnnevb7w9nIbwec4nbZVt+c/vl3m0udPp6WZ/cDDMP5L24uUHmZw8Mw9R/8+bPu\ng0OEQULYsPv6N32pz65efGBz0qn0JaiOa+c9z8Mu0zwkGI0gvweh/tXUISjCkkYDIW+pbEATvkim\nR0OqlRGvHr768OqY3Yvdu/v46MXP4uqmH/yBtc9acPeCgdF1c5tPn/3hRw+Tuo7nX99e1TnK9TA8\n55v377ufdmflEKFMOIeH31+dn3evfzP1o8WC7LzLHLL0oWjxV6PPYVpX0eOE1PrJhLSOsKjL0nS4\nJBQtBqDlwwnujd5AYbkYj3cvb35588EPnn79i7+Zlmt+a08OuekGP8/Fuq48FHWAd4m7zVd2eC0W\n23776vnl1ZN8OF1+9/7bT59Nx9jdDXMXopPFbvofwzd315qBVOiIktHHXDJZ2CPn0au8/YlXTAD/\n3YPWMqlQ61TL1prtoJXHx60EKEEZV9HF8fe3fNH/2YtPf/rq1M25eXocOw39Rb/czT1qoouvrUsJ\nOX/4ZP7ycHEx9OOr15fQdnvGh/LZ1U/+otry7dY5oCADpdftoT7dnz75N40cLnDCjma1R0oVqhpq\nWFmRToilrZEGUikzgVgq+t1YClbJLdeUHT7qXMAmoezGf+Dpp1svF09+/bfvr5bpwx/rzS0+3g/3\n1sGuxn5RdrXP0oH1+PFXn16FUMY/vL4cY3E+vZ7i7El5x9A2SFuMWQjU4fL19WdffHq6v3uw3WxW\nkFh81PXigURg9yyqo6ZnWsLUrVbvRoMhyawaehpQYJZAtZU1abXFAhKNolhw8+WHH7yb9ex3P+8/\nji8//CYvX3HzybsXD+f12bPX13HqOiBrFxH7j56e/vxlZ+Pu96+utxAWdePZ7c3+TPt4sd3M982W\nHhW5/fiZbh766f1mOGdlpSeW3M4uFIBZrjmrtL6CkLAkkCamWipUSNYXo6QCVmO2zkqP4qRG8Ta9\n5P7Jj5/nOxzuN/nUtP/gs/Li7MUSH/JDy5vt/Lz/mXWeTeB64Jf/9GyX3p9//duLc9ETOjn7cj7q\nrAwP9x2yZCpcYj1COfp47VtbJEQvS+0Q6aqbrP6RFpZs7LAJqLBEEpaSZbJy6NYBz5iN1MwGTTcr\nR8Jacp+EJX5wfWV372p/HsvoX//hj87OLlEP/cOX/u3E/s3VpgZEqzWmL/b5dO/js9e/eLJTC9p0\nHaEBvtt242YLJ8Fwq9ltx824Md+NJD1hkVClJtKtqzM9wtmwtVzdODQq1KwBk29HJ6AAisIkZCQu\nS7bAonYqGNWC5d5d+ru4nT/u79Vdjtv+56eP3y3lsPn6k89+ON2X69c9lUpLO108/Yfzp6V7/vAP\nTy7ax6dDCDM6kPQpO0gWj1K0AmVXjYQli1HoaLCqpXJk0FsqyRqy99+JMOQS/ab7Pj3SckUJ3duE\nvAIj8NU/HsnXX7s+/Y/X76v1tnta7q+2s3f9xx/8/rUdy5/odpEbkx7LD97oycnP+b+NF0Y9Ivhd\nDxCxpGpYyZbnKjAKQ4ZEJjPFjKgaOIXLlbaJivKvxiVkSwNDAhk5nHmmMpESzD1XV6webf+rigPA\ncpqxvXjy0z/7wdX+JTbdZsH97ecffSPD6fWzH8abn/yHl/dfXNhkQCx3F+dfP1Nur/73+dpDaNp2\nQ6R51vDCCss1xUT0lCsks9VO01qhfpbJld2odAop0lcDVMZKwmbtNvY9/EkUYuU+2CDHlqUBZiLT\nxzPX0h//y8unw7KZ0x7Ghw/8q2VwKHflavu7vzv9++27A2BgLJ++0VmMH/zid18O649PE0xcuubB\n8jCCNM9gGpbiLezOK2hmKhPjYhk2sLmefXCMrs2hrYo1aX0LXQvfeEvgowCgLBR9WXHzpsJo29wA\nH/122gSP8fE5u/Tnb/zS96f9hlRRnv/m+nj3NKd95iDT4fwn//zsMp7e/v3lsCFT1oJnKPMUyYY9\nsaVqiY4MJwEZItzmi2rJ/cVch5IRI8Ks6U7b4bTWA5EKbQrykSdb02sVtoKeAtspKzJpnO5rt13e\nTxcXy3iyT/hw1b/5qpzVSlqnZVvO3FSfYU5xPn36rr7Icfgv45lBaVAkGr/L4tnCQpsysqDNN7Y6\ncRFOS09VHLeCbTJ9iKDnGgG66jqt5VzGMvRYhRXN82RIIKL1I63paqeBqKgqzuP73Q8ub7a+/fz+\nd34PXR0ZHUl1NQ6d93h7/uLhFMv+yea7s8CLf9k/b5+0oxjUUqtQW4qlIiUlZFjkBHPtlCozg/Lo\nC4cliOyHmkOfLcy1acGQUiYzwwc+6hOaj9ZEmIN6hENbSyVC0W1qnSZ9+dPpt2Pcf/L69fnw7eJn\n96UnUqr+wf3EY5y++vLPZlCfvJmfxbPjzy82zbIQcLjz8fQLoalkm3NtElmldKMhFQZHcA5bLhSZ\nscVS7Xt+idZC6qzlx43eZAcr6Ka1JycgeNOK0KyJ2uow7uvZJ5/E//x3lz85vHjz+8vzeLoru+5U\nDCHF6eEyhq7/9Df/9GeX9dhdvNmdbc7+tmyGhCKaRVh0iowmdWRtuAACBegEoqbgqjIHzDTpOGaH\nqFtlsWkNPl0n0zU4YvGhYY5US09WGhpf0pARNKwUgIJYhu6jy/fvvl1+9OnX4/t3Z2NXnr7P/fP0\nlOpSrl9095vtxedXv3uLaf/54XSJq394dQ1vqUTrzxK4tN1JrrMbk6aZY0gkYRmgMMiNy8DNSXXp\n+4yC1gq2DOEW5yJk5ehKtRSJlXQyywhRRG1pGS0NAAJ5Xz85fHW07flH714vw/hkLtXunz1sT0cj\nWPyU9/2uu7n6PL8+RHf9cnh29eZnT7vwBrM2t7MEclHDJBNrHg1Uw7pGYBsQ1RiomZqGIDBr9KUF\n7NVVH7kGEJohS6fVE8CmyW2UYAvHZDuB6a2sG4u+fjVdPB2Q726vyv586WaeTud3B9W5dgU2/CF+\nMLz7h//8anNbcT0tn+7u/8vmjOiIzAywfdRYI0ZUw7A0ixihCvOgQzKeZHIL63hi7TZRbVS0dApX\n4ypaL0Jl1ejNdGLe5M4grNIdASlDq2Wuwb7zXK6vcJruDk8Oy/VhzLunX2s4+aY+nTEOy5SHs/OL\n7e5HZdpN0xeXXx358nU5bwxrooloZDJvZtT2QxcjOjeTk4VzI1s4l6J0hGlfrGxisZ1WDBlAhuBN\nAilkltL0UvYYNSTQOiiDxuQab7UqWyvPPj4/1M6mrvLZROxHHrqL47x5eZXn9WGUXf5g/s8PV999\n4927T78YzobjM5XOCQxhbo6GVNSkE49Zdui6cXdms6yrRMeQrOwrg0zLrhyLLUMsGCKtORtaWmfT\npJsh1VuucVFNG0ZCFiQdNBQ8ZjwAEIerzftvJnSDshbOfL/86O3wkH774f1Hdjt151oe/tu3tR66\nbliOH97W8x9/+eorJ1KZq2pm7YSQ9XuDcFdwex9d0ax4qAScXu86N0g13SZ1U7+kW0T3CHiyXXCA\nFh7ifTuIWstuXMV4ekRNbI24Wym1u9d7G1hPx3Gol1u7uxi/e6754vBsOfTL5eZ2mU/bj/7jD38X\nnrdn5xq63c//lqMZU8UtMhvD2xgMWXPv9NDdz3/93btFU+V9Z0Badx+dqALC/CQuQ+3Pmi/se6kd\n0ERizLRG4a5/0KZGUmG2pNJSmlcom6laFf0omKsvWG7r/fzFN3mh6Pen64fxu4fp0+dvNV3Y3/lm\nmerzDw98Mv/ez4hqLZQsLZuv5HE4gyiU6bj5t736o9WrGb2S6nUowwwhCXSvlsJ+2W0i2gZoenM1\n4k4E0DUthdl/zwYVCiyEkDlbZyCCiIZFWFrJfPu8v9vv/228f6pSw95/jrMZxw8/Hfd2ffMdx4ty\nPnxz/3n/P+HsrLuJpJX2QYWWualHvRDTNlh8MoxPP7teEqZkeb+4KDQrn6azoOdwSrN4JCya6Uot\nrtcZDUp7dLG39L2wUpxy+z6Ur80dJCLjuMNt2HL8KX8RfvrQyvnvNrzsZ70qm7iavn7xpPLT8e5w\n/udf1Sdlf1BBIh8LRKbUmBSDIHOUzmrttt3VJoMpoOh+ZAIoYRjc9p5H7y1hGS00ZE2Lb2lMMvs+\nwrYtG1GkkUXBXFq8mGDm9mjWMtThcjPaZD969rX3UmY/zDr68zx7/7vTs/638WLYPOHPl6u/+vYX\nH+w6l8t88UbsklYKRcK5RjxCSfPdWT0co4Hfwz57g3rZAtG6Q+nmfhtVtgawU60JXsMUSxtoVyqs\nPR9pVUsSVdm2FBRCag2IXsbzV8ftXD/rfnF0VT/eHrh9F6ezeO5Hrw/xw6HfnP1+Wq7i/7OxUwY6\nY4b7qhZCSsjMFFRDiGzDQBsDAGLg267l/0jpEbZHn0MXreS1maHdW9D87K1/bY0WHrdLa9YbMbDa\nvNaEcCKFubt+i6p8wX+sTwZGn0ZMuF66h0u9me/8E93nRX3bXX72v+psUFZhVixgINpCaqHRDdmg\nrfpqEOEhgervlwLKoogwyFVj9lLRrN2tSNs6ojZCbJXvgM02Qaxy9RRZMpZc3cyrh8tCz27rdhnP\nx3e9+dBP2S+1Dneoxx9dvF2Wj7vDt8dN/y8z//Tnr663QS/qCixRKo1YYzKZLW8pM9p76ovR+r5L\noPhNTwq+WLWlygpUYFbpa7AAqRYeIgARfLy+RSBaLiJINFOjlJCvS9Cb0JB1eRF3ZfHu3XKzKcti\nQ2Hy4cny24vy4hffLi+ev+i76ye/O/Y/WX55vnuyBLKiy5bevAIu1rQIzjVYEN0w8IRPzg+589C4\nPxbrgXRDdl2XmqIYHhXzAhK+StHV1HntRpb2LO01qQWmJVdtHleJqMAElt3mZRarb7vzjSPqsmx0\nNkz7H88/e5ivLgf+tv/RF0+/ecknL/7XYchSCS8JmAUzm4tKEtt+d+/GUoaOOd29/u723d3h/W0O\n5U1XOMPlFcUrYBmoPZuPoG2ORxYDjcdrIENDoxqUQ6IAYZbpCZloK9YIxL7/4tUyflTnzfaXfVU/\nHfHDvfXLw8VnJ935F5/tbg9l+93Xm/iTfzp+wOmFoTrmzcNstemMWmi6aPLDq34cugxMTKD44d3x\nEvXF9uG4dZ8pF2diPuuBk5XSIMxcZefrdMf2xh/V3aspqhXAIrolzFDav6Ma3Tvbl4cbf9Ld6jJu\nz86W9HLsxofNxfXd9U/+6eXUfXa/vZT9YV7+ePOrJ13lMsqQKkTUgeEuMWnMcCaurC5WrKRmmZvt\ncud5Yd8t5hI47AcIi3XG/XXv68izigVzrRQS5aAa69RgkPW0KoaUBVe3DdRyzLjgj/W1zp++rcPJ\nRt9jg81p9/5whe5av94NuLqa3/zxB/98LJ/96P+hDWBnLi8xL92BiK7Be4Y00JTjyOfKS/f0zyn7\nYXe8Ka7y5nDhlrSScXlyK6J48tX2IpkAz1x9SlwXWGsK1nZcJiFXBRKbYhjrHShA1k/0W/jlw1S6\n+mFZTtbPvvTx7jT6NH1zuHjKmvP0+jfDD37yN8sllJzSlpCHIRJa1owPSMiExTQt0DxVH3a70tlx\nrijLqwEti36PIf0EuSz9+80LgGHrXRCP0tV2Vul7XQVEorRRUbR47F4A+XY4fhN+Pr4eH+qLzanP\ny8XHi+rn+zN9eHsV9szq8dlf/k/6/JO/+eo8jVbTuik9srCSDR1UNjUTkPDIYMl63MMmdFPXWfdm\nuQCjZc94luidHqHSQL/HS39aKtqjZp5Maj1Twe8vOUqijRjxKAsCLY7TjXfDx2/zfLm7eD91ZXu/\nW57sb77g+d2ry+uPLqZnnMv96bMvXv/+6ehmpXSKPkGwoBezCgGasXNkBEE31ETSvePU9fTTfuzM\nUdwE1hr0xZp0k2YNkyLafVMtOHpNM3tMqOEaOxDKFmElWYZWqwaSXKKU5aK+P1um/mry6pzu6nW+\niufj5rvf7yrubk/DZv7w0/3fjB0oRVqi9CoBzoFeOdOUiqYfbXIPo5bj/uHmbtnusiyvrSw5B0su\ncDciKDKglbJsBJ43hBmQmNHu6vn+NGquEzeDI5baFdj6+5EmeMfUeKC9rdt6GOp23037J/XJy91c\nnm2momm5m/p35/jPZdeM/ZnijIyoXSqKOrabXPBYjxJQlrEvZXN12U85vJWbWXGaieFo+Rke/ijk\nWkn6xtM3W+4aP7nO2hLoZqSbpcTsi5dmAUIqG5duC23ybfdq6ms3la0N+ydFxT7Ilxzfle0wfvKD\nn5/OLZVImoMKp6xkVKE2P2SDUREtBty8bK7ONwVL3S7vziyhQKy2nZBbKud01hZgloCyqmEd68Vp\n1tw+JoSwXpbWrO9UKTUaLQUCzJTM7B7nrOdXewbGunSnu4fj5Z3j5qiH4SO8R17//ddXvTlbGiuc\nS7QbTSxtwUyuXXMzEVopDDNhqRFl8yZ9cYKlyKSCpDhEKU0/vyoKE2iygxVBbd7NNstQTeetVDZ+\nBO2uH2ulUyQyZXkEWCmzdFWczajezWN3Fffst3e5v//FtauG0GXQFAalKkwI0zDtI4F2X1kTCbpD\niDB3DHx/bmK6Z84byUHjotOgZDTYR+vxlIbV5mBE522YlLjCiBRBg9UUUR8fcKU4jHCHney2Vw5e\n2ccQvP3dDa5YnsDUnWF5+dmZmbNJFIlwQ5q1Nk/c1EOV2jVoMMJgg6vW9C5juD9uKx2ZFu3OIDB7\nq6bKZvfGo0pnRdTAhLJr8+tqYG9XGrEMxmKk1Ug3sN2EwzWQI8b5kDGUSj+l6rK52JS4e1jszN+d\nP/Tl+M9lFhnKuUVFw4R83AlzV+o0Sd4VMxOpmJN0c1MW7b1MHQ1RwiaHV+RCCQ7SnTBvUOeq/29y\ng8o2kK3QrwQzls1uNC0Sc2bXt7qZZFuPcVzm94qLY9I2IMt97XZ5jjlP3Yd1GjeveXVgyTRYKZRk\nYTSEuVcGVEbXaYr5dGpnB4iJNEQETehhJjM4BxksPZ3uJmQVIzLSmyqyddwCjMk+0a5/aZ2Jl2E3\n5qtflKaQDHA2e9SRSci42fz4u8/eTB8fT6PZkiw1fEi7Pb/6bjmru7z6o0+Op/ttQVnMYul7JH0x\nJWnqVJbSpUddqli89CikVZVcqnFx7I9Ls0tHhuQjDJRZtQcDo3qhoFVASMCRVEZXEhJXvyjp5tPb\nt6euEAlF16J92nFMBud3/B+3b/4PL/f7zbSr3cEUPPZPXjiPX0y//mLpdH/afPTpN8etR2xKIkNW\nC+UJWFazZRlLDGOiM+jcJGPpkm4bnOjzw+F2NwACu7SFYTIoQuzAmK3SJIsCoERSSiYyto+w7Rqm\nnqf3t8GzoRgYZKeIdsmpIUkuD/xP0389pjbLZV/Gin7fD/5QNm/7/ZM4/+zu5R9d/eL9T/78T//x\n9twxl4gCdZQJYoYSqQnWD0ZfTSKZrtUosjnvh7+8/+5Gu9E0K4xstJ11WZFY6mauNJlqUw/L2g0P\n2FD6fqTNVH1z51tiazK0K1vAdqMcacrT+f/w8q/n85dHRtctXu2c6upuv33xcLc///why/v85Itf\nfHf9F2VWZ12W4jC2m00oB+HyBKTlcDzs52VJZqSwTNM8Hw8P8fTf/ccPTjf7FEl6B0RBpcuFKN4D\nQiCixlxDsuag7tYgyebiHfu5bgaWbWe5JMxowdqI5kDWevHjX/71pv6byS3Sl6rYjNm/unv6zebz\nL843F1/V8e2U+2fvfz3/xTS7FKqgWUVRR4RZmkNmK3tmdJPAUIIszpz3b97Yn/zVx8vbe3iKRYYq\nl2xAxOB97wZmSLnUiKWG1j/8mjxHqvD4dir9MHRh20JVlLrU04ImV4r+fPuHbzb2Zf+qL3zflbrE\nPGrUV+PlV2+uL59+/XCeZyf6Jz/yf9j88H6qCnMLJSuyxnrblirMMkhlOpIrG4kW2UzMN2/yx//D\nFw8v7+fIpbarBeLYRcAiO7q7OVRVq2KeTss86/HuVcDLGF99tYwDwKXauOt7o1VwzU9w+BC3x37z\n9E/+2+JZ7o4/ygWHDc6H+qsnn8RXv7755uMROdTN7fsT/ubqYnZas2oSdCS7lDNRlVWWK2yfaNcQ\nmJHOzCSW++/mn/6np/cPUWw9ah2lqi+NtHcrZCGzZoZ7Hg/z9zIjG+zNu+5q57RMs7DBvbCw5X9n\nStovBXk1nt7tSqHPy/72xnuca3vz8vCDD36S5bDZTXfndvb6iye39Y8PWTMAs8LI0rVLKkrj4jNg\nxszvMTOaEe1il4Tq3Sv++38z3y0nMkUzxZhLm+kEuXVd6ZxUtaEfWclVF1By/vrt2eXo67WvkV4G\nZ8kQW0CogYzDi6vjvxRXXfofn330/IsXb/DuFD/uf/XtxfTVLl91Oyud3fzix/m322fHufSFBsAk\nM0PB1Aex7pCs7VJBBNbcOD7ejmea3t18/lf29iAjaFi0qehWk8cq5rauMytuUfuuwc5Wyv7XL+26\nW6YwM84GevEO1TLNm5nfjFP50fvbepUS73/32/6PPvzdYsfNT599O3z63csSF7jvL4/WffCPb/7T\n61/9ZK7talZAiSoFme4hAOlFYT2ygX0JN0WVoHZ1K5j7N5d/ZTenXBYlojLU0mYetQki0RnE0EWR\nDClY3u/Pn26b4V0NZ0UxQwsYb4iCEJfb2/5hLF3n5+PLn3/9i/tZmyf5i/tnT177BR6upzdzRHn2\nwd9/8Gffld0+GVUpc5lXsEvari6VpqIDHIhonVALPpC0Rt8B0Ont+V/6becY6eqGMD16KdkE6HQW\nRoVvFWhY8HLktj+FGkdmJlYFx6HJb9eTRPB7+8vTIawrG396vXHf2e1Xv+d53D05YafTde5H3n71\nxxe/+ol+/YOTl966ziiXdcUKgK5TlM6nqRsluXs2WQfbqdkM8KAIzG/P/92rmyWwdFH8RDU4AliF\n0IhkBlDGVe2Feqy7HcNhfUcErCrpAz3i+1h+M9jx1d3PVHyaDrrv8v3RnrPrypPdd7//dPNu8Xs/\nj94tuvNvf/ujn38+LXNmxlxrCDHPU8rNqoRl2fSVICNaTyoo2n0OuSa+g5hvPv7JzQmI5dizdtYo\nsgAaxwTQSlGUPtbpdnnwM59zVqcayc5q8y+VPrPhJoIAO+su3/WENIxL9NDWDlelzs/9dLzsauGb\nYbw/cDN8lL/+o+1mnJiiFXOgmsGyB3sxj9qWyNZotwM/E0AmWyb644VZx9ufPj8oq+UuExFtkmvc\nJwW1C1ztkUXjwzQMNYt1nddQgS0ziyB0yHanplGSnbB5IQhxOUulQ/HDZHg3XhyXuxc5bo9vxwPq\nHa/z5r9eHDeTMoHIzM4x2Bx9ouQi71FpQTaitw20TXDQLDwNioEOyx+fDp2pjlO7sU9skiW1yVqx\nVPYhgbBS7zFqIlAwz3KazXeAydw0JUJr3kD/cPvinYo45J1hbgT0tiy/+2Cbp+PF4fxTHpddZ6cb\ne/7qAz8zc5m3RLtYUDtfUsiCZMoZjW80w6obX6+xfiSRgHpz+fn9gVOOuWbqrVqfloeSdeZ4da4Q\nJeMhh1Ll7DjNgls1g5NpJdYNt4Ip/q7OxoAfrEARY87wGObb55eb+ykXbaPuL67j9M1zXZVTLD4X\nl8SoYpeFceyKicwINexGzY3+qK9C2JrZRgKYDz9djggO0cD+Fd9Y+/KIYei8jYJm077vLYvAENyZ\nsqJZLNSMYo833ojycedjoE8UWQ8+mbEfOXTzq280qJ+PE8lYlheb/YPHnVtn5jBmgWpyg6OVREYx\nmmwFAyORLXINaM4ktnuZKeD+8qP3Ge7RwhKamKgJ54noNh1Xl7tpv/RczGCMpBlW41AoalFmEilK\njtRn9w9l7suCTJsPY7nvc7audMOJUVK7jS893vvDZ7d3+Jc5a4c8UWlZPKDl2G8YcDFk1lZPY7ge\n8QCJgdWXIUE5n368KLty6tolYVhR7IQUdbtpGAzoPh/NlQZrES5MFNOcELyAjWLjiqbotaH0U6a8\nPtQvKkc8OFU1+qSYtro9zmXsf/3shz8YXv7fzxcIQTctU+9ZunEIqAiwgnwUBjbF44r6N3SGIOEi\nkfcfDrfLiOrf20wIb+H6id0q9SKQx6WHDNYum3bAYGQNh+SMRkI1CQOXPcrmVI20o/2E7+4S81hD\nJ4zCcjufHYZuutjuvnr1nNdeuzrV4nRXVmrsFQj2aRF6TOhsfkJbs88b598+RaM2T/hwv/TRFjXX\naxfTjFlradoqALT5qJI1U1gWM4dUrMBgsEhvIpvkCuLue7/YpxNZ9UX3Tz+2kdxMOQBd1TCyfDn5\n7D9/tvnhFz/7LxcczmE1pILMXh2UmNLqsB6ZjS3G9xEzJNPahRrCaleIw+fTuKntAaWkLMVgEjki\nMtMMspw1ukmlpZgboKqCJFJppWDtASqt0mOXm3TKtPlw/750GwKmHFhS2sXv+Xz3sPTL8PR6+jf/\nt3sexuPgu6i+ZPNIs8yk0VrKSoDRyNRENjX86hD5noxUHp71e0/Cl/VeZTUxMugOtBgis9OpmJII\nb4lhEoACBkwGS0S0Iw60jKswGBPdPMz3xZ44Ndx0nfeCnS/9cf/0430uT/JvPhj7Jz8+/sfb0yzU\nSIkqnnJkZ1LYarJtSk1PlTZMi9nybx+DdKiTv/hVF5A8m5gFsGBAREGzXcLyFCOiXcVdm77Q5lJg\nNQpgIefqIKdJXdzC3ICKW5156IiuO+5MkZFnXwuzjnfP31365mH3sxfH+PzwQ+1v0pnudIiuJVIl\nYWJw1cuuHXRLmbL1AiKsXCTq9OH/sxdUTVijhBqSnSqybAdcnSkQUahQ7RwZbRarSqFRIY/kJzIj\nZutrIoENw5dK7OnDHMOBs5+en5dpeyr46MfLxdPL4Xq5/Ze//e75Tktac3a55gGd1kXTLoY0ZpPT\ntquTEHqEvtmOoGc5C03c0sKlwISMXtc7QYjaLhpWUhV9J6CstCzQnHhNGqAUmD50lqHQgquhQpUb\nyI59sCxXD/78zeX0jbrl5bsPr97qfbmI7e5y+SoYTeWexuPSod0QuiZxtBfU0kzEBkAxHVRmu432\nsLs+0MzB75PkgDTQj5PWC/uWNknQtJ98YC4VchoRlWZAiyJfVcasQ2x8Vmr3oW4XDNI+hzJh8Knv\nTscn96//7E/Oyk6/v/9sev50kQzeuQDvgkiKJytNwQFThkCt92UbHwMmEdluvKSJSMz5xcligpVS\n1gzvBJFWdHe/r9ZON2VWdZYohdkIzc5gOUOCD+vx0X5uLRossXk6vn3fMoTQlcEfdn74IB+O+w9f\n/ePL6LXU++2XfzpbXUos5ouJ2bkTlvuBXSNWUNgoCbVs6ta5RrarWFaLLkTm4QNZXeqSKaebtWgg\nN+91f3AnGekOpXmdNFhmhdr1gIRb1JqrFq0dHWUJLknWd2/DO5aFsP4wDWU+mZ9kw9mLh/e/evfT\nnu9/s3v5hzHVkxYpWZgq2E2nDawRyplGmhvM7TGUA1z7RTbWCxAY01V3mmpdalUFihOGpBm9c++M\nxFzNTdZhVu8ZBhQXF2uX9pWiMFulWoBgd04XaH1BWLE69yOmJSq2d6fNctpsh52/eTv4F/r262Ej\nzoGa9E5OkV4ezFrLC1qT7AGsyEc6rAUQrld0YXUxLP321RJLXXJZllxEmhUSsST7kiAWs/mwuM+L\nulIjgx2kxQA+2EArXa3rLAaaDSdDP2HDELg9mRCHYqg03W3PvjzOL/b1efnVjbG7XrYx9a0WqGNH\nSN3p/c69CjI3WrSQvqaBUEIyZn4vbFG7b0CW2Lw3X5Z5OUWdlzkj6X3Pepjq1iDO8kx1XZwwdEsS\nZkyPMDDDU5ljrKOGIGVfRzk8F0Z2WT1YpM6CY+LZN6Xc55/ETXfx4Qd58zXGBV2SNBPMQZi/1bZj\nEopmBmucO7WyYEo9jn7rw7SxY76Sb4fBgGU5LfNU5wqo+uVohFmdoi619DlrGDOMYaUSKVOtQkLs\nsM6VgKHMt9sS6o6WuesO7CGFd0VSZ9/a15U3+GF3Njx798/dfhiyLJkJrPB8P76/3XX+GGoOWQud\nedT+Y30erYkVq6FbiXkznFBK527KmKf94XA6vY2zYelLShm1TlHiULsBpxabJkBm6jokCV9VGk3c\nU+a0/dQjlUN3LFwqhsglNrbEyKu89z5vP+jf/e7JzopN1smRkaR5Devmt37ulqtwpM3V1l540zSb\nrfHKq5KhPRJR8fwYpHvpShPmnu7e63o0OgRqqWGDp7rB0zzTTMRcZduxt5PRPK0NTu3CgaXfLFOp\n6HZxC8Pcl4ogkJrrGDu9fD/sloe4OvO5j4oKCegLjPLRX9WL0ZeWxPioRFlVNW1lZa7quhVjwGpU\nyuOLmusQUgY3A+f66YanqaMXP81hY48p0GnORDQQ0mQD+7KEWL3XYzwVoBxyqOSmNc6etjkVc57k\n/XEZxpHvv9bTD/o7PO8yC1JpwhDuwNa+PQ3bLsBswT/tAzdnGZPtDuT1Ou51k3x/rdl8WY7RJOjG\nUsw9OSge0NPL8V2lWz1pHLxWAbIO0tCZLc0KkiylJYG3riHLaSCz3i80AbEs0XklQzZOqByv99++\nLNd+/6arAj0ElEgjN+Xl282Tba2QoSknVxpe/z1hpgG0xOON6S0PEtJUzvdtXE+EeUkVHJZlPHe3\n+28W887mKINNJ6Gic1Sx1LSk9d6UxdHGTUDmpxITPdsd9R2cR9b5OJyLp934nlyedsv+F9NP728K\n3AW4vNTBxh1+8+78bMj62FpDmXLTYzLDmiezmo6TrjUEliQVy4t5aaMDiGBnfe4Tu53i/vXcu+Oo\nYWANkoT7nNSy9MVKItt9go/1lQiy2EJ4Av0REUB/DqFEHSufjf02TmfdD1/84en1TgQUsBy0G4Bv\n3t8/ueoVDWMCYE0V1WQfIIAQHRQYJckF66WhTc95eopa2oToTNrYL8+ts3o67CtL7zU6FKmSsL5k\nqGAuVqyrR0sBlZ5Yffukna6k6m4zobI5GPPigfKucq46fjLmwK9/cPH+h0MttkTRstkSh9u7OL96\norU6o3XZSQneoKfmTs8EGshM0Jp4EC3AaNqNpydq9g2T3AArlsc3kxlHj6UfHTUDLKCCJaccvFTU\npU83ekuCaU4CbTCruGk2S8roNltGlnPw/qNvNH3+L7615bP+/VaW7NEb7k97Ha52xRdbAT+0TBO5\nrWFUCGO7Ht2Stspbm+2+9aVC1bOXYesJRqf5MoeOr+at1Q0V3necFisTBzBhmnNktXmhUxaJfq1M\npBO7wx29LkkLg6WYS1T4sNsg9HR6G08xnb/bfb4sVDVmHO7vdPHiw0tnI+ASUOTjy2+7HaLa5MdW\nch1Qg45XPsOk5elcA5I7gjR6f7x98+1UMjvDrM5znmosGDqnAVWlRrWYqciadaEhhMb+YaP0HPua\nhV60ycSZkHFguXN9OMf7s75Cf9d9sJS+z7lmd3Z1vnMJqkttTss0ga1CrPtO66ELQEascrCGU2Pt\n16fzbubj/cmkbMDD0caxGw0VVjwWlEzvMxBGukjYwFColdUkmx/K/eT9qCVAVeND53nq7yNxOj7F\n7XH3bP+g8+P+/MObYI2lFi/uZsVcrYeVmtPIRLpB1m6HYpM4tXMkH0VTtUbVmuoizN3ZoU0L6U6k\naPRipaNqluLTUs2NDnRMLEsQNQ2oy2LjMFz08YhDW3nz+3lJeJohasn5fFnOjMbx26vndlyuNg/7\nZ7tv3g/5B1Bdj0AxMyFSxCqchTXAM6PdwfU45rPpz9YxFu1BgdV0mDUvTwlrCmElICuAuaXSvKvL\nFFHTO7rFKhxWmoJA32XgjKsOKxPjR2dzjpmyKsLf76ouK1DGef44336XvT+8/YHvjxdXwzQvqW5w\nQ9bmWo0107Vmgg1Ue7x+qMUNklzZSWRjC9C0dRKJ+VkuQF0DB5prgL0J8KEgwc4SfTEtQsZ69Rqc\nyTGr5t6j6cEA+FkRKqwz0W1Tlw3n2a2byx/Oz+f9ha6w3FxS5erDfJgr+3YbX4sk4+qRa/os0dsE\nDkDREl/Xj26QU/UxiG1lDY5n3UkqanGUpIJwZ0aWEqdaBhNL54CAUGQ131iCaJGIwPeeCLt/PfWa\nCxBk8a689SfHqRTJltdfeL0YHtiHTrfHt/anO216BemFIE1BFj4ylGx+Laxa1CbQbjePZW0EBR6n\n8VUOWcvFkrlUIsMcgHs3KJUl5zlASKU3gqjzLPPMrvz/AL8n+AMGG/NIAAAAAElFTkSuQmCC\n", - "text/plain": [ - "" - ] - }, - "execution_count": 8, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "# 输入是一个batch,batch_size=1\n", - "input = to_tensor(lena).unsqueeze(0) \n", - "\n", - "# 锐化卷积核\n", - "kernel = t.ones(3, 3)/-9.\n", - "kernel[1][1] = 1\n", - "conv = nn.Conv2d(1, 1, (3, 3), 1, bias=False)\n", - "conv.weight.data = kernel.view(1, 1, 3, 3)\n", - "\n", - "out = conv(V(input))\n", - "to_pil(out.data.squeeze(0))" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "除了上述的使用,图像的卷积操作还有各种变体,具体可以参照此处动图[^2]介绍。\n", - "[^2]: https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "池化层可以看作是一种特殊的卷积层,用来下采样。但池化层没有可学习参数,其weight是固定的。" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "metadata": { - "scrolled": true - }, - "outputs": [ - { - "data": { - "text/plain": [ - "[]" - ] - }, - "execution_count": 9, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "pool = nn.AvgPool2d(2,2)\n", - "list(pool.parameters())" - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "metadata": {}, - "outputs": [ - { - "data": { - "image/png": "iVBORw0KGgoAAAANSUhEUgAAAGQAAABkCAAAAABVicqIAAAZt0lEQVR4nEV6Sa9lWXbW96299znn\nNq9/8eK96DIiMiOzsquiqih32FbJNjAwMp5YDJAQA5gwYMIY8Q8QI4QsMfCkQJjGSBjbYKrswmVX\n48rKqsiK7CIjMzKjyXgRr73NafZeazG4YXMHV1dXV3efs9c63/qazd+hmrnUlfRPPmyb6KSHOkYJ\nhCGAhIkL3IsQEPeQaTQ6SumGCzvj0FFE4UIyvT6xb789rkg4hRpd8952LDQ3GAgdvBQxguJ0E4Bu\nQndXKpywoIRR3d0AOBzeZQ3B4AQc8O68CdfuqIoQhNEZ5eP/LTCQjMFVzeEW6Eo3pxlIVwNAd4cT\nBvD58iwghXG+zB5iTFFIV9UHHa69nIu7uxkIxsd/+UTcCUgFLYV0Ak44ACfhpACkgySFBBhIgrBE\nCCXYIpeYggQJIBznh1p9eVrUGSgMVfrwRz3EhUIR10IhXGy1isEdbu5mZg5XLcXVzZ5/62YwBuNQ\n3Al3GAE6Hsx994tubu4IKXx4xxBFQISKXiBW1I0AHCDgZgDc3DzQclZTs1KccCfdQRdhKPZ8a0gR\non2cw+u7xYXu8Hd+ko0iJKUKdAfIQDzfLJiv3kAXV4O7GszcHauqA4CCsFLMnQ6nUISfner6L4Sc\nPQR5/x0PiRARiWJFCTMtjmIA4OrmMHNzBWHmAlgpZv5X9wkDGGmmpZibwwnAkD8ruHazL8rw7k8s\n1cQgAGPuB2B1eU7CHSThBqGbwBxOCVKPa+SianCKgAZzumkp6mYGdS1FcXhk1VeanP3Ddy1Fc/VI\nD9qWMHGKCAxiXD2CIEgQEHNvQorjtRh9dvy0j4xiBqdT3RwFoBsBteKGcn9jevFvvJU+umMhktkl\nepB2kddIAOZugAACQFzECHf1OGlymFZMqpsXvnD0wcID4W5mcJA0JWEGV1PQnz66WX/l8Tt3LET0\nJm7CShfHbRSSsOKqVuhm7mqmbjpaH6nYktL3Q6uBiFd++WbOpZhbMTUA4l7MSynmDnfaRwvbvHG7\nMImbEEEqdiePlSRI6qBqMFMzdbhbiuXs3HVou0E9u9TBLcc3fyEscimrX0KEbsVZ1AARhtDd0/yd\nXhKZDYUiKc8fn8FBqJqRgTA3e35VcrToYVB61/aMLE4Cvv1r1+bzvpg5nRKCl76jAAQgkOrp4Te+\nFYUoknJqTqO2x08DSHeHggI6VkAlxrXT0EWJFqPTy9COEqMIneFn198/HkNpihDE2zZILMEBoVDC\nf/y2S+2DN7kaH31JutkjE7oLV93kTkJWkMXSZqHCtMDc1VOguufM8cbX/+nmbMjFCCejHz+ZQ8xJ\nAowP/0IZvViTsT5c+jVZPjkVAQCYklSBrOCREHlmVQwciguAwd1d2Lb11lrQNvxGbrWompsztPee\ndgICAKT907lX0bUaOEm7X09y9FAC4e5cbakJnAYQkHkWUtUDA2NMpiUvziYHI1VV7a7++rwr5k7C\nQ3z62bnTAYD8/nlsAkrtNpqs/czmXB71CWbucDoYAszoJN0l5ySiRauYWNWelXaEa+MM6xXU/LWv\nLOeDEoCQ5elxK3CA8cHtIYp3Il7JxmtbeV2OAoTmK4A0dcIFDoCxje6Oqo6Mbjm7leP9XZXS5oCS\nfci/9Q+sVQdIQTi7f5IJY+z+LLMWQ9WH0Nzc9kktCGAQ+F/Du0AdTht3fQkwYxUnElBKWcqNDe3m\nAxJDCKq5fPXvzAZzdzcJ5eGDucMpbz2NyV3rPkJe3S+xgUTA1QG4rHrLhUJovroQFEqMSNq753m5\ncZ2tZaMIDKpa3L/+8lIdBBCku/upX6iqB7dDim7Bg7QvvqjV2CzC3N2wGigr8Aows9FahkQTiZUK\nbV62Xlpf1oOOaijEjBAUn/76vzWYQAh4++Wf1+B/lFOCGDUtrn3JwzoU0Shubi50z6pOhxmsvLB0\nklECS65k1t46sCVmo9oLkZkiPUez5a2rnQvdNeeL/+xa2Ujpq+8DyLSo06+O8oYoXEIIMThEBHC6\niASBBLkyz0DxkPo26qx/Y7/vB29ywJB7iZ7V3IHBfrOYCLVdfulfbaXpMOteA0xtgPibl/KEBY4Y\nGGzFSUIIIQkoCOK+cVe8b2LoMJbWX2nmIg206dmHZjV3hiqItFe//qmKLvt/9IvdXmpzTfEgnYcw\n339Nm8rcTWIkSMAhIhSQpAJo4jms9t7q2GfexJlMKg3ehflIvE2s7FlT04uWW3JY8sVfqO2CtRYT\nngZjzMHwlVFuHA5oJJwxEMJAQuA0V+S9ZaxtYGPWYrRjg05cBmvkdGreSq77oZEhGGTe7nxev7je\nVql3CUJ8EoMVIN26lms6aO7RxUhIEAbCYeJG47CPJngIEi3KLpc6jn46TUuMVHqJaR7qXPmQgyvT\njY3B15FdJAO85y4ews4blAR3AJRVC7sI3OkON7i7bVq1NSIy6Ttl6JswW45tHpKxYOj7yL4u/aIt\ny/PMYJNmKArmZ6onH6k5U76+YzWdpDBEA9zwHCJDoBgd7rGMB+37SZF9xfl2MJS5r4k6tIvdEOrY\n21APZ9wWroUBdAl8tMy7d5ZJMmz/utVhxYDEBWYGEZiDWJE0QEuwECduJewN+WQqfWZYxq6YLdT6\nbFh2XTssumYD44m2mUIJZ3P73L5lUSHTL048Pcd2IAJwrgaKCAADYdARYzVMtd4puthOfW7CMM1h\n2YdUMDIbMHeWZttHsYNIiEDEY7RVfj+pRN+7bBVdIeZ0RAQLFR0EaWpOC07VRkbZh/PNnu2W52UK\n3drJPI/RPHh51Fr081TZfkgj7z2JRELi4/nH7a+8PZ+6hHAzWsKqwHCPhAiFpK9IIyMImFW1+2i6\nqBfbyF0Vl9PTpcw+ib8aP+05OZ63p7+1m9eawoAQIE4++cH78rP7/7pyrcrBQRmJwelwh0cGZRCh\nr+iiwwkynI/XHh55jic7krtKuumJ8lTH7U/LXYyPj5b1L1bVtmW6B7rTP/7Jw3byxmv/+agi2L8Y\nWIEQd9AdMdCxamenkwBocD68ao7JcL7R5LZiHs365mg4PGIt69XD2KW/d2F7VFYFZAjPvvlA1196\n+cq/eye51rhwYGti9NUyQIwWXOFGBwS++hCqD3/OpmepH0+GRc1uzeejD878tH01/vhKvxxP/v7a\njaqrYSJCefb9ki7v7k42f+entRQP/kqDaiWnnE5oJFalAAA+ZxMu6XAZd8oT27TjqffVUSg/HvJZ\nmeCnlzbvLa//avMF9KF3VuDwR+9eDa9MQ9r43e/VwS3m8bUyJVasAnSnQOCqIEnC3GHmCCn95Thv\nnm2U0zXkCt2dDy2fH7TbH9ni2YWtV5o3ksEHNfHhm2+txXRqo/jb366iwMkrwRvQSQGFhAtAz8VW\n9JPqBCkMox893YqjPNSBkhfHqbYn07K2QN2Wza/pFzBkU42hWP3iVjqrS/Po39xJDX1gGF9aVgG2\nQi2u/s8dRd2dpJg7aABjCnfXGilLWXgoTx4/Y197XR/1i2U6Tde75bzVIrlpmnjl4OH5o/Tw956w\nkugabHvDR6u5T4qDQDTQzZ9rLAmBhIOCVJrj4zSe7/jJe6VdxGo9VzvLvLy+d+HF7TQsK5lN1mu6\nVLu4f0VtISKIA51XpYGt1K07CK56YIUqTllpGglgqI54Nw9lenR4cshzOVPWn+6Nlld/7o0X1Nla\nnFejSKma6e54ftqzDQHimb61z/V+MSjcV8V3xOdiF6S4CBnopru5VCcnfZObMHt8JKVJ/XG8VC8P\nL/7yzv6isHg8aeokZJDqwkY8fnRx1IYQe1T5YkrNoB2CiAhgDkSaqa/AHohR3MnQeUS5vTHb5Ozz\nuS+3BZK6Nuver0z2xv6sPu+HuhIBhQjrm+OzY6laETdBvOmTyGzmvTMKHHBRd/pKgtLMAQqliJZ4\n99laHD4405RO9vqSJs/KO39rcqFjM1qcaaoDFKrqmOxs+LknSvAcyuUpa0McjaoE7xftYCvShRXZ\nptMNgMGJOLl4Wnl3No8NuT0bwrDWnH0p3BgVVuGHRzmmaEPRUooitT4fNkYSMoVXvKK7e4ixChTP\nXTcIKEJ3CCBOPrc9KHuPv/Pu1c/e3z6/1PGsjzb1buP6a+MU0tGH8vhP/uTjbH3RIfcP/u+dTJ1t\nmlgBmt0ycTXN2Uxi3VT0vIgQFDOHUSCkUJ3McmBbX2refhz7gwcb7eLq/TIft796c7r2+e6d++P5\nfp795fx6rdFP3/l4GO31pVsrSU2420h0d3cXuAvpVjxCXYsDwkJRNQNM93clxpe+e7Z76eHupxfq\n46MXP5odf/XC1pqsf+etrhvGy1dv2dl0Mw0//uGztokqQ53YO7p9GeGvjB+H0dSDSclD3/cuAMyM\nrqp99Utus/vfPLyK97Yjzprp6fyirL10TRr84M8fd+36P3zj3sz7IpPv/Pfjr+0vFy6tjFjo9UWZ\nrNwrN4eV0ucua5Tn0oESwBiTQIZP3n6h+7OPFtcfjXzcIT/bap9mvLGVN+Wdp5vdK4drr974b7mt\nMT7/5JLvlwcF9VBPjyHlwlpY2XzidDOzknNYjwLj6mGkhCqm6B/Wr+1t/8F7wxufhcmmbI/bpSYf\nbly7tDF6+7MLVz6uD26V+Y1ZKRg/2j4Y7d0elVK1o1BUeJE1zCzEImZqxYpc3JxGYS6ucEpY0fl6\n82+W9/7Hh5cO8mwzdKN0cFeyZHn9cpr88N5eHV/uZw8fHGwuo+ey+dLR/NvnFx+59BueEdIVrc0d\n5q4oWhTbu5OIqHnQbjAAJM1M9JfO3rLDapHSeFZdPmtKVMDeXN/bvPPWpTSOLFuxlqFin/Noq9Cn\nZ4dZS6UObcaB7kYoYGZFru5FN8S+5FzM3Vdz3sLoWG6+55f2lt2z6Xa7OU6jZZHJrUtrh7//QnBU\nVaXitqwzhzytm36MgRFt0sKyLw1tZd+alcIbF2N2SBxU1QCsTFYbb6TZu/l8I47C+6n7+ODz+38X\nd5nfbK7w309qumkI0X1gMil91CBMwYJ0ZZyTHkizwlt392JXL4iTIcZetbj+tVfZmS42h5g63svR\n20fBhhvvaX312vQ/Pd0W0TIMManmYq6pD/Qg6NRkOWx5hc2qHoDVvmjZ2A9GSZFSDJodZmqOsp4n\nm0PTcNmnUZ1KHpW1gK+Ul3nzs2+NVAHzoR9US26z6mI1r+dDRFc2abtNWk0mN9MSrtTurBpxEYFp\nWTE903AQn85jN9u/FLaK9WF/49YXPjrxm5ftt3Ws5qamXTsUy6VY7svI3MupV8j9BP1+qNQcQsLN\ndjeBMKqEjKLQXIwUF+IsxYHzfmqdLwvH0r4of3i//xr3/svjagoz10KziL5vp46+1AO0PWIwnU/H\nw34IK7YAQsNBUtbRXCgxIIRIOMiwvbnQszBZhlMs0uVY9/bO4s78hd945fjtIDUFBjM3Vc8hS2kX\nCyvl81agPK+m9Vak+8pNdlufOFMwBynFJEw3K8DBcH5oGaJnBz7YxY1uI9298nloXhvv/oejxLpx\nw8pYRxjz84fHz54eQ/P9EIKHZ5hcraK5O9yKuW9GxIqMgYju4Np6glLgEFRnk/eufpImJV18e/Pm\nN+flZ1548uQHT3YW3nSBZhqc4jJhNfKXmlPPj+eMI4tHGO8h+mrw0SxMIUkcdGdUh3otDoEQLhvJ\nbj57utPw2cEtf7LV/e3LH8jZW2Ote4+BJIskA9jUV2VjyFruiddgPOH0gothZbO7pwohAjQhpZRs\nNCcdhNr2mV98OpvY2cmhXpl+uvvqHu6v/25US6W3Qrq5dkOfi4zWpxGDHZ4giJt1w2QUVvGNO1BS\nkooQJ4HYe4IbVpGBXzgZD+0MMmUbPnj5k6u/sjf6X1e/W3aO2HgfBgvuMLQxINq9zT4fV+8jMCCF\nTseBgDudbu5BquCEiBhiF2hBJES66w4wzQ+YlmuT+fn2R9M3dl76r/X8ETon2MfOUChV+zFGacij\n9aPR2uFhyBPz3CyXG2klpUAjLKVqlWuBiO6G+rkkmsqnrx01dnxhujbH7lE8aGbv9y98o+4KWeoe\nbe+BUo2vt5amu83avVK++dwVrTmsizgQFE53r+tgJEl3jwYHi7sANe5c70/X8tHENssz3cV+Xr/7\nxW+UlDTCUyFaqyhDrG3w7lQeI75/FCEamBCtdkDcxRVummpxhJXvLaqZKQ8GsJ/p9snTyaxqZ7vl\nVldtPpylV/7wSEiKqPS2OGv7Yrkg1qGcnDEevueeaktCRpUVewugmxVEShDkghCjq21Mq+JGMMoi\nh3ZedViOm6tPriy4+cfvBY0O62NcSjmlJINJasIk5dLd7hliseRCKVUWN3cC7mpB4FATMXhUQXye\nEsJ4NgwTicPoyafcGOzgo+99ly7RgQxK9uX5gaYoMU5iQpr1jxgTiLoUrW1VCwfdzDGVohBR9ZUV\n1fVR6CRR/HG3E8UWG90DuTisjX80CBgYNGgQZ5h/0mFtSg6JinyoqYYMIZTARNeOTndTc5UpDUIt\nunz0bjTqIi/pRmh/dX+4fxJtUHnx0+Vivb5664+fhroyybTciOfdzk9N1WKLUnIfCDoE0cJQ94cY\nVSGKA6WMpmpQL7MP/uynOY5sWc7rAhjzC/hztjtDHreznc8/fFMe8s5v/v7DldUGL8G69ZsXmhiZ\nxBZhfOmNr/zFu52WBgayXvbaUUJT0a3sVYPSyuz27x1Od+OLz8RmYaFuYfvsT6/9qC95uPn4J1++\nsHfw9nhjsnhz5oWRFrwkJ8ZBjVEHQzUNzauvPf7Wd/vkRel13kDfd+UkBpEexYRY3P6D053tKK+9\n+cJOPmvdGbsnr2CG8wuzyd7RvZTfG9fVXvfuGzY44JG0EQfNRnMUDXWSnO3yP/4X13KXc9YQJuPJ\nxtak6s+O/eTYIDZ75/e6g61KRC6/ernJIFGGW5MPWJqtn+kX23gwPJg16xvv7Z9d7EMIZEJskjto\nEsUQk5eSi9oX/uXrx+1gfZIU63q0vr65NlmvdkISe+cbabder1zacOHqhoNRiKvDrF6z248+Xdgn\nG58EHHeyV23vDwSDSKimG22UXCVzpkjt+5yL2cY/v7woLoEEQ6iayfoat2+E4Cc/ibvV+hoauX02\nPthxUoLIk/vLbCpPLobDnXf3hiXbxSXIhmT3bOjL2lYuOg2gC91yybmYmq3/k9jO4io9EYYQm2pn\nR6D3n1ySeqOwkbuP02hzcILkbAJMFirrOc7Gi63dxdEQ5YdXNsoqEdOmztbQnUKaWlY3czO7+hv9\nslFx54pG+ORGQ+8+CKGZdotaZDjxSgZzgvHMY8zFZNo2C+SFTM9SOMVOTrES7zk2VpQiMMOKqKkX\nB81+bbqYiJCAmaluv3y9oj0+HFU1lpEmed7DzQgKtp5K3Wrejb1NZs/aeT337mDz8XZyDSqy1kyj\nhEjY6uXFDOpwXf/FYbpip2ZZ5fWDStB+kL2ue68B8aVKTSXgXM5GS+fowmOItnX7VDYD49cflLhk\ngK5dGpc6mAhBNxhUV6cC3PTnMQGeRyQ63kmRevgkMea+EluKF/WmUtgwFK63CKPRR1ZTg9R867o8\nu/tELl/ZmjplfSQ5peAUFjM1MyvqbmauN/Yah6k7xTkJg7P7RELCoFGHHNl1UgU3JobQimEQPQl9\n0+a1tjmDjn94ht0tn7cbGzG0kxEDVla1azDVnMzEvf5iQ0AKAMZQEO34qdGHeQyd1ZL6VsPIzOHe\nNMhpV72lIAi26iiPN65vyfGDjwefjmGn0yoF+f/muxVdlUd/Pg4ZTCII1fLusejpcsg400lXAoWl\nLXFi5gb6JKTR8ZAYG60mZyW++8A+uHXNtM9DvTbivNuIpJDigLuWbJrNQPjr9SefHrUmkupp+Pgk\n5aO81DxMy0CcRw1DaoKaC2QuQealif30LGtMa1e/9/7W+r33XW2YbK6LP2w2CIJ0WRkC9JxKcAc3\nl59xNJ42TTNtTrHN88PT2fR0HYM058eRuQ+jaTEHuajJISVVbXKVy4CffX/r5MfI1DDamsrZJ19e\n+SYAXElQNRR1DQLuNjIsHyNMb1za7zbKxx8/XluMxi3q0vayfREyHtFhDPMyWBWHQvdaHffvb7+G\nqh3ENK5vQG/vXgiEm4is4lq3XLwUNXV/Ja5vraXlw9vf276yWZ/d+RTFtvrBYp5bvInG0ubCKWIh\n5BS6yioVM4nxXvXK1INaqbY2Jrjtr1R4Hh2YuCghtCGYqSgOxmiE9JNh09e6H3zf6nJxOJdJyqkW\n3Zo6xoQLWZaiQwjW18MFGdfV+vFn3QMPLtPNtdn3+9enz8WaQ1Zqqrjr0Gd1eLrWeYjTzYOd5Zb9\nn//Z17YTTrWedpareLo960OgEc68LA0hEevtxixXEm3rrpbA0MyP/GB/ugrYAIqYKE3MSvBStAT3\nl36aYgrc5nD0p+/mTVyYHM0m69QZLJ5+jn40CgoRyUdBK0kopR9d/+kLH758PBw8jaD0zeXtSYo0\nBTU4QVAIh7karAQVv+a91FrX5Xb/sKm78ejkzHYm2g+5j+Gsks3dJyTQPwvihI6fpb7diqPp07WH\nEzWm7YOtKiXh6riB0whBMBe4uWaREIKtb8zAOoHzfjOcjMZHC9najOenfTEp7XJ77WrtbmaTPohI\nZfUkVecv3dteVJ/bFcXW/kZKFDcJukq8HCRDogQHTM2sQG7kkoMEidOdkCZ9CaPLo+XJaQl7Yrq0\nAiGI/txNQtXkF2e1HY5Hm/1MH9ubO9tVCBJAcSdWh8ugcEEg6aZqpup2qxmGITTV5ELs8snQVDsb\n/uzzBXYP/h8CwMEc4WwTQAAAAABJRU5ErkJggg==\n", - "text/plain": [ - "" - ] - }, - "execution_count": 10, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "out = pool(V(input))\n", - "to_pil(out.data.squeeze(0))" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "除了卷积层和池化层,深度学习中还将常用到以下几个层:\n", - "- Linear:全连接层。\n", - "- BatchNorm:批规范化层,分为1D、2D和3D。除了标准的BatchNorm之外,还有在风格迁移中常用到的InstanceNorm层。\n", - "- Dropout:dropout层,用来防止过拟合,同样分为1D、2D和3D。\n", - "下面通过例子来说明它们的使用。" - ] - }, - { - "cell_type": "code", - "execution_count": 11, - "metadata": { - "scrolled": true - }, - "outputs": [ - { - "data": { - "text/plain": [ - "Variable containing:\n", - " 0.0529 0.4152 0.6688 0.4281\n", - " 0.4504 -0.3291 0.4206 1.4391\n", - "[torch.FloatTensor of size 2x4]" - ] - }, - "execution_count": 11, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "# 输入 batch_size=2,维度3\n", - "input = V(t.randn(2, 3))\n", - "linear = nn.Linear(3, 4)\n", - "h = linear(input)\n", - "h" - ] - }, - { - "cell_type": "code", - "execution_count": 12, - "metadata": { - "scrolled": true - }, - "outputs": [ - { - "data": { - "text/plain": [ - "(Variable containing:\n", - " 1.00000e-07 *\n", - " 0.0000\n", - " 0.0000\n", - " 0.0000\n", - " 2.3842\n", - " [torch.FloatTensor of size 4], Variable containing:\n", - " 15.9960\n", - " 15.9988\n", - " 15.9896\n", - " 15.9994\n", - " [torch.FloatTensor of size 4])" - ] - }, - "execution_count": 12, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "# 4 channel,初始化标准差为4,均值为0\n", - "bn = nn.BatchNorm1d(4)\n", - "bn.weight.data = t.ones(4) * 4\n", - "bn.bias.data = t.zeros(4)\n", - "\n", - "bn_out = bn(h)\n", - "# 注意输出的均值和方差\n", - "# 方差是标准差的平方,计算无偏方差分母会减1\n", - "# 使用unbiased=False 分母不减1\n", - "bn_out.mean(0), bn_out.var(0, unbiased=False)" - ] - }, - { - "cell_type": "code", - "execution_count": 13, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "Variable containing:\n", - "-7.9990 0.0000 7.9974 -7.9998\n", - " 0.0000 -0.0000 -0.0000 7.9998\n", - "[torch.FloatTensor of size 2x4]" - ] - }, - "execution_count": 13, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "# 每个元素以0.5的概率舍弃\n", - "dropout = nn.Dropout(0.5)\n", - "o = dropout(bn_out)\n", - "o # 有一半左右的数变为0" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "以上很多例子中都对module的属性直接操作,其大多数是可学习参数,一般会随着学习的进行而不断改变。实际使用中除非需要使用特殊的初始化,应尽量不要直接修改这些参数。" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "#### 4.1.2 激活函数\n", - "PyTorch实现了常见的激活函数,其具体的接口信息可参见官方文档[^3],这些激活函数可作为独立的layer使用。这里将介绍最常用的激活函数ReLU,其数学表达式为:\n", - "$$ReLU(x)=max(0,x)$$\n", - "[^3]: http://pytorch.org/docs/nn.html#non-linear-activations" - ] - }, - { - "cell_type": "code", - "execution_count": 14, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Variable containing:\n", - "-0.6869 0.7347 -0.2196\n", - " 0.9445 -0.9042 -1.2652\n", - "[torch.FloatTensor of size 2x3]\n", - "\n", - "Variable containing:\n", - " 0.0000 0.7347 0.0000\n", - " 0.9445 0.0000 0.0000\n", - "[torch.FloatTensor of size 2x3]\n", - "\n" - ] - } - ], - "source": [ - "relu = nn.ReLU(inplace=True)\n", - "input = V(t.randn(2, 3))\n", - "print(input)\n", - "output = relu(input)\n", - "print(output) # 小于0的都被截断为0\n", - "# 等价于input.clamp(min=0)" - ] - }, - { - "cell_type": "raw", - "metadata": {}, - "source": [ - "ReLU函数有个inplace参数,如果设为True,它会把输出直接覆盖到输入中,这样可以节省内存/显存。之所以可以覆盖是因为在计算ReLU的反向传播时,只需根据输出就能够推算出反向传播的梯度。但是只有少数的autograd操作支持inplace操作(如variable.sigmoid_()),除非你明确地知道自己在做什么,否则一般不要使用inplace操作。" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "在以上的例子中,基本上都是将每一层的输出直接作为下一层的输入,这种网络称为前馈传播网络(feedforward neural network)。对于此类网络如果每次都写复杂的forward函数会有些麻烦,在此就有两种简化方式,ModuleList和Sequential。其中Sequential是一个特殊的module,它包含几个子Module,前向传播时会将输入一层接一层的传递下去。ModuleList也是一个特殊的module,可以包含几个子module,可以像用list一样使用它,但不能直接把输入传给ModuleList。下面举例说明。" - ] - }, - { - "cell_type": "code", - "execution_count": 15, - "metadata": { - "scrolled": false - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "net1: Sequential(\n", - " (conv): Conv2d (3, 3, kernel_size=(3, 3), stride=(1, 1))\n", - " (batchnorm): BatchNorm2d(3, eps=1e-05, momentum=0.1, affine=True)\n", - " (activation_layer): ReLU()\n", - ")\n", - "net2: Sequential(\n", - " (0): Conv2d (3, 3, kernel_size=(3, 3), stride=(1, 1))\n", - " (1): BatchNorm2d(3, eps=1e-05, momentum=0.1, affine=True)\n", - " (2): ReLU()\n", - ")\n", - "net3: Sequential(\n", - " (conv1): Conv2d (3, 3, kernel_size=(3, 3), stride=(1, 1))\n", - " (bn1): BatchNorm2d(3, eps=1e-05, momentum=0.1, affine=True)\n", - " (relu1): ReLU()\n", - ")\n" - ] - } - ], - "source": [ - "# Sequential的三种写法\n", - "net1 = nn.Sequential()\n", - "net1.add_module('conv', nn.Conv2d(3, 3, 3))\n", - "net1.add_module('batchnorm', nn.BatchNorm2d(3))\n", - "net1.add_module('activation_layer', nn.ReLU())\n", - "\n", - "net2 = nn.Sequential(\n", - " nn.Conv2d(3, 3, 3),\n", - " nn.BatchNorm2d(3),\n", - " nn.ReLU()\n", - " )\n", - "\n", - "from collections import OrderedDict\n", - "net3= nn.Sequential(OrderedDict([\n", - " ('conv1', nn.Conv2d(3, 3, 3)),\n", - " ('bn1', nn.BatchNorm2d(3)),\n", - " ('relu1', nn.ReLU())\n", - " ]))\n", - "print('net1:', net1)\n", - "print('net2:', net2)\n", - "print('net3:', net3)" - ] - }, - { - "cell_type": "code", - "execution_count": 16, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "(Conv2d (3, 3, kernel_size=(3, 3), stride=(1, 1)),\n", - " Conv2d (3, 3, kernel_size=(3, 3), stride=(1, 1)),\n", - " Conv2d (3, 3, kernel_size=(3, 3), stride=(1, 1)))" - ] - }, - "execution_count": 16, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "# 可根据名字或序号取出子module\n", - "net1.conv, net2[0], net3.conv1" - ] - }, - { - "cell_type": "code", - "execution_count": 17, - "metadata": {}, - "outputs": [], - "source": [ - "input = V(t.rand(1, 3, 4, 4))\n", - "output = net1(input)\n", - "output = net2(input)\n", - "output = net3(input)\n", - "output = net3.relu1(net1.batchnorm(net1.conv(input)))" - ] - }, - { - "cell_type": "code", - "execution_count": 18, - "metadata": {}, - "outputs": [], - "source": [ - "modellist = nn.ModuleList([nn.Linear(3,4), nn.ReLU(), nn.Linear(4,2)])\n", - "input = V(t.randn(1, 3))\n", - "for model in modellist:\n", - " input = model(input)\n", - "# 下面会报错,因为modellist没有实现forward方法\n", - "# output = modelist(input)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "看到这里,读者可能会问,为何不直接使用Python中自带的list,而非要多此一举呢?这是因为`ModuleList`是`Module`的子类,当在`Module`中使用它的时候,就能自动识别为子module。\n", - "\n", - "下面举例说明。" - ] - }, - { - "cell_type": "code", - "execution_count": 19, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "MyModule(\n", - " (module_list): ModuleList(\n", - " (0): Conv2d (3, 3, kernel_size=(3, 3), stride=(1, 1))\n", - " (1): ReLU()\n", - " )\n", - ")" - ] - }, - "execution_count": 19, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "class MyModule(nn.Module):\n", - " def __init__(self):\n", - " super(MyModule, self).__init__()\n", - " self.list = [nn.Linear(3, 4), nn.ReLU()]\n", - " self.module_list = nn.ModuleList([nn.Conv2d(3, 3, 3), nn.ReLU()])\n", - " def forward(self):\n", - " pass\n", - "model = MyModule()\n", - "model" - ] - }, - { - "cell_type": "code", - "execution_count": 20, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "module_list.0.weight torch.Size([3, 3, 3, 3])\n", - "module_list.0.bias torch.Size([3])\n" - ] - } - ], - "source": [ - "for name, param in model.named_parameters():\n", - " print(name, param.size())" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "可见,list中的子module并不能被主module所识别,而ModuleList中的子module能够被主module所识别。这意味着如果用list保存子module,将无法调整其参数,因其未加入到主module的参数中。" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "除ModuleList之外还有ParameterList,其是一个可以包含多个parameter的类list对象。在实际应用中,使用方式与ModuleList类似。如果在构造函数`__init__`中用到list、tuple、dict等对象时,一定要思考是否应该用ModuleList或ParameterList代替。" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "#### 4.1.3 循环神经网络层(RNN)\n", - "近些年随着深度学习和自然语言处理的结合加深,RNN的使用也越来越多,关于RNN的基础知识,推荐阅读colah的文章[^4]入门。PyTorch中实现了如今最常用的三种RNN:RNN(vanilla RNN)、LSTM和GRU。此外还有对应的三种RNNCell。\n", - "\n", - "RNN和RNNCell层的区别在于前者一次能够处理整个序列,而后者一次只处理序列中一个时间点的数据,前者封装更完备更易于使用,后者更具灵活性。实际上RNN层的一种后端实现方式就是调用RNNCell来实现的。\n", - "[^4]: http://colah.github.io/posts/2015-08-Understanding-LSTMs/" - ] - }, - { - "cell_type": "code", - "execution_count": 21, - "metadata": { - "scrolled": true - }, - "outputs": [ - { - "data": { - "text/plain": [ - "Variable containing:\n", - "(0 ,.,.) = \n", - " 0.0545 -0.0061 0.5615\n", - " -0.1251 0.4490 0.2640\n", - " 0.1405 -0.1624 0.0303\n", - "\n", - "(1 ,.,.) = \n", - " 0.0168 0.1562 0.5002\n", - " 0.0824 0.1454 0.4007\n", - " 0.0180 -0.0267 0.0094\n", - "[torch.FloatTensor of size 2x3x3]" - ] - }, - "execution_count": 21, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "t.manual_seed(1000)\n", - "# 输入:batch_size=3,序列长度都为2,序列中每个元素占4维\n", - "input = V(t.randn(2, 3, 4))\n", - "# lstm输入向量4维,隐藏元3,1层\n", - "lstm = nn.LSTM(4, 3, 1)\n", - "# 初始状态:1层,batch_size=3,3个隐藏元\n", - "h0 = V(t.randn(1, 3, 3))\n", - "c0 = V(t.randn(1, 3, 3))\n", - "out, hn = lstm(input, (h0, c0))\n", - "out" - ] - }, - { - "cell_type": "code", - "execution_count": 22, - "metadata": { - "scrolled": true - }, - "outputs": [ - { - "data": { - "text/plain": [ - "Variable containing:\n", - "(0 ,.,.) = \n", - " 0.0545 -0.0061 0.5615\n", - " -0.1251 0.4490 0.2640\n", - " 0.1405 -0.1624 0.0303\n", - "\n", - "(1 ,.,.) = \n", - " 0.0168 0.1562 0.5002\n", - " 0.0824 0.1454 0.4007\n", - " 0.0180 -0.0267 0.0094\n", - "[torch.FloatTensor of size 2x3x3]" - ] - }, - "execution_count": 22, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "t.manual_seed(1000)\n", - "input = V(t.randn(2, 3, 4))\n", - "# 一个LSTMCell对应的层数只能是一层\n", - "lstm = nn.LSTMCell(4, 3)\n", - "hx = V(t.randn(3, 3))\n", - "cx = V(t.randn(3, 3))\n", - "out = []\n", - "for i_ in input:\n", - " hx, cx=lstm(i_, (hx, cx))\n", - " out.append(hx)\n", - "t.stack(out)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "词向量在自然语言中应用十分普及,PyTorch同样提供了Embedding层。" - ] - }, - { - "cell_type": "code", - "execution_count": 23, - "metadata": {}, - "outputs": [], - "source": [ - "# 有4个词,每个词用5维的向量表示\n", - "embedding = nn.Embedding(4, 5)\n", - "# 可以用预训练好的词向量初始化embedding\n", - "embedding.weight.data = t.arange(0,20).view(4,5)" - ] - }, - { - "cell_type": "code", - "execution_count": 24, - "metadata": { - "scrolled": false - }, - "outputs": [ - { - "data": { - "text/plain": [ - "Variable containing:\n", - " 15 16 17 18 19\n", - " 10 11 12 13 14\n", - " 5 6 7 8 9\n", - "[torch.FloatTensor of size 3x5]" - ] - }, - "execution_count": 24, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "input = V(t.arange(3, 0, -1)).long()\n", - "output = embedding(input)\n", - "output" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "#### 4.1.4 损失函数\n", - "在深度学习中要用到各种各样的损失函数(loss function),这些损失函数可看作是一种特殊的layer,PyTorch也将这些损失函数实现为`nn.Module`的子类。然而在实际使用中通常将这些loss function专门提取出来,和主模型互相独立。详细的loss使用请参照文档[^5],这里以分类中最常用的交叉熵损失CrossEntropyloss为例说明。\n", - "[^5]: http://pytorch.org/docs/nn.html#loss-functions" - ] - }, - { - "cell_type": "code", - "execution_count": 25, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "Variable containing:\n", - " 1.5544\n", - "[torch.FloatTensor of size 1]" - ] - }, - "execution_count": 25, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "# batch_size=3,计算对应每个类别的分数(只有两个类别)\n", - "score = V(t.randn(3, 2))\n", - "# 三个样本分别属于1,0,1类,label必须是LongTensor\n", - "label = V(t.Tensor([1, 0, 1])).long()\n", - "\n", - "# loss与普通的layer无差异\n", - "criterion = nn.CrossEntropyLoss()\n", - "loss = criterion(score, label)\n", - "loss" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### 4.2 优化器\n", - "\n", - "PyTorch将深度学习中常用的优化方法全部封装在`torch.optim`中,其设计十分灵活,能够很方便的扩展成自定义的优化方法。\n", - "\n", - "所有的优化方法都是继承基类`optim.Optimizer`,并实现了自己的优化步骤。下面就以最基本的优化方法——随机梯度下降法(SGD)举例说明。这里需重点掌握:\n", - "\n", - "- 优化方法的基本使用方法\n", - "- 如何对模型的不同部分设置不同的学习率\n", - "- 如何调整学习率" - ] - }, - { - "cell_type": "code", - "execution_count": 26, - "metadata": {}, - "outputs": [], - "source": [ - "# 首先定义一个LeNet网络\n", - "class Net(nn.Module):\n", - " def __init__(self):\n", - " super(Net, self).__init__()\n", - " self.features = nn.Sequential(\n", - " nn.Conv2d(3, 6, 5),\n", - " nn.ReLU(),\n", - " nn.MaxPool2d(2,2),\n", - " nn.Conv2d(6, 16, 5),\n", - " nn.ReLU(),\n", - " nn.MaxPool2d(2,2)\n", - " )\n", - " self.classifier = nn.Sequential(\n", - " nn.Linear(16 * 5 * 5, 120),\n", - " nn.ReLU(),\n", - " nn.Linear(120, 84),\n", - " nn.ReLU(),\n", - " nn.Linear(84, 10)\n", - " )\n", - "\n", - " def forward(self, x):\n", - " x = self.features(x)\n", - " x = x.view(-1, 16 * 5 * 5)\n", - " x = self.classifier(x)\n", - " return x\n", - "\n", - "net = Net()" - ] - }, - { - "cell_type": "code", - "execution_count": 27, - "metadata": { - "scrolled": true - }, - "outputs": [], - "source": [ - "from torch import optim\n", - "optimizer = optim.SGD(params=net.parameters(), lr=1)\n", - "optimizer.zero_grad() # 梯度清零,等价于net.zero_grad()\n", - "\n", - "input = V(t.randn(1, 3, 32, 32))\n", - "output = net(input)\n", - "output.backward(output) # fake backward\n", - "\n", - "optimizer.step() # 执行优化" - ] - }, - { - "cell_type": "code", - "execution_count": 28, - "metadata": {}, - "outputs": [], - "source": [ - "# 为不同子网络设置不同的学习率,在finetune中经常用到\n", - "# 如果对某个参数不指定学习率,就使用最外层的默认学习率\n", - "optimizer =optim.SGD([\n", - " {'params': net.features.parameters()}, # 学习率为1e-5\n", - " {'params': net.classifier.parameters(), 'lr': 1e-2}\n", - " ], lr=1e-5)" - ] - }, - { - "cell_type": "code", - "execution_count": 29, - "metadata": {}, - "outputs": [], - "source": [ - "# 只为两个全连接层设置较大的学习率,其余层的学习率较小\n", - "special_layers = nn.ModuleList([net.classifier[0], net.classifier[2]])\n", - "special_layers_params = list(map(id, special_layers.parameters()))\n", - "base_params = filter(lambda p: id(p) not in special_layers_params,\n", - " net.parameters())\n", - "\n", - "optimizer = t.optim.SGD([\n", - " {'params': base_params},\n", - " {'params': special_layers.parameters(), 'lr': 0.01}\n", - " ], lr=0.001 )" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "对于如何调整学习率,主要有两种做法。一种是修改optimizer.param_groups中对应的学习率,另一种是更简单也是较为推荐的做法——新建优化器,由于optimizer十分轻量级,构建开销很小,故而可以构建新的optimizer。但是后者对于使用动量的优化器(如Adam),会丢失动量等状态信息,可能会造成损失函数的收敛出现震荡等情况。" - ] - }, - { - "cell_type": "code", - "execution_count": 30, - "metadata": {}, - "outputs": [], - "source": [ - "# 调整学习率,新建一个optimizer\n", - "old_lr = 0.1\n", - "optimizer =optim.SGD([\n", - " {'params': net.features.parameters()},\n", - " {'params': net.classifier.parameters(), 'lr': old_lr*0.1}\n", - " ], lr=1e-5)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### 4.3 nn.functional\n", - "\n", - "nn中还有一个很常用的模块:`nn.functional`,nn中的大多数layer,在`functional`中都有一个与之相对应的函数。`nn.functional`中的函数和`nn.Module`的主要区别在于,用nn.Module实现的layers是一个特殊的类,都是由`class layer(nn.Module)`定义,会自动提取可学习的参数。而`nn.functional`中的函数更像是纯函数,由`def function(input)`定义。下面举例说明functional的使用,并指出二者的不同之处。" - ] - }, - { - "cell_type": "code", - "execution_count": 31, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "Variable containing:\n", - " 1 1 1 1\n", - " 1 1 1 1\n", - "[torch.ByteTensor of size 2x4]" - ] - }, - "execution_count": 31, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "input = V(t.randn(2, 3))\n", - "model = nn.Linear(3, 4)\n", - "output1 = model(input)\n", - "output2 = nn.functional.linear(input, model.weight, model.bias)\n", - "output1 == output2" - ] - }, - { - "cell_type": "code", - "execution_count": 32, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "Variable containing:\n", - " 1 1 1\n", - " 1 1 1\n", - "[torch.ByteTensor of size 2x3]" - ] - }, - "execution_count": 32, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "b = nn.functional.relu(input)\n", - "b2 = nn.ReLU()(input)\n", - "b == b2" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "此时读者可能会问,应该什么时候使用nn.Module,什么时候使用nn.functional呢?答案很简单,如果模型有可学习的参数,最好用nn.Module,否则既可以使用nn.functional也可以使用nn.Module,二者在性能上没有太大差异,具体的使用取决于个人的喜好。如激活函数(ReLU、sigmoid、tanh),池化(MaxPool)等层由于没有可学习参数,则可以使用对应的functional函数代替,而对于卷积、全连接等具有可学习参数的网络建议使用nn.Module。下面举例说明,如何在模型中搭配使用nn.Module和nn.functional。另外虽然dropout操作也没有可学习操作,但建议还是使用`nn.Dropout`而不是`nn.functional.dropout`,因为dropout在训练和测试两个阶段的行为有所差别,使用`nn.Module`对象能够通过`model.eval`操作加以区分。" - ] - }, - { - "cell_type": "code", - "execution_count": 33, - "metadata": {}, - "outputs": [], - "source": [ - "from torch.nn import functional as F\n", - "class Net(nn.Module):\n", - " def __init__(self):\n", - " super(Net, self).__init__()\n", - " self.conv1 = nn.Conv2d(3, 6, 5)\n", - " self.conv2 = nn.Conv2d(6, 16, 5)\n", - " self.fc1 = nn.Linear(16 * 5 * 5, 120)\n", - " self.fc2 = nn.Linear(120, 84)\n", - " self.fc3 = nn.Linear(84, 10)\n", - "\n", - " def forward(self, x):\n", - " x = F.pool(F.relu(self.conv1(x)), 2)\n", - " x = F.pool(F.relu(self.conv2(x)), 2)\n", - " x = x.view(-1, 16 * 5 * 5)\n", - " x = F.relu(self.fc1(x))\n", - " x = F.relu(self.fc2(x))\n", - " x = self.fc3(x)\n", - " return x" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "对于不具备可学习参数的层(激活层、池化层等),将它们用函数代替,这样则可以不用放置在构造函数`__init__`中。对于有可学习参数的模块,也可以用functional来代替,只不过实现起来较为繁琐,需要手动定义参数parameter,如前面实现自定义的全连接层,就可将weight和bias两个参数单独拿出来,在构造函数中初始化为parameter。" - ] - }, - { - "cell_type": "code", - "execution_count": 34, - "metadata": {}, - "outputs": [], - "source": [ - "class MyLinear(nn.Module):\n", - " def __init__(self):\n", - " super(MyLinear, self).__init__()\n", - " self.weight = nn.Parameter(t.randn(3, 4))\n", - " self.bias = nn.Parameter(t.zeros(3))\n", - " def forward(self,input):\n", - " return F.linear(input, self.weight, self.bias)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "关于nn.functional的设计初衷,以及它和nn.Module更多的比较说明,可参看论坛的讨论和作者说明[^6]。\n", - "[^6]: https://discuss.pytorch.org/search?q=nn.functional" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### 4.4 初始化策略\n", - "在深度学习中参数的初始化十分重要,良好的初始化能让模型更快收敛,并达到更高水平,而糟糕的初始化则可能使得模型迅速瘫痪。PyTorch中nn.Module的模块参数都采取了较为合理的初始化策略,因此一般不用我们考虑,当然我们也可以用自定义初始化去代替系统的默认初始化。而当我们在使用Parameter时,自定义初始化则尤为重要,因t.Tensor()返回的是内存中的随机数,很可能会有极大值,这在实际训练网络中会造成溢出或者梯度消失。PyTorch中`nn.init`模块就是专门为初始化而设计,如果某种初始化策略`nn.init`不提供,用户也可以自己直接初始化。" - ] - }, - { - "cell_type": "code", - "execution_count": 35, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "Parameter containing:\n", - " 0.3535 0.1427 0.0330\n", - " 0.3321 -0.2416 -0.0888\n", - "-0.8140 0.2040 -0.5493\n", - "-0.3010 -0.4769 -0.0311\n", - "[torch.FloatTensor of size 4x3]" - ] - }, - "execution_count": 35, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "# 利用nn.init初始化\n", - "from torch.nn import init\n", - "linear = nn.Linear(3, 4)\n", - "\n", - "t.manual_seed(1)\n", - "# 等价于 linear.weight.data.normal_(0, std)\n", - "init.xavier_normal(linear.weight)" - ] - }, - { - "cell_type": "code", - "execution_count": 36, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "\n", - " 0.3535 0.1427 0.0330\n", - " 0.3321 -0.2416 -0.0888\n", - "-0.8140 0.2040 -0.5493\n", - "-0.3010 -0.4769 -0.0311\n", - "[torch.FloatTensor of size 4x3]" - ] - }, - "execution_count": 36, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "# 直接初始化\n", - "import math\n", - "t.manual_seed(1)\n", - "\n", - "# xavier初始化的计算公式\n", - "std = math.sqrt(2)/math.sqrt(7.)\n", - "linear.weight.data.normal_(0,std)" - ] - }, - { - "cell_type": "code", - "execution_count": 37, - "metadata": {}, - "outputs": [], - "source": [ - "# 对模型的所有参数进行初始化\n", - "for name, params in net.named_parameters():\n", - " if name.find('linear') != -1:\n", - " # init linear\n", - " params[0] # weight\n", - " params[1] # bias\n", - " elif name.find('conv') != -1:\n", - " pass\n", - " elif name.find('norm') != -1:\n", - " pass" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### 4.5 nn.Module深入分析\n", - "\n", - "如果想要更深入地理解nn.Module,究其原理是很有必要的。首先来看看nn.Module基类的构造函数:\n", - "```python\n", - "def __init__(self):\n", - " self._parameters = OrderedDict()\n", - " self._modules = OrderedDict()\n", - " self._buffers = OrderedDict()\n", - " self._backward_hooks = OrderedDict()\n", - " self._forward_hooks = OrderedDict()\n", - " self.training = True\n", - "```\n", - "其中每个属性的解释如下:\n", - "\n", - "- `_parameters`:字典,保存用户直接设置的parameter,`self.param1 = nn.Parameter(t.randn(3, 3))`会被检测到,在字典中加入一个key为'param',value为对应parameter的item。而self.submodule = nn.Linear(3, 4)中的parameter则不会存于此。\n", - "- `_modules`:子module,通过`self.submodel = nn.Linear(3, 4)`指定的子module会保存于此。\n", - "- `_buffers`:缓存。如batchnorm使用momentum机制,每次前向传播需用到上一次前向传播的结果。\n", - "- `_backward_hooks`与`_forward_hooks`:钩子技术,用来提取中间变量,类似variable的hook。\n", - "- `training`:BatchNorm与Dropout层在训练阶段和测试阶段中采取的策略不同,通过判断training值来决定前向传播策略。\n", - "\n", - "上述几个属性中,`_parameters`、`_modules`和`_buffers`这三个字典中的键值,都可以通过`self.key`方式获得,效果等价于`self._parameters['key']`.\n", - "\n", - "下面举例说明。" - ] - }, - { - "cell_type": "code", - "execution_count": 38, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "Net(\n", - " (submodel1): Linear(in_features=3, out_features=4)\n", - ")" - ] - }, - "execution_count": 38, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "class Net(nn.Module):\n", - " def __init__(self):\n", - " super(Net, self).__init__()\n", - " # 等价与self.register_parameter('param1' ,nn.Parameter(t.randn(3, 3)))\n", - " self.param1 = nn.Parameter(t.rand(3, 3))\n", - " self.submodel1 = nn.Linear(3, 4) \n", - " def forward(self, input):\n", - " x = self.param1.mm(input)\n", - " x = self.submodel1(x)\n", - " return x\n", - "net = Net()\n", - "net" - ] - }, - { - "cell_type": "code", - "execution_count": 39, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "OrderedDict([('submodel1', Linear(in_features=3, out_features=4))])" - ] - }, - "execution_count": 39, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "net._modules" - ] - }, - { - "cell_type": "code", - "execution_count": 40, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "OrderedDict([('param1', Parameter containing:\n", - " 0.3398 0.5239 0.7981\n", - " 0.7718 0.0112 0.8100\n", - " 0.6397 0.9743 0.8300\n", - " [torch.FloatTensor of size 3x3])])" - ] - }, - "execution_count": 40, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "net._parameters" - ] - }, - { - "cell_type": "code", - "execution_count": 41, - "metadata": { - "scrolled": true - }, - "outputs": [ - { - "data": { - "text/plain": [ - "Parameter containing:\n", - " 0.3398 0.5239 0.7981\n", - " 0.7718 0.0112 0.8100\n", - " 0.6397 0.9743 0.8300\n", - "[torch.FloatTensor of size 3x3]" - ] - }, - "execution_count": 41, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "net.param1 # 等价于net._parameters['param1']" - ] - }, - { - "cell_type": "code", - "execution_count": 42, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "param1 torch.Size([3, 3])\n", - "submodel1.weight torch.Size([4, 3])\n", - "submodel1.bias torch.Size([4])\n" - ] - } - ], - "source": [ - "for name, param in net.named_parameters():\n", - " print(name, param.size())" - ] - }, - { - "cell_type": "code", - "execution_count": 43, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - " Net(\n", - " (submodel1): Linear(in_features=3, out_features=4)\n", - ")\n", - "submodel1 Linear(in_features=3, out_features=4)\n" - ] - } - ], - "source": [ - "for name, submodel in net.named_modules():\n", - " print(name, submodel)" - ] - }, - { - "cell_type": "code", - "execution_count": 44, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "OrderedDict([('running_mean', \n", - " 1.00000e-02 *\n", - " 5.1362\n", - " 7.4864\n", - " [torch.FloatTensor of size 2]), ('running_var', \n", - " 0.9116\n", - " 0.9068\n", - " [torch.FloatTensor of size 2])])" - ] - }, - "execution_count": 44, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "bn = nn.BatchNorm1d(2)\n", - "input = V(t.rand(3, 2), requires_grad=True)\n", - "output = bn(input)\n", - "bn._buffers" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "nn.Module在实际使用中可能层层嵌套,一个module包含若干个子module,每一个子module又包含了更多的子module。为方便用户访问各个子module,nn.Module实现了很多方法,如函数`children`可以查看直接子module,函数`module`可以查看所有的子module(包括当前module)。与之相对应的还有函数`named_childen`和`named_modules`,其能够在返回module列表的同时返回它们的名字。" - ] - }, - { - "cell_type": "code", - "execution_count": 45, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "Variable containing:\n", - " 0 2 0 0\n", - " 8 0 12 14\n", - " 16 0 0 22\n", - "[torch.FloatTensor of size 3x4]" - ] - }, - "execution_count": 45, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "input = V(t.arange(0, 12).view(3, 4))\n", - "model = nn.Dropout()\n", - "# 在训练阶段,会有一半左右的数被随机置为0\n", - "model(input)" - ] - }, - { - "cell_type": "code", - "execution_count": 46, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "Variable containing:\n", - " 0 1 2 3\n", - " 4 5 6 7\n", - " 8 9 10 11\n", - "[torch.FloatTensor of size 3x4]" - ] - }, - "execution_count": 46, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "model.training = False\n", - "# 在测试阶段,dropout什么都不做\n", - "model(input)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "对于batchnorm、dropout、instancenorm等在训练和测试阶段行为差距巨大的层,如果在测试时不将其training值设为True,则可能会有很大影响,这在实际使用中要千万注意。虽然可通过直接设置`training`属性,来将子module设为train和eval模式,但这种方式较为繁琐,因如果一个模型具有多个dropout层,就需要为每个dropout层指定training属性。更为推荐的做法是调用`model.train()`函数,它会将当前module及其子module中的所有training属性都设为True,相应的,`model.eval()`函数会把training属性都设为False。" - ] - }, - { - "cell_type": "code", - "execution_count": 47, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "True True\n" - ] - }, - { - "data": { - "text/plain": [ - "(False, False)" - ] - }, - "execution_count": 47, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "print(net.training, net.submodel1.training)\n", - "net.eval()\n", - "net.training, net.submodel1.training" - ] - }, - { - "cell_type": "code", - "execution_count": 48, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "[('', Net(\n", - " (submodel1): Linear(in_features=3, out_features=4)\n", - " )), ('submodel1', Linear(in_features=3, out_features=4))]" - ] - }, - "execution_count": 48, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "list(net.named_modules())" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "`register_forward_hook`与`register_backward_hook`,这两个函数的功能类似于variable函数的`register_hook`,可在module前向传播或反向传播时注册钩子。每次前向传播执行结束后会执行钩子函数(hook)。前向传播的钩子函数具有如下形式:`hook(module, input, output) -> None`,而反向传播则具有如下形式:`hook(module, grad_input, grad_output) -> Tensor or None`。钩子函数不应修改输入和输出,并且在使用后应及时删除,以避免每次都运行钩子增加运行负载。钩子函数主要用在获取某些中间结果的情景,如中间某一层的输出或某一层的梯度。这些结果本应写在forward函数中,但如果在forward函数中专门加上这些处理,可能会使处理逻辑比较复杂,这时候使用钩子技术就更合适一些。下面考虑一种场景,有一个预训练好的模型,需要提取模型的某一层(不是最后一层)的输出作为特征进行分类,但又不希望修改其原有的模型定义文件,这时就可以利用钩子函数。下面给出实现的伪代码。\n", - "```python\n", - "model = VGG()\n", - "features = t.Tensor()\n", - "def hook(module, input, output):\n", - " '''把这层的输出拷贝到features中'''\n", - " features.copy_(output.data)\n", - " \n", - "handle = model.layer8.register_forward_hook(hook)\n", - "_ = model(input)\n", - "# 用完hook后删除\n", - "handle.remove()\n", - "```" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "`nn.Module`对象在构造函数中的行为看起来有些怪异,如果想要真正掌握其原理,就需要看两个魔法方法`__getattr__`和`__setattr__`。在Python中有两个常用的buildin方法`getattr`和`setattr`,`getattr(obj, 'attr1')`等价于`obj.attr`,如果`getattr`函数无法找到所需属性,Python会转而调用`obj.__getattr__('attr1')`方法,即`getattr`函数无法找到的交给`__getattr__`函数处理,没有实现`__getattr__`或者`__getattr__`也无法处理的就会raise AttributeError。`setattr(obj, 'name', value)`等价于`obj.name=value`,如果obj对象实现了`__setattr__`方法,setattr会直接调用`obj.__setattr__('name', value)`,否则调用buildin方法。总结一下:\n", - "- result = obj.name会调用buildin函数`getattr(obj, 'name')`,如果该属性找不到,会调用`obj.__getattr__('name')`\n", - "- obj.name = value会调用buildin函数`setattr(obj, 'name', value)`,如果obj对象实现了`__setattr__`方法,`setattr`会直接调用`obj.__setattr__('name', value')`" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "nn.Module实现了自定义的`__setattr__`函数,当执行`module.name=value`时,会在`__setattr__`中判断value是否为`Parameter`或`nn.Module`对象,如果是则将这些对象加到`_parameters`和`_modules`两个字典中,而如果是其它类型的对象,如`Variable`、`list`、`dict`等,则调用默认的操作,将这个值保存在`__dict__`中。" - ] - }, - { - "cell_type": "code", - "execution_count": 49, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "OrderedDict([('param', Parameter containing:\n", - " 1 1\n", - " 1 1\n", - " [torch.FloatTensor of size 2x2])])" - ] - }, - "execution_count": 49, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "module = nn.Module()\n", - "module.param = nn.Parameter(t.ones(2, 2))\n", - "module._parameters" - ] - }, - { - "cell_type": "code", - "execution_count": 50, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "_modules: OrderedDict()\n", - "__dict__['submodules']: [Linear(in_features=2, out_features=2), Linear(in_features=2, out_features=2)]\n" - ] - } - ], - "source": [ - "submodule1 = nn.Linear(2, 2)\n", - "submodule2 = nn.Linear(2, 2)\n", - "module_list = [submodule1, submodule2]\n", - "# 对于list对象,调用buildin函数,保存在__dict__中\n", - "module.submodules = module_list\n", - "print('_modules: ', module._modules)\n", - "print(\"__dict__['submodules']:\",module.__dict__.get('submodules'))" - ] - }, - { - "cell_type": "code", - "execution_count": 51, - "metadata": { - "scrolled": true - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "ModuleList is instance of nn.Module: True\n", - "_modules: OrderedDict([('submodules', ModuleList(\n", - " (0): Linear(in_features=2, out_features=2)\n", - " (1): Linear(in_features=2, out_features=2)\n", - "))])\n", - "__dict__['submodules']: None\n" - ] - } - ], - "source": [ - "module_list = nn.ModuleList(module_list)\n", - "module.submodules = module_list\n", - "print('ModuleList is instance of nn.Module: ', isinstance(module_list, nn.Module))\n", - "print('_modules: ', module._modules)\n", - "print(\"__dict__['submodules']:\", module.__dict__.get('submodules'))" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "因`_modules`和`_parameters`中的item未保存在`__dict__`中,所以默认的getattr方法无法获取它,因而`nn.Module`实现了自定义的`__getattr__`方法,如果默认的`getattr`无法处理,就调用自定义的`__getattr__`方法,尝试从`_modules`、`_parameters`和`_buffers`这三个字典中获取。" - ] - }, - { - "cell_type": "code", - "execution_count": 52, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "True" - ] - }, - "execution_count": 52, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "getattr(module, 'training') # 等价于module.training\n", - "# error\n", - "# module.__getattr__('training')" - ] - }, - { - "cell_type": "code", - "execution_count": 53, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "2" - ] - }, - "execution_count": 53, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "module.attr1 = 2\n", - "getattr(module, 'attr1')\n", - "# 报错\n", - "# module.__getattr__('attr1')" - ] - }, - { - "cell_type": "code", - "execution_count": 54, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "Parameter containing:\n", - " 1 1\n", - " 1 1\n", - "[torch.FloatTensor of size 2x2]" - ] - }, - "execution_count": 54, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "# 即module.param, 会调用module.__getattr__('param')\n", - "getattr(module, 'param')" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "在PyTorch中保存模型十分简单,所有的Module对象都具有state_dict()函数,返回当前Module所有的状态数据。将这些状态数据保存后,下次使用模型时即可利用`model.load_state_dict()`函数将状态加载进来。优化器(optimizer)也有类似的机制,不过一般并不需要保存优化器的运行状态。" - ] - }, - { - "cell_type": "code", - "execution_count": 55, - "metadata": {}, - "outputs": [], - "source": [ - "# 保存模型\n", - "t.save(net.state_dict(), 'net.pth')\n", - "\n", - "# 加载已保存的模型\n", - "net2 = Net()\n", - "net2.load_state_dict(t.load('net.pth'))" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "实际上还有另外一种保存方法,但因其严重依赖模型定义方式及文件路径结构等,很容易出问题,因而不建议使用。" - ] - }, - { - "cell_type": "code", - "execution_count": 56, - "metadata": { - "scrolled": false - }, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/usr/local/lib/python3.5/dist-packages/torch/serialization.py:158: UserWarning: Couldn't retrieve source code for container of type Net. It won't be checked for correctness upon loading.\n", - " \"type \" + obj.__name__ + \". It won't be checked \"\n" - ] - }, - { - "data": { - "text/plain": [ - "Net(\n", - " (submodel1): Linear(in_features=3, out_features=4)\n", - ")" - ] - }, - "execution_count": 56, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "t.save(net, 'net_all.pth')\n", - "net2 = t.load('net_all.pth')\n", - "net2" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "将Module放在GPU上运行也十分简单,只需两步:\n", - "- model = model.cuda():将模型的所有参数转存到GPU\n", - "- input.cuda():将输入数据也放置到GPU上\n", - "\n", - "至于如何在多个GPU上并行计算,PyTorch也提供了两个函数,可实现简单高效的并行GPU计算\n", - "- nn.parallel.data_parallel(module, inputs, device_ids=None, output_device=None, dim=0, module_kwargs=None)\n", - "- class torch.nn.DataParallel(module, device_ids=None, output_device=None, dim=0)\n", - "\n", - "可见二者的参数十分相似,通过`device_ids`参数可以指定在哪些GPU上进行优化,output_device指定输出到哪个GPU上。唯一的不同就在于前者直接利用多GPU并行计算得出结果,而后者则返回一个新的module,能够自动在多GPU上进行并行加速。\n", - "\n", - "```\n", - "# method 1\n", - "new_net = nn.DataParallel(net, device_ids=[0, 1])\n", - "output = new_net(input)\n", - "\n", - "# method 2\n", - "output = nn.parallel.data_parallel(new_net, input, device_ids=[0, 1])\n", - "```" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "DataParallel并行的方式,是将输入一个batch的数据均分成多份,分别送到对应的GPU进行计算,各个GPU得到的梯度累加。与Module相关的所有数据也都会以浅复制的方式复制多份,在此需要注意,在module中属性应该是只读的。" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### 4.6 nn和autograd的关系\n", - "nn.Module利用的也是autograd技术,其主要工作是实现前向传播。在forward函数中,nn.Module对输入的Variable进行的各种操作,本质上都是用到了autograd技术。这里需要对比autograd.Function和nn.Module之间的区别:\n", - "- autograd.Function利用了Tensor对autograd技术的扩展,为autograd实现了新的运算op,不仅要实现前向传播还要手动实现反向传播\n", - "- nn.Module利用了autograd技术,对nn的功能进行扩展,实现了深度学习中更多的层。只需实现前向传播功能,autograd即会自动实现反向传播\n", - "- nn.functional是一些autograd操作的集合,是经过封装的函数\n", - "\n", - "作为两大类扩充PyTorch接口的方法,我们在实际使用中应该如何选择呢?如果某一个操作,在autograd中尚未支持,那么只能实现Function接口对应的前向传播和反向传播。如果某些时候利用autograd接口比较复杂,则可以利用Function将多个操作聚合,实现优化,正如第三章所实现的`Sigmoid`一样,比直接利用autograd低级别的操作要快。而如果只是想在深度学习中增加某一层,使用nn.Module进行封装则更为简单高效。" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### 4.7 小试牛刀:搭建ResNet\n", - "Kaiming He的深度残差网络(ResNet)[^7]在深度学习的发展中起到了很重要的作用,ResNet不仅一举拿下了当年CV下多个比赛项目的冠军,更重要的是这一结构解决了训练极深网络时的梯度消失问题。\n", - "\n", - "首先来看看ResNet的网络结构,这里选取的是ResNet的一个变种:ResNet34。ResNet的网络结构如图4-2所示,可见除了最开始的卷积池化和最后的池化全连接之外,网络中有很多结构相似的单元,这些重复单元的共同点就是有个跨层直连的shortcut。ResNet中将一个跨层直连的单元称为Residual block,其结构如图4-3所示,左边部分是普通的卷积网络结构,右边是直连,但如果输入和输出的通道数不一致,或其步长不为1,那么就需要有一个专门的单元将二者转成一致,使其可以相加。\n", - "\n", - "另外我们可以发现Residual block的大小也是有规律的,在最开始的pool之后有连续的几个一模一样的Residual block单元,这些单元的通道数一样,在这里我们将这几个拥有多个Residual block单元的结构称之为layer,注意和之前讲的layer区分开来,这里的layer是几个层的集合。\n", - "\n", - "考虑到Residual block和layer出现了多次,我们可以把它们实现为一个子Module或函数。这里我们将Residual block实现为一个子moduke,而将layer实现为一个函数。下面是实现代码,规律总结如下:\n", - "\n", - "- 对于模型中的重复部分,实现为子module或用函数生成相应的module`make_layer`\n", - "- nn.Module和nn.Functional结合使用\n", - "- 尽量使用`nn.Seqential`\n", - "\n", - "![图4-2: ResNet34网络结构](imgs/resnet1.png)\n", - "![图4-3: Residual block 结构图](imgs/residual.png)\n", - " [^7]: He K, Zhang X, Ren S, et al. Deep residual learning for image recognition[C]//Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition. 2016: 770-778." - ] - }, - { - "cell_type": "code", - "execution_count": 57, - "metadata": {}, - "outputs": [], - "source": [ - "from torch import nn\n", - "import torch as t\n", - "from torch.nn import functional as F" - ] - }, - { - "cell_type": "code", - "execution_count": 58, - "metadata": {}, - "outputs": [], - "source": [ - "class ResidualBlock(nn.Module):\n", - " '''\n", - " 实现子module: Residual Block\n", - " '''\n", - " def __init__(self, inchannel, outchannel, stride=1, shortcut=None):\n", - " super(ResidualBlock, self).__init__()\n", - " self.left = nn.Sequential(\n", - " nn.Conv2d(inchannel,outchannel,3,stride, 1,bias=False),\n", - " nn.BatchNorm2d(outchannel),\n", - " nn.ReLU(inplace=True),\n", - " nn.Conv2d(outchannel,outchannel,3,1,1,bias=False),\n", - " nn.BatchNorm2d(outchannel) )\n", - " self.right = shortcut\n", - "\n", - " def forward(self, x):\n", - " out = self.left(x)\n", - " residual = x if self.right is None else self.right(x)\n", - " out += residual\n", - " return F.relu(out)\n", - "\n", - "class ResNet(nn.Module):\n", - " '''\n", - " 实现主module:ResNet34\n", - " ResNet34 包含多个layer,每个layer又包含多个residual block\n", - " 用子module来实现residual block,用_make_layer函数来实现layer\n", - " '''\n", - " def __init__(self, num_classes=1000):\n", - " super(ResNet, self).__init__()\n", - " # 前几层图像转换\n", - " self.pre = nn.Sequential(\n", - " nn.Conv2d(3, 64, 7, 2, 3, bias=False),\n", - " nn.BatchNorm2d(64),\n", - " nn.ReLU(inplace=True),\n", - " nn.MaxPool2d(3, 2, 1))\n", - " \n", - " # 重复的layer,分别有3,4,6,3个residual block\n", - " self.layer1 = self._make_layer( 64, 64, 3)\n", - " self.layer2 = self._make_layer( 64, 128, 4, stride=2)\n", - " self.layer3 = self._make_layer( 128, 256, 6, stride=2)\n", - " self.layer4 = self._make_layer( 256, 512, 3, stride=2)\n", - "\n", - " #分类用的全连接\n", - " self.fc = nn.Linear(512, num_classes)\n", - " \n", - " def _make_layer(self, inchannel, outchannel, block_num, stride=1):\n", - " '''\n", - " 构建layer,包含多个residual block\n", - " '''\n", - " shortcut = nn.Sequential(\n", - " nn.Conv2d(inchannel,outchannel,1,stride, bias=False),\n", - " nn.BatchNorm2d(outchannel))\n", - " \n", - " layers = []\n", - " layers.append(ResidualBlock(inchannel, outchannel, stride, shortcut))\n", - " \n", - " for i in range(1, block_num):\n", - " layers.append(ResidualBlock(outchannel, outchannel))\n", - " return nn.Sequential(*layers)\n", - " \n", - " def forward(self, x):\n", - " x = self.pre(x)\n", - " \n", - " x = self.layer1(x)\n", - " x = self.layer2(x)\n", - " x = self.layer3(x)\n", - " x = self.layer4(x)\n", - "\n", - " x = F.avg_pool2d(x, 7)\n", - " x = x.view(x.size(0), -1)\n", - " return self.fc(x)" - ] - }, - { - "cell_type": "code", - "execution_count": 59, - "metadata": {}, - "outputs": [], - "source": [ - "model = ResNet()\n", - "input = t.autograd.Variable(t.randn(1, 3, 224, 224))\n", - "o = model(input)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "感兴趣的读者可以尝试实现Google的Inception网络结构或ResNet的其它变体,看看如何能够简洁明了地实现它,实现代码尽量控制在80行以内(本例去掉空行和注释总共不超过50行)。另外,与PyTorch配套的图像工具包`torchvision`已经实现了深度学习中大多数经典的模型,其中就包括ResNet34,读者可以通过下面两行代码使用:\n", - "```python\n", - "from torchvision import models\n", - "model = models.resnet34()\n", - "```\n", - "本例中ResNet34的实现就是参考了torchvision中的实现并做了简化,感兴趣的读者可以阅读相应的源码,比较这里的实现和torchvision中实现的不同。" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3 (ipykernel)", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.9.7" - } - }, - "nbformat": 4, - "nbformat_minor": 2 -} diff --git a/6_pytorch/0_basic/2-autograd.ipynb b/6_pytorch/2-autograd.ipynb similarity index 78% rename from 6_pytorch/0_basic/2-autograd.ipynb rename to 6_pytorch/2-autograd.ipynb index 0d8676a..2a65563 100644 --- a/6_pytorch/0_basic/2-autograd.ipynb +++ b/6_pytorch/2-autograd.ipynb @@ -5,7 +5,17 @@ "metadata": {}, "source": [ "# 自动求导\n", - "这次课程我们会了解 PyTorch 中的自动求导机制,自动求导是 PyTorch 中非常重要的特性,能够让我们避免手动去计算非常复杂的导数,这能够极大地减少了我们构建模型的时间,这也是其前身 Torch 这个框架所不具备的特性,下面我们通过例子看看 PyTorch 自动求导的独特魅力以及探究自动求导的更多用法。" + "\n", + "自动求导是 PyTorch 中非常重要的特性,能够让我们避免手动去计算非常复杂的导数,这能够极大地减少构建模型的时间。 PyTorch 的 Autograd 模块实现了深度学习的算法中的反向传播求导数,在张量(Tensor类)上的所有操作, Autograd 都能为他们自动提供微分,简化了手动计算导数的复杂过程。\n", + "\n", + "在PyTorch 0.4以前的版本中, PyTorch 使用 `Variabe` 类来自动计算所有的梯度 `Variable` 类主要包含三个属性 \n", + "* Variable 所包含的 Tensor;\n", + "* grad:保存 data 对应的梯度,grad 也是个 Variable,而不是 Tensor,它和 data 的形状一样;\n", + "* grad_fn:指向一个 Function 对象,这个 Function 用来反向传播计算输入的梯度;\n", + "\n", + "从 PyTorch 0.4版本起, `Variable` 正式合并入 `Tensor` 类,通过 `Variable` 嵌套实现的自动微分功能已经整合进入了 `Tensor` 类中。虽然为了的兼容性还是可以使用 `Variable`(tensor)这种方式进行嵌套,但是这个操作其实什么都没做。\n", + "\n", + "以后的代码建议直接使用 `Tensor` 类进行操作,因为官方文档中已经将 `Variable` 设置成过期模块。" ] }, { @@ -14,8 +24,7 @@ "metadata": {}, "outputs": [], "source": [ - "import torch\n", - "from torch.autograd import Variable" + "import torch" ] }, { @@ -28,7 +37,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 3, "metadata": {}, "outputs": [ { @@ -40,7 +49,7 @@ } ], "source": [ - "x = Variable(torch.Tensor([2]), requires_grad=True)\n", + "x = torch.tensor([2.0], requires_grad=True)\n", "y = x + 2\n", "z = y ** 2 + 3\n", "print(z)" @@ -67,7 +76,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 4, "metadata": {}, "outputs": [ { @@ -93,7 +102,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 8, "metadata": {}, "outputs": [ { @@ -106,16 +115,16 @@ } ], "source": [ - "# 定义Variable\n", - "x = Variable(torch.FloatTensor([1,2]), requires_grad=False)\n", - "b = Variable(torch.FloatTensor([5,6]), requires_grad=False)\n", - "w = Variable(torch.FloatTensor([[1,2],[3,4]]), requires_grad=True)\n", + "# 定义变量\n", + "x = torch.tensor([1,2], dtype=torch.float, requires_grad=False)\n", + "b = torch.tensor([5,6], dtype=torch.float, requires_grad=False)\n", + "w = torch.tensor([[1,2],[3,4]], dtype=torch.float, requires_grad=True)\n", "print(w)" ] }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 9, "metadata": {}, "outputs": [], "source": [ @@ -132,7 +141,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 10, "metadata": {}, "outputs": [ { @@ -178,7 +187,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "上面数学公式就更加复杂,矩阵乘法之后对两个矩阵对应元素相乘,然后所有元素求平均,有兴趣的同学可以手动去计算一下梯度,使用 PyTorch 的自动求导,我们能够非常容易得到 x, y 和 w 的导数,因为深度学习中充满大量的矩阵运算,所以我们没有办法手动去求这些导数,有了自动求导能够非常方便地解决网络更新的问题。" + "上面数学公式的具体含义是:矩阵乘法之后对两个矩阵对应元素相乘,然后所有元素求平均。使用 PyTorch 的自动求导,能够非常容易得到 对 `w` 的导数,因为深度学习中充满大量的矩阵运算,所以手动去求这些导数比较费时间和精力,有了自动求导能够非常方便地解决网络更新的问题。" ] }, { @@ -192,7 +201,7 @@ }, { "cell_type": "code", - "execution_count": 28, + "execution_count": 12, "metadata": {}, "outputs": [ { @@ -205,22 +214,22 @@ } ], "source": [ - "m = Variable(torch.FloatTensor([[2, 3]]), requires_grad=True) # 构建一个 1 x 2 的矩阵\n", - "n = Variable(torch.zeros(1, 2)) # 构建一个相同大小的 0 矩阵\n", + "m = torch.tensor([[2, 3]], dtype=torch.float, requires_grad=True) # 构建一个 1 x 2 的矩阵\n", + "n = torch.zeros(1, 2) # 构建一个相同大小的 0 矩阵\n", "print(m)\n", "print(n)" ] }, { "cell_type": "code", - "execution_count": 29, + "execution_count": 13, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "tensor(2., grad_fn=)\n", + "tensor(2., grad_fn=)\n", "tensor([[ 4., 27.]], grad_fn=)\n" ] } @@ -247,7 +256,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "下面我们直接对 n 进行反向传播,也就是求 n 对 m 的导数。\n", + "下面我们直接对 `n` 进行反向传播,也就是求 `n` 对 `m` 的导数。\n", "\n", "这时我们需要明确这个导数的定义,即如何定义\n", "\n", @@ -326,7 +335,7 @@ }, { "cell_type": "code", - "execution_count": 16, + "execution_count": 17, "metadata": {}, "outputs": [ { @@ -338,14 +347,14 @@ } ], "source": [ - "x = Variable(torch.FloatTensor([3]), requires_grad=True)\n", + "x = torch.tensor([3], dtype=torch.float, requires_grad=True)\n", "y = x * 2 + x ** 2 + 3\n", "print(y)" ] }, { "cell_type": "code", - "execution_count": 17, + "execution_count": 18, "metadata": {}, "outputs": [], "source": [ @@ -354,7 +363,7 @@ }, { "cell_type": "code", - "execution_count": 18, + "execution_count": 19, "metadata": {}, "outputs": [ { @@ -371,7 +380,7 @@ }, { "cell_type": "code", - "execution_count": 19, + "execution_count": 20, "metadata": {}, "outputs": [], "source": [ @@ -380,7 +389,7 @@ }, { "cell_type": "code", - "execution_count": 24, + "execution_count": 21, "metadata": {}, "outputs": [ { @@ -446,28 +455,17 @@ "\\frac{\\partial k_1}{\\partial x_0} & \\frac{\\partial k_1}{\\partial x_1}\n", "\\end{matrix}\n", "\\right]\n", - "$$\n", - "\n", - "参考答案:\n", - "\n", - "$$\n", - "\\left[\n", - "\\begin{matrix}\n", - "4 & 3 \\\\\n", - "2 & 6 \\\\\n", - "\\end{matrix}\n", - "\\right]\n", - "$$" + "$$\n" ] }, { "cell_type": "code", - "execution_count": 25, + "execution_count": 22, "metadata": {}, "outputs": [], "source": [ - "x = Variable(torch.FloatTensor([2, 3]), requires_grad=True)\n", - "k = Variable(torch.zeros(2))\n", + "x = torch.tensor([2, 3], dtype=torch.float, requires_grad=True)\n", + "k = torch.zeros(2)\n", "\n", "k[0] = x[0] ** 2 + 3 * x[1]\n", "k[1] = x[1] ** 2 + 2 * x[0]" @@ -475,31 +473,7 @@ }, { "cell_type": "code", - "execution_count": 29, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "tensor([2., 3., 4.], requires_grad=True)\n", - "tensor([2., 0., 0.])\n" - ] - } - ], - "source": [ - "# demo to show how to use `.backward`\n", - "x = torch.tensor([2,3,4], dtype=torch.float, requires_grad=True)\n", - "print(x)\n", - "y = x*2\n", - "\n", - "y.backward(torch.tensor([1, 0, 0], dtype=torch.float))\n", - "print(x.grad)" - ] - }, - { - "cell_type": "code", - "execution_count": 26, + "execution_count": 23, "metadata": {}, "outputs": [ { @@ -530,8 +504,10 @@ }, { "cell_type": "code", - "execution_count": 30, - "metadata": {}, + "execution_count": 24, + "metadata": { + "scrolled": true + }, "outputs": [ { "name": "stdout", @@ -545,11 +521,45 @@ "source": [ "print(j)" ] + }, + { + "cell_type": "code", + "execution_count": 25, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "tensor([2., 3., 4.], requires_grad=True)\n", + "tensor([2., 0., 0.])\n" + ] + } + ], + "source": [ + "# demo to show how to use `.backward`\n", + "x = torch.tensor([2,3,4], dtype=torch.float, requires_grad=True)\n", + "print(x)\n", + "y = x*2\n", + "\n", + "y.backward(torch.tensor([1, 0, 0], dtype=torch.float))\n", + "print(x.grad)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 参考资料\n", + "* [PyTorch 的 Autograd](https://zhuanlan.zhihu.com/p/69294347)\n", + "* [PyTorch学习笔记之自动求导(AutoGrad)](https://zhuanlan.zhihu.com/p/102942725)\n", + "* [Pytorch Autograd (自动求导机制)](https://www.cnblogs.com/wangqinze/p/13418291.html)" + ] } ], "metadata": { "kernelspec": { - "display_name": "Python 3", + "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, @@ -563,7 +573,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.7.9" + "version": "3.9.7" } }, "nbformat": 4, diff --git a/6_pytorch/3-linear-regression-gradient-descend.ipynb b/6_pytorch/3-linear-regression-gradient-descend.ipynb new file mode 100644 index 0000000..f2306c2 --- /dev/null +++ b/6_pytorch/3-linear-regression-gradient-descend.ipynb @@ -0,0 +1,909 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# 线性模型的PyTorch实现\n", + "\n", + "本节简单回顾一下线性回归模型,并演示一下如何使用PyTorch来对线性回归模型进行建模和模型参数计算。" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 1. 一元线性回归\n", + "一元线性回归模型比较简单,假设有变量 $x_i$ 和目标 $y_i$,每个 i 对应于一个数据点,希望建立一个模型\n", + "\n", + "$$\n", + "\\hat{y}_i = w x_i + b\n", + "$$\n", + "\n", + "$\\hat{y}_i$ 是预测的结果,希望通过 $\\hat{y}_i$ 来拟合目标 $y_i$,通俗来讲就是找到这个函数拟合 $y_i$ 使得误差最小,即最小化\n", + "\n", + "$$\n", + "\\frac{1}{n} \\sum_{i=1}^n(\\hat{y}_i - y_i)^2\n", + "$$" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "那么如何最小化这个误差呢?" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 2. 梯度下降法\n", + "\n", + "在梯度下降法中,首先要明确梯度的概念,梯度在数学上就是导数,如果是一个多元函数,那么梯度就是偏导数。比如一个函数$f(x, y)$,那么 $f$ 的梯度就是 \n", + "\n", + "$$\n", + "(\\frac{\\partial f}{\\partial x},\\ \\frac{\\partial f}{\\partial y})\n", + "$$\n", + "\n", + "可以称为 grad f(x, y) 或者 $\\nabla f(x, y)$。具体某一点 $(x_0,\\ y_0)$ 的梯度就是 $\\nabla f(x_0,\\ y_0)$。\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "梯度有什么意义呢?从几何意义来讲,一个点的梯度值是这个函数变化最快的地方。具体来说,对于函数 f(x, y),在点 $(x_0, y_0)$ 处,沿着梯度 $\\nabla f(x_0,\\ y_0)$ 的方向,函数增加最快,也就是说沿着梯度的方向,能够更快地找到函数的极大值点,或者反过来沿着梯度的反方向,能够更快地找到函数的最小值点。\n", + "\n", + "针对一元线性回归问题,就是沿着梯度的反方向,不断改变 $w$ 和 $b$ 的值,最终找到一组最好的 $w$ 和 $b$ 使得误差最小。\n", + "\n", + "在更新的时候,需要决定每次更新的幅度就是每次往下走的那一步的长度,这个长度称为学习率,用 $\\eta$ 表示。不同的学习率都会导致不同的结果,学习率太小会导致下降非常缓慢;学习率太大又会导致跳动非常明显。\n", + "\n", + "最后我们的更新公式就是\n", + "\n", + "$$\n", + "w := w - \\eta \\frac{\\partial f(w,\\ b)}{\\partial w} \\\\\n", + "b := b - \\eta \\frac{\\partial f(w,\\ b)}{\\partial b}\n", + "$$\n", + "\n", + "通过不断地迭代更新,最终我们能够找到一组最优的 $w$ 和 $b$。" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 3. PyTorch实现\n", + "\n", + "上面是原理部分,下面通过一个例子来进一步学习线性模型" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 1, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import torch\n", + "import numpy as np\n", + "\n", + "torch.manual_seed(2021)" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Matplotlib is building the font cache; this may take a moment.\n" + ] + }, + { + "data": { + "text/plain": [ + "[]" + ] + }, + "execution_count": 2, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAWoAAAD4CAYAAADFAawfAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8/fFQqAAAACXBIWXMAAAsTAAALEwEAmpwYAAAOpklEQVR4nO3df4xlZ13H8fd3u25kEG3TnRItzA4YLJDGah1rIdKAVbGNkWCaWJ1AbIwTo1bwL4ibyB9mE0n8Q41RMqlojGNJ2LaKCVYajWCCrd7F/tiyoqXsDEvRTi2C6SSW7X7949y7O7u9M3Pu7D3nPPfe9yuZ3L3nnp35zrOzn/PMc5/nOZGZSJLKdaDrAiRJuzOoJalwBrUkFc6glqTCGdSSVLiDTXzSw4cP5+LiYhOfWpKm0okTJ57LzPlhrzUS1IuLi/R6vSY+tSRNpYhY3+k1hz4kqXAGtSQVzqCWpMIZ1JJUOINakgpnUEtSTWtrsLgIBw5Uj2tr7XzdRqbnSdK0WVuDlRXY2qqer69XzwGWl5v92vaoJamGo0cvhPTA1lZ1vGkGtSTVsLEx2vFxMqglqYaFhdGOj5NBLUk1HDsGc3MXH5ubq443zaCWpBqWl2F1FY4cgYjqcXW1+TcSwVkfklTb8nI7wXwpe9SSVDiDWpIKZ1BLUuEMakkqnEEtSYUzqCWpcAa1JBXOoJakwhnUklQ4g1qSCmdQS1LhDGpJKpxBLUmFM6glqXC1gjoi3hcRJyPiyYh4f8M1SZK22TOoI+J64BeBm4AbgJ+MiDc0XZgkqVKnR/0m4OHM3MrMs8CngXc3W5YkaaBOUJ8EbomIqyNiDrgdeO2lJ0XESkT0IqK3ubk57jolaWbtGdSZeQr4MPAQ8CDwGHB2yHmrmbmUmUvz8/NjL1SSZlWtNxMz848z88bMvAV4HviPZsuSJA3UurltRFyTmc9GxALw08Bbmi1LkjRQ9y7k90XE1cA3gV/JzK81WJMkaZu6Qx9vy8w3Z+YNmfl3TRclzYq1NVhchAMHqse1ta4rUonq9qgljdnaGqyswNZW9Xx9vXoOsLzcXV0qj0vIpY4cPXohpAe2tqrj0nYGtdSRjY3Rjmt2GdRSRxYWRjuu2WVQSx05dgzm5i4+NjdXHZe2M6iljiwvw+oqHDkCEdXj6qpvJOrlnPUhdWh52WDW3uxRS3I+d+HsUUszzvnc5bNHLc0453OXz6CWZpzzuctnUEszzvnc5TOopRnnfO7yGdTSjHM+d/mc9SHJ+dyFs0ctSYUzqCWpcAa1JBXOoJakwhnUklQ4g1qSCmdQSx1y1zrV4TxqqSPuWqe67FFLHXHXOtVlUEsdcdc61WVQSx1x1zrVZVBLHXHXOtVlUEsdcdc61eWsD6lD7lqnOuxRS1LhDGpJM2USFxk59CFpZkzqIiN71JJmxqQuMjKoJe1pEocLhpnURUYGtaRdDYYL1tch88JwwSSG9aQuMjKoJe1qUocLhpnURUa1gjoifj0inoyIkxFxb0R8a9OFSRpNU8MTkzpcMMykLjLaM6gj4lrg14ClzLweuAK4s+nCJNXX5PDEpA4X7GR5GU6fhnPnqsfSQxrqD30cBF4REQeBOeCZ5kqSNKomhycmdbigTU2/2bpnUGfmV4DfATaArwJfz8xPXXpeRKxERC8iepubm+OtUtKumhyemNThgra08WZrZObuJ0RcBdwH/AzwP8DHgeOZ+ec7/Z2lpaXs9Xrjq1LSrhYXq4C41JEj1a/3as642j4iTmTm0rDX6gx9/CjwpczczMxvAvcDb63/5SU1zeGJ7rTxZmudoN4Abo6IuYgI4Fbg1PhKkHS5HJ7oThtvttYZo34EOA58Dnii/3dWx1eCpHGYxNkM06CN32ZqzfrIzA9l5hsz8/rMfE9m/t/4SpCkydXGbzPunidJl6npG0C4hFyShihpIyp71JJ0idL2rbZHrfNK6kFIXSptIyp71ALK60FIXSptIyp71ALK60FIXSptIyqDWkB5PQipS6Wt9DSoBZTXg5C6VNpKT4NaQHk9CKlrJa30NKgFlNeDkHSBsz50XtOrqyTtjz1qSSqcQS1JhTOoJalwBrUkFc6glqTCGdSSVDiDWpIKZ1BLUuEMakkqnEEtSYUzqCWpcAZ1gbwllqTtDOrCDG6Jtb4OmRduiWVYTw4vtBo3g7ow3hJrsnmhVRMM6sJ4S6zJ5oVWTTCoC+MtsSabF1o1waAujLfEmmxeaNUEg7ow3hJrsnmhVRO8FVeBvCXW5Br8ux09Wg13LCxUIe2/py6HQS2NmRdajZtDH5JUOINakgpnUEtS4QxqSSqcQS1JhdszqCPiuoh4dNvHNyLi/S3UJkmixvS8zPwC8H0AEXEF8BXggWbLkiQNjDr0cSvwxcxcb6IYSdLLjRrUdwL3DnshIlYiohcRvc3NzcuvTJIEjBDUEXEI+Cng48Nez8zVzFzKzKX5+flx1SdJM2+UHvVtwOcy87+aKkaS9HKjBPXPssOwhySpObWCOiLmgB8D7m+2HEnSpWrtnpeZW8DVDdciSRrClYmSVDiDWpIKZ1BLUuEMakkqnEEtSYUzqCWpcAa1JBXOoJakwhnUklQ4g1qSCmdQS1LhDGpJKtzEBPXaGiwuwoED1ePaWtcVSVI7JiKo19ZgZQXW1yGzelxZMay74kVTatdEBPXRo7C1dfGxra3quNrlRVNq30QE9cbGaMfbNGu9Sy+aUvsmIqgXFkY73pZZ7F2WfNGUptVEBPWxYzA3d/GxubnqeJdmsXdZ6kVTmmYTEdTLy7C6CkeOQET1uLpaHe/SLPYuS71oStNsIoIaqlA+fRrOnaseuw5pmM3eZakXTWmaTUxQl2hWe5clXjSlaWZQXwZ7l5LaYFBfJnuXKtWsTR2dZge7LkDS+A2mjg5mJQ2mjoKdiUlkj1qaQrM4dXSaGdTSFJrFqaPTzKCWptAsTh2dZga1NIVmderotDKopSnk1NHp4qwPaUotLxvM08IetSQVzqCWpMIZ1JJUOINakgpnUEtS4WoFdURcGRHHI+LfIuJURLyl6cIkSZW60/N+D3gwM++IiEPA3F5/QZI0HnsGdUR8O3AL8PMAmfki8GKzZUmSBuoMfbwe2AT+JCL+NSLuiYhXNlyXJKmvTlAfBG4E/igzvx94AfjgpSdFxEpE9CKit7m5OeYypea4wb5KVyeozwBnMvOR/vPjVMF9kcxczcylzFyan58fZ41SYwYb7K+vQ+aFDfYNa5Vkz6DOzP8EvhwR1/UP3Qp8vtGqpJa4wb4mQd1ZH3cDa/0ZH08DdzVXktQeN9jXJKgV1Jn5KLDUbClS+xYWquGOYcelUrgyUTPNDfY1CQxqzTQ32Nck8MYBmnlusK/S2aOWpMIZ1JJUOINakgpnUEtS4QxqSSqcQS1JhTOoJalwBrUkFa6YoHZPYEkaroiViYM9gQfbTQ72BAZXjElSET1q9wSWpJ0VEdTuCSxJOysiqHfa+9c9gSWpkKB2T2BJ2lkRQe2ewJK0syJmfYB7AkvSToroUUuSdmZQS1LhDGpJKpxBrX1z2b/UjmLeTNRkcdm/1B571NoXl/1L7TGotS8u+5faY1BrX1z2L7XHoNa+uOxfao9BrX1x2b/UHmd9aN9c9i+1wx61JBXOoJakwhnUklQ4g7pgLtGWBL6ZWCyXaEsasEddKJdoSxqo1aOOiNPA/wIvAWczc6nJouQSbUkXjDL08Y7MfK6xSnSRhYVquGPYcUmzxaGPQrlEW9JA3aBO4FMRcSIiVoadEBErEdGLiN7m5ub4KpxRLtGWNBCZufdJEd+Vmc9ExDXAQ8DdmfmZnc5fWlrKXq83xjIlabpFxImd3v+r1aPOzGf6j88CDwA3ja88SdJu9gzqiHhlRLxq8Gfgx4GTTRcmSarUmfXxauCBiBic/xeZ+WCjVUmSztszqDPzaeCGFmqRJA3h9DxJKlytWR8jf9KITWDIco2ZcBiY9YVBtoFtALYBjNYGRzJzftgLjQT1LIuI3qwvsbcNbAOwDWB8beDQhyQVzqCWpMIZ1OO32nUBBbANbAOwDWBMbeAYtSQVzh61JBXOoJakwhnU+xQRPxERX4iIpyLig0NeX46Ix/sfn42IqVvduVcbbDvvByPipYi4o836mlbn+4+It0fEoxHxZER8uu0am1bj/8F3RMRfR8Rj/Ta4q4s6mxQRH42IZyNi6B5IUfn9fhs9HhE3jvxFMtOPET+AK4AvAq8HDgGPAW++5Jy3Alf1/3wb8EjXdbfdBtvO+3vgk8AdXdfd8s/AlcDngYX+82u6rruDNvgN4MP9P88DzwOHuq59zO1wC3AjcHKH128H/gYI4Ob9ZIE96v25CXgqM5/OzBeBjwHv2n5CZn42M7/Wf/ow8JqWa2zanm3QdzdwH/Bsm8W1oM73/3PA/Zm5Aee3CZ4mddoggVdFtavbt1EF9dl2y2xWVnvzP7/LKe8C/iwrDwNXRsR3jvI1DOr9uRb48rbnZ/rHdvILVFfUabJnG0TEtcC7gY+0WFdb6vwMfA9wVUT8Q//uSO9trbp21GmDPwDeBDwDPAG8LzPPtVNeMUbNi5cZ5ea2uiCGHBs6zzEi3kEV1D/caEXtq9MGvwt8IDNf6m+TO03qfP8HgR8AbgVeAfxTRDycmf/edHEtqdMG7wQeBX4E+G7goYj4x8z8RsO1laR2XuzEoN6fM8Brtz1/DVWP4SIR8b3APcBtmfnfLdXWljptsAR8rB/Sh4HbI+JsZv5lKxU2q873fwZ4LjNfAF6IiM9QbRk8LUFdpw3uAn47q8HapyLiS8AbgX9up8Qi1MqL3Tj0sT//ArwhIl4XEYeAO4FPbD8hIhaA+4H3TFEPars92yAzX5eZi5m5CBwHfnlKQhpqfP/AXwFvi4iDETEH/BBwquU6m1SnDTaofqMgIl4NXAc83WqV3fsE8N7+7I+bga9n5ldH+QT2qPchM89GxK8Cf0v1zvdHM/PJiPil/usfAX4TuBr4w36P8mxO0U5iNdtgatX5/jPzVEQ8CDwOnAPuycypuY1dzZ+B3wL+NCKeoBoC+EBmTtXWpxFxL/B24HBEnAE+BHwLnG+DT1LN/HgK2KL6LWO0r9GfPiJJKpRDH5JUOINakgpnUEtS4QxqSSqcQS1JhTOoJalwBrUkFe7/AeTSyedpFuSCAAAAAElFTkSuQmCC\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "# 生层测试数据\n", + "x_train = np.random.rand(20, 1)\n", + "y_train = x_train * 3 + 4 + 3*np.random.rand(20,1)\n", + "\n", + "# 画出图像\n", + "import matplotlib.pyplot as plt\n", + "%matplotlib inline\n", + "\n", + "plt.plot(x_train, y_train, 'bo')" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "# 转换成 Tensor\n", + "x_train = torch.from_numpy(x_train)\n", + "y_train = torch.from_numpy(y_train)\n", + "\n", + "# 定义参数 w 和 b\n", + "w = torch.randn(1, requires_grad=True) # 随机初始化\n", + "b = torch.zeros(1, requires_grad=True) # 使用 0 进行初始化" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "# 构建线性回归模型\n", + "def linear_model(x):\n", + " return x * w + b\n", + "\n", + "def logistc_regression(x):\n", + " return torch.sigmoid(x*w+b) " + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "y_ = linear_model(x_train)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "经过上面的步骤我们就定义好了模型,在进行参数更新之前,我们可以先看看模型的输出结果长什么样" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAWoAAAD4CAYAAADFAawfAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8/fFQqAAAACXBIWXMAAAsTAAALEwEAmpwYAAAW1UlEQVR4nO3df2xd5X3H8c/XsUMwoRQlHgJS2zANmpT8IDEsbUcIkIa0QSuI/lHmloYOBchAbFo7YJFGpRR1SBMp6fhlZSlacYdGoIxNWcdaYKEiFGzq0DYpCQtOcGBKYlAK+aEk9nd/HNtxnHt9z7XPOfe5975fkuXccy/nPPfYfO7j53zP85i7CwAQrppSNwAAMDqCGgACR1ADQOAIagAIHEENAIGrTWOnU6dO9ebm5jR2DQAVqbOzc5+7N+R6LpWgbm5uVkdHRxq7BoCKZGY78z3H0AcABI6gBoDAEdQAELhUxqhzOXr0qHp6enT48OGsDlnxJk2apGnTpqmurq7UTQGQosyCuqenR6effrqam5tlZlkdtmK5u3p7e9XT06Pzzjuv1M0BkKLMhj4OHz6sKVOmENIJMTNNmTKFv1CADLW3S83NUk1N9L29PZvjZtajlkRIJ4zzCWSnvV1avlw6eDB6vHNn9FiSWlvTPTYXEwEghpUrj4f0oIMHo+1pI6hjam5u1r59+0rdDAAlsmtXcduTFGxQpzkW5O7q7+9PbocAKl5jY3HbkxRkUA+OBe3cKbkfHwsaT1h3d3dr+vTpWrFihebOnatVq1bpkksu0axZs3TvvfcOve7aa6/VvHnz9JnPfEZtbW0JvBsAleC++6T6+hO31ddH29MWZFCnNRb01ltv6cYbb9T999+v3bt367XXXlNXV5c6Ozu1ceNGSdK6devU2dmpjo4OrVmzRr29veM7KICK0NoqtbVJTU2SWfS9rS39C4lSxlUfcaU1FtTU1KT58+frW9/6lp5//nldfPHFkqSPP/5Y27dv14IFC7RmzRr95Cc/kSS9++672r59u6ZMmTK+AwOoCK2t2QTzSEEGdWNjNNyRa/t4nHbaaZKiMep77rlHt9xyywnPv/TSS/rZz36mTZs2qb6+XgsXLqROGUDJBTn0kfZY0NVXX61169bp448/liTt3r1be/bs0f79+3XmmWeqvr5ev/vd7/Tqq68mc0AAGIcge9SDf1qsXBkNdzQ2RiGd1J8cixcv1tatW/XZz35WkjR58mQ98cQTWrJkiR599FHNmjVLF154oebPn5/MAQFgHMzdE99pS0uLj1w4YOvWrZo+fXrix6p2nFegMphZp7u35HouyKEPAMBxBDUABI6gBkqoVLOxobwEeTERqAalnI0N5YUeNVAipZyNDeWFoAZKpJSzsaG8ENQ5PP7443rvvfeGHt98883asmXLuPfb3d2tH//4x0X/d8uWLdP69evHfXyEpZSzsY3EWHnYwg3qEv7mjAzqtWvXasaMGePe71iDGpWplLOxDZfGbJVIVphBndJvzhNPPKFLL71Uc+bM0S233KK+vj4tW7ZMF110kWbOnKnVq1dr/fr16ujoUGtrq+bMmaNDhw5p4cKFGryBZ/Lkybrrrrs0b948LVq0SK+99poWLlyo888/X88995ykKJAvu+wyzZ07V3PnztUrr7wiSbr77rv18ssva86cOVq9erX6+vr07W9/e2i61ccee0xSNBfJ7bffrhkzZmjp0qXas2fPuN43wlTK2diGY6y8DLh74l/z5s3zkbZs2XLStryamtyjiD7xq6kp/j5yHP+aa67xI0eOuLv7bbfd5t/5znd80aJFQ6/58MMP3d398ssv99dff31o+/DHknzDhg3u7n7ttdf6F77wBT9y5Ih3dXX57Nmz3d39wIEDfujQIXd337Ztmw+ejxdffNGXLl06tN/HHnvMV61a5e7uhw8f9nnz5vmOHTv86aef9kWLFvmxY8d89+7dfsYZZ/hTTz2V930B42GW+383s1K3rLpI6vA8mRpmeV4KV1l+/vOfq7OzU5dccokk6dChQ1qyZIl27NihO+64Q0uXLtXixYsL7mfixIlasmSJJGnmzJk65ZRTVFdXp5kzZ6q7u1uSdPToUd1+++3q6urShAkTtG3btpz7ev755/Xmm28OjT/v379f27dv18aNG3XDDTdowoQJOuecc3TllVeO+X0DhaQ1WyWSE+bQRwpXWdxd3/jGN9TV1aWuri699dZbevDBB7V582YtXLhQDz30kG6++eaC+6mrqxta/bumpkannHLK0L+PHTsmSVq9erXOOussbd68WR0dHTpy5EjeNv3gBz8YatM777wz9GHBCuPISihj5cgvzKBO4Tfnqquu0vr164fGez/44APt3LlT/f39uv7667Vq1Sq98cYbkqTTTz9dH3300ZiPtX//fp199tmqqanRj370I/X19eXc79VXX61HHnlER48elSRt27ZNBw4c0IIFC/Tkk0+qr69P77//vl588cUxtwVhC6HaIpSxcuQX5tBHCvOczpgxQ9/97ne1ePFi9ff3q66uTg888ICuu+66oYVuv/e970mKyuFuvfVWnXrqqdq0aVPRx1qxYoWuv/56PfXUU7riiiuGFiyYNWuWamtrNXv2bC1btkx33nmnuru7NXfuXLm7Ghoa9Oyzz+q6667TCy+8oJkzZ+qCCy7Q5ZdfPub3jXCFdGdiqVYuQTxMc1rmOK/lq7k599hwU5M0cLkDVYRpToEAcWci4iKogRIJ6c5EhC3ToE5jmKWacT7LG9UWpRHCBdxiZRbUkyZNUm9vL+GSEHdXb2+vJk2aVOqmYIzKqdqiHMMtl3K9XT7WxUQz+ytJN0tySb+WdJO7H873+lwXE48ePaqenh4dPpz3P0ORJk2apGnTpqmurq7UTUEFG1mdIkU9/1A/VEYT8gXc0S4mFgxqMztX0i8kzXD3Q2b2r5I2uPvj+f6bXEENIF3t7YlWtA4JOdyKVVMT9aRHMpMGqnRLJomqj1pJp5pZraR6Se8VeD2ADKX5J30lVaekdQE37aGhgkHt7rsl/YOkXZLel7Tf3Z9PthkAxiPNGfAqqToljQu4WYx7FwxqMztT0pclnSfpHEmnmdnXcrxuuZl1mFnH3r17k2shgILS7PVWUnVKGhdws5gmNs7QxyJJ77j7Xnc/KukZSZ8b+SJ3b3P3FndvaWhoSK6FAApKs9dbTtUpcbS2RmPr/f3R93zvI+5wRhZDQ3GCepek+WZWb9GUbldJ2ppcExCKSinBqkZp93rjhlulKGY4I4uhoThj1L+UtF7SG4pK82oktSXXBISgXOtLEam0Xm+pFTOckcXQUGaTMiFslVSCBYxXsWV8SZRGjlaeF+Y0p8hcJZVgAeNV7Ko3aU8Ty6RMkFRZJVjAeIVW6UJQQ1J4v5hAKYU25s/QBySlsqgOUNZCWvWGoMaQkH4xARzH0AcABI6gDhA3ngAYjqAODDeelD8+aJE0gjowWUzwgvTwQYs0ENSB4caT8sYHLdJAUAeGG0/KGx+0SANBHRhuPClvfNAiDQR1YEK7IwrF4YMWaeCGlwBx40n54g5PpIGgBhLGBy2SxtAHAASOoAaAwBHUABA4ghoAAkdQA0DgCGoACBxBDQCBI6gBIHBlE9TM8RsOfhZAtsoiqEOe47faQivknwVQqczdE99pS0uLd3R0JLa/5uYoEEZqapK6uxM7TNEGQ2v4/MP19ZU9iVKoPwug3JlZp7u35HyuHIK6pibqvY1kJvX3J3aYolVjaIX6swDK3WhBXRZDH6HO8VuNk8SH+rPAyaptWK6SlUVQhzrHbzWGVqg/C5yIawmVpSyCOtTJ9KsxtEL9WeBErN1YWcpijDpk7e1MEo/wcC2h/Iw2Rs3CAePEJPEIUWNj7gvdlTwsV8nKYugDQHGqcViukhHUQAXiWkJliTX0YWaflLRW0kWSXNI33X1Tiu0CME4My1WOuD3qByX91N0/LWm2pK3pNQnIFvXGCF3BHrWZfULSAknLJMndj0g6km6zgGyMnAZgsN5YojeKcMTpUZ8vaa+kH5rZr8xsrZmdNvJFZrbczDrMrGPv3r2JNxRIA/XGKAdxgrpW0lxJj7j7xZIOSLp75Ivcvc3dW9y9paGhIeFmAumoxmkAUH7iBHWPpB53/+XA4/WKghsoe9U4DQDKT8Ggdvf/k/SumV04sOkqSVtSbRWQEeqNUQ7i3pl4h6R2M5soaYekm9JrEpCdwQuGTAOAkDHXBwAEoCzmo6aWFQByC2JSJmpZASC/IHrU1LICQH5BBDW1rACQXxBBTS0rAOQXRFBTy1qeuAAMZCOIoGbu3PLD4qlAdqijxpg0N+de6qmpSeruzro1QPkrizpqlBcuAAPZIagxJlwABrJDUAcs5It1XAAGskNQByr0i3VcAAayw8XEQHGxDqguXEwsQ1ysAzCIoA4UF+sADCKoA8XFOgCDCOpAcbEOwKAg5qNGbq2tBDMAetQAEDyCGgACR1ADQOAIagAIHEENAIEjqAEgcAQ1AASOoAaAwBHUABA4ghoAAkdQA0DgCGoACBxBDQCBI6gBIHAENQAELnZQm9kEM/uVmf1Hmg0CAJyomB71nZK2ptUQAEBusYLazKZJWippbbrNAQCMFLdH/X1JfyOpP98LzGy5mXWYWcfevXuTaBsAQDGC2syukbTH3TtHe527t7l7i7u3NDQ0JNZAAKh2cXrUn5f0p2bWLelJSVea2ROptgoAMKRgULv7Pe4+zd2bJX1V0gvu/rXUWwYAkEQdNQAEr7aYF7v7S5JeSqUlAICc6FEDQOAIagAIHEENAIEjqAEgcAQ1AASOoAaAwBHUABA4ghoAAkdQA0DgCGoACBxBDQCBI6gBIHAENQAEjqAGgMAR1AAQOIIaAAJHUANA4AhqAAgcQQ0AgSOoASBwBDUABI6gBoDAEdQAEDiCGgACR1ADQOAIagAIHEENAIEjqAEgcAQ1AIxXe7vU3CzV1ETf29sT3T1BDQC55ArffNuWL5d27pTco+/Llyca1ubuie1sUEtLi3d0dCS+XwDIxGD4Hjx4fNvEiVEQHz16fFt9vXTqqVJv78n7aGqSurtjH9LMOt29Jddz9KgBVJc4wxQrV54Y0pJ05MiJIS1Fr8kV0pK0a1cSrZUk1Sa2JwAI3cie8uAwhSS1th5/XRIh29g4/n0MKNijNrNPmdmLZrbVzH5rZncmdnQAyFKunvLBg9H24YoJ2SlToiGQ4errpfvuG1sbc4gz9HFM0l+7+3RJ8yX9hZnNSKwFAJCVfD3lkdvvu+/k8J04UaqrO3Fbfb304INSW1s0Jm0WfW9rO7GHPk4Fhz7c/X1J7w/8+yMz2yrpXElbEmsFAGShsTEa7si1fbjBkF25MgrxxsbjPeSR2wZfm2Awj1RU1YeZNUvaKOkid//9iOeWS1ouSY2NjfN25joZAFBKuao56usT7wGPRSJVH2Y2WdLTkv5yZEhLkru3uXuLu7c0NDSMvbUAkJbW1tSHKdIQK6jNrE5RSLe7+zPpNglAcFK+8y5Tra1RfXN/f/Q98JCWYoxRm5lJ+idJW939gfSbBCAocUvakJo4PerPS/q6pCvNrGvg60sptwtAlkbrMcctaUNqCga1u//C3c3dZ7n7nIGvDVk0DsA4xB2uKDRXRdySNqSGW8iBSjE8mKdOlb75zXgTBRXqMee7+SPBO+8wOoIaqAQje8W9vdHcFMPlG64o1GPOdfNHwnfeYXQENVAJcvWKc8kVyoV6zGVa0lZJCGqglFaskGprowCsrY0ej0Xc8eJcoRynx1yGJW2VhKAGkhb3It6KFdIjj0h9fdHjvr7o8VjCOs54cb7hCnrMwWPhACBJxdyiXFt7PKSHmzBBOnZs/Metq5M+8Qnpgw9OnpcCwWHhACArxdQc5wrp0baPJlev+Ic/lPbtY7iiAhDUQJK3RxdTczxhQu7X5tteCOPIFYugRnVLemHSYmqOB2/DjrsdVYugRnVL+vboYmqOH35Yuu224z3oCROixw8/PLZjo2JxMRHVraYm6kmPZBYNIYxFe3v+yeWBPEa7mMjitqhucVf8KEZrK8GMRDH0gerG7dEoAwQ1qhs3e6AMMPQBMFSBwNGjRvEG644H56cwK//lmYCAEdQ4Ls6NH8PrjqXjd9GNt/4YQF4ENSJxb/wYbTpNlmcCUkFQIxL3xo9C02myPBOQOIIakbhzVBSqL2Z5JiBxBDUiceeoyFV3PIj6YyAVBHWISlFVEffGj+F1x9LxeSqoPwZSQx11aEZOAD+yqkJKJwwH9xlnjgrqjoFMMSlTaJqbc889MaipKZprGEBFYYWXNCU56bxEVQWAkxDU45H0pPMSVRUATkJQj0fSk85LVFUAOAlBPR7FrI8XF1UVAEag6mM80ph0XqKqAsAJ6FGPB5POA8hA+QR10tUVSWDSeQAZCCeoRwviNKorktLaGtU19/dH3wlpAAkLI6gLBXEa1RUAUCZiBbWZLTGzt8zsbTO7O/FWFAriNKorAKBMFAxqM5sg6SFJX5Q0Q9INZjYj0VYUCuK4M7sBQAWK06O+VNLb7r7D3Y9IelLSlxNtRaEgproCQBWLE9TnSnp32OOegW0nMLPlZtZhZh179+4trhWFgpjqCgBVLE5QW45tJ0255+5t7t7i7i0NDQ3FtSJOEFNdAaBKxbkzsUfSp4Y9nibpvcRbwt14AJBTnB7165L+yMzOM7OJkr4q6bl0mwUAGFSwR+3ux8zsdkn/JWmCpHXu/tvUWwYAkBRzUiZ33yBpQ8ptAQDkEMadiQCAvAhqAAhcKovbmtleSaOs0FrRpkraV+pGlBjngHMgcQ6k4s5Bk7vnrG1OJairmZl15FtJuFpwDjgHEudASu4cMPQBAIEjqAEgcAR18tpK3YAAcA44BxLnQEroHDBGDQCBo0cNAIEjqAEgcAT1GBVanszMWs3szYGvV8xsdinamaa4S7SZ2SVm1mdmX8myfWmL8/7NbKGZdZnZb83sf7JuY9pi/H9whpn9u5ltHjgHN5WinWkys3VmtsfMfpPneTOzNQPn6E0zm1v0QdydryK/FE1O9b+Szpc0UdJmSTNGvOZzks4c+PcXJf2y1O3O+hwMe90LiuaK+Uqp253x78AnJW2R1Djw+A9K3e4SnIO/lXT/wL8bJH0gaWKp257weVggaa6k3+R5/kuS/lPR3P7zx5IF9KjHpuDyZO7+irt/OPDwVUXzeFeSuEu03SHpaUl7smxcBuK8/z+T9Iy775Ikd6/Gc+CSTjczkzRZUVAfy7aZ6XL3jYreVz5flvTPHnlV0ifN7OxijkFQj02s5cmG+XNFn6iVpOA5MLNzJV0n6dEM25WVOL8DF0g608xeMrNOM7sxs9ZlI845+EdJ0xUtNvJrSXe6e382zQtGsXlxkljTnOIksZYnkyQzu0JRUP9Jqi3KXpxz8H1Jd7l7X9Shqihx3n+tpHmSrpJ0qqRNZvaqu29Lu3EZiXMOrpbUJelKSX8o6b/N7GV3/33KbQtJ7LzIh6Aem1jLk5nZLElrJX3R3XszaltW4pyDFklPDoT0VElfMrNj7v5sJi1MV5z33yNpn7sfkHTAzDZKmi2pUoI6zjm4SdLfezRY+7aZvSPp05Jey6aJQRj3coYMfYxNweXJzKxR0jOSvl5BPajhCp4Ddz/P3ZvdvVnSekkrKiSkpXhL1P2bpMvMrNbM6iX9saStGbczTXHOwS5Ff1HIzM6SdKGkHZm2svSek3TjQPXHfEn73f39YnZAj3oMPM/yZGZ268Dzj0r6O0lTJD080KM85hU0k1jMc1Cx4rx/d99qZj+V9Kakfklr3T1nCVc5ivk7sErS42b2a0VDAHe5e0VNfWpm/yJpoaSpZtYj6V5JddLQOdigqPLjbUkHFf2VUdwxBspHAACBYugDAAJHUANA4AhqAAgcQQ0AgSOoASBwBDUABI6gBoDA/T9GRnWgZHl9GwAAAABJRU5ErkJggg==\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "plt.plot(x_train.data.numpy(), y_train.data.numpy(), 'bo', label='real')\n", + "plt.plot(x_train.data.numpy(), y_.data.numpy(), 'ro', label='estimated')\n", + "plt.legend()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "**思考:红色的点表示预测值,似乎排列成一条直线,请思考一下这些点是否在一条直线上?**" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "这个时候需要计算我们的误差函数,也就是\n", + "\n", + "$$\n", + "E = \\sum_{i=1}^n(\\hat{y}_i - y_i)^2\n", + "$$" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [], + "source": [ + "# 计算误差\n", + "def get_loss(y_, y):\n", + " return torch.sum((y_ - y) ** 2)\n", + "\n", + "loss = get_loss(y_, y_train)" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "tensor(733.2964, dtype=torch.float64, grad_fn=)\n" + ] + } + ], + "source": [ + "# 打印一下看看 loss 的大小\n", + "print(loss)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "定义好了误差函数,接下来我们需要计算 $w$ 和 $b$ 的梯度了,这时得益于 PyTorch 的自动求导,不需要手动去算梯度就可以得到计算好的梯度值。手动计算的$w$ 和 $b$ 的梯度分别是\n", + "\n", + "$$\n", + "\\frac{\\partial}{\\partial w} = \\frac{2}{n} \\sum_{i=1}^n x_i(w x_i + b - y_i) \\\\\n", + "\\frac{\\partial}{\\partial b} = \\frac{2}{n} \\sum_{i=1}^n (w x_i + b - y_i)\n", + "$$" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [], + "source": [ + "# 自动求导\n", + "loss.backward()" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "tensor([-135.3880])\n", + "tensor([-239.5816])\n" + ] + } + ], + "source": [ + "# 查看 w 和 b 的梯度\n", + "print(w.grad)\n", + "print(b.grad)" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [], + "source": [ + "# 更新一次参数\n", + "w.data = w.data - 1e-2 * w.grad.data\n", + "b.data = b.data - 1e-2 * b.grad.data" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "更新完成参数之后,我们再一次看看模型输出的结果" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAWoAAAD4CAYAAADFAawfAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8/fFQqAAAACXBIWXMAAAsTAAALEwEAmpwYAAAX10lEQVR4nO3df5DU9X3H8eebAySnVJ3z4mjI3WnHWpAf53FaTCsSQSDBSbTJH9pLlUwtKKMlnUmq1pmaDrGpf5Fq448bw2gD0VHU1qZUMYkGO/7AOwuooGDInR7acpyGRITC3b37x/f2OM5d9ru3+9397O7rMXOzt7tfvvvZ73157Wc/388Pc3dERCRc40pdABEROT4FtYhI4BTUIiKBU1CLiAROQS0iErjxSez0tNNO86ampiR2LSJSkTo7O/e5e3265xIJ6qamJjo6OpLYtYhIRTKz7kzPqelDRCRwCmoRkcApqEVEApdIG3U6R44coaenh0OHDhXrJSvepEmTmDJlChMmTCh1UUQkQUUL6p6eHiZPnkxTUxNmVqyXrVjuTl9fHz09PZx11lmlLo6IJKhoTR+HDh2irq5OIV0gZkZdXZ2+oYgU0bp10NQE48ZFt+vWFed1i1ajBhTSBabjKVI869bBsmXwySfR/e7u6D5AW1uyr62LiSIiMdx229GQTvnkk+jxpCmoY2pqamLfvn2lLoaIlMi77+b2eCEFG9RJtgW5O4ODg4XboYhUvIaG3B4vpCCDOtUW1N0N7kfbgvIJ666uLqZOncqKFStoaWlh1apVXHDBBcycOZPbb799eLsrrriC2bNnc95559He3l6AdyMileCOO6C29tjHamujx5MWZFAn1Rb09ttvc80113DnnXeyZ88eNm/ezJYtW+js7GTTpk0ArFmzhs7OTjo6Orjrrrvo6+vL70VFpCK0tUF7OzQ2gll0296e/IVEKHKvj7iSagtqbGxkzpw5fPvb32bjxo2cf/75AHz88cfs2rWLuXPnctddd/Hkk08C8N5777Fr1y7q6urye2ERqQhtbcUJ5tGCDOqGhqi5I93j+TjxxBOBqI361ltvZfny5cc8//zzz/Ozn/2Ml156idraWubNm6d+yiJSckE2fSTdFrRo0SLWrFnDxx9/DMCePXvYu3cv+/fv59RTT6W2tpa33nqLl19+uTAvKCKShyBr1KmvFrfdFjV3NDREIV2orxwLFy5kx44dXHTRRQCcdNJJrF27lsWLF3Pfffcxc+ZMzj33XObMmVOYFxQRyYO5e8F32tra6qMXDtixYwdTp04t+GtVOx1XkcpgZp3u3pruuSCbPkRE5CgFtYhI4GIFtZmtNLM3zOxNM/tWwmUSEZERsga1mU0H/hK4EJgFXG5m5yRdMBERicSpUU8FXnb3T9y9H/glcGWyxRIRkZQ4Qf0GMNfM6sysFvgy8PnRG5nZMjPrMLOO3t7eQpdTRKRqZQ1qd98B3Ak8CzwNbAX602zX7u6t7t5aX19f8IIW04MPPsj7778/fP+6665j+/btee+3q6uLn/zkJzn/u6VLl7J+/fq8X19EylOsi4nu/iN3b3H3ucCHwK5ki0Xp1rzh00H9wAMPMG3atLz3O9aglspVwtNcykjcXh+fHbptAP4UeDjJQiUyzymwdu1aLrzwQpqbm1m+fDkDAwMsXbqU6dOnM2PGDFavXs369evp6Oigra2N5uZmDh48yLx580gN4DnppJO4+eabmT17NgsWLGDz5s3MmzePs88+m6eeegqIAvniiy+mpaWFlpYWXnzxRQBuueUWXnjhBZqbm1m9ejUDAwN85zvfGZ5u9f777weiuUhuvPFGpk2bxpIlS9i7d29e71vClNBpLpXI3bP+AC8A24maPeZn23727Nk+2vbt2z/1WEaNje7RuXvsT2Nj/H2kef3LL7/cDx8+7O7uN9xwg3/3u9/1BQsWDG/z0Ucfubv7JZdc4q+++urw4yPvA75hwwZ3d7/iiiv8sssu88OHD/uWLVt81qxZ7u5+4MABP3jwoLu779y501PH47nnnvMlS5YM7/f+++/3VatWubv7oUOHfPbs2b57925//PHHfcGCBd7f3+979uzxk08+2R977LGM70vKUwKnuZQxoMMzZGqsuT7c/eLEPinSSWCe05///Od0dnZywQUXAHDw4EEWL17M7t27uemmm1iyZAkLFy7Mup+JEyeyePFiAGbMmMEJJ5zAhAkTmDFjBl1dXQAcOXKEG2+8kS1btlBTU8POnTvT7mvjxo1s27ZtuP15//797Nq1i02bNnH11VdTU1PDmWeeyaWXXjrm9y3hKuXSTlJegpyUKYl5Tt2da6+9lu9///vHPH7HHXfwzDPP8MMf/pBHH32UNWvWHHc/EyZMGF79e9y4cZxwwgnDv/f3R9dYV69ezemnn87WrVsZHBxk0qRJGct09913s2jRomMe37Bhg1YYrwJJTecrlSfMIeQJzHM6f/581q9fP9ze++GHH9Ld3c3g4CBf+9rXWLVqFa+99hoAkydP5ne/+92YX2v//v2cccYZjBs3jh//+McMDAyk3e+iRYu49957OXLkCAA7d+7kwIEDzJ07l0ceeYSBgQE++OADnnvuuTGXRcJVyqWdpLyEGdQJrHkzbdo0vve977Fw4UJmzpzJZZddRldXF/PmzaO5uZmlS5cO17aXLl3K9ddfP3wxMVcrVqzgoYceYs6cOezcuXN4wYKZM2cyfvx4Zs2axerVq7nuuuuYNm0aLS0tTJ8+neXLl9Pf38+VV17JOeecw4wZM7jhhhu45JJLxvy+JVylXNppNPU+CZumOS1zOq6Sr1Tvk5HrlNbWlu5Do1ppmlMRySipxaSlcBTUIlVOvU/CV9SgTqKZpZrpeEohZOplot4n4ShaUE+aNIm+vj6FS4G4O319fRm7/onEpd4n4StaP+opU6bQ09ODZtYrnEmTJjFlypRSF0PKXNKLSUv+itbrQ0REMlOvD5FAqf+yxBHmEHKRKjC6/3Jq9jxQs4McSzVqkRJR/2WJS0EtUiLqvyxxKahFSkT9lyUuBbVIiaj/ssSloBYpkZBmz6sm5djTRr0+REqorU3BXEzl2tNGNWoRyaoca6HplGtPG9WoReS4yrUWmk659rSJVaM2s782szfN7A0ze9jMNBOQSGCSqvWWay00nXLtaZM1qM3sc8BfAa3uPh2oAa5KumAiEl+q1tvdDe5Ha72FCOtyrYWmk1RPm6SbhuK2UY8HPmNm44Fa4P3CFkNE8pFkrbdca6HpJNHTJskPyZRYs+eZ2UrgDuAgsNHdP/W2zGwZsAygoaFhdnd3d+FKKSLHNW5cFBKjmcHgYH771pqKx9fUFIXzaI2N0NUVfz95zZ5nZqcCXwXOAs4ETjSzb4zezt3b3b3V3Vvr6+vjl05E8pZkrVf9vY+vGE1DcZo+FgC/dvdedz8CPAF8oXBFEJF8JT3Ksa0tqh0ODka3CumjitE0FCeo3wXmmFmtmRkwH9hRuCKISL5U6y28uBcIizEVQNagdvdXgPXAa8DrQ/+mvXBFkFBUyqCGaqVab+HkcoGwGB+SWopLAF0wEhmpUBcIc6GluCSrShrUIJKv0PqOK6gFCO/EFCml0PqOK6gFCO/EFCml0OYKV1ALEN6JKVJKofWi0ex5Ahw9AW+7LWruaGiIQloXEqVahTRXuIJahoV0YorIUWr6EBEJnIJaRCRwCmoRkcApqEVEAqegDpDm3BCRkRTUgSnGahGSLH3QSqEpqAOjOTfKmz5oJQkK6sBozo3ypg9aSYKCOjCac6O86YNWkqCgDozm3Chv+qCVJCioAxPaZDCSG33QShI010eANOdG+dLkVpIEBbVIgemDVgota9OHmZ1rZltG/PzWzL5VhLKJiAgxatTu/jbQDGBmNcAe4MlkiyUiIim5XkycD/zK3dOszysiIknINaivAh5O94SZLTOzDjPr6O3tzb9kIiIC5BDUZjYR+ArwWLrn3b3d3VvdvbW+vr5Q5RMRqXq51Ki/BLzm7v+bVGFEROTTcgnqq8nQ7CEiIsmJFdRmVgtcBjyRbHFERGS0WANe3P0ToC7hsoiISBqa60NEJHAKahGRwCmoRUQCp6AWEQmcglpEJHBlE9Ra2Tkc+luIFFdZzEedWtk5tWhoamVn0Ly/xaa/hUjxlUWNOuSVnautdhny30KkUpVFjTrUlZ2rsXYZ6t9CpJKVRY061JWdq7F2GerfQqSSlUVQh7qyczXWLkP9W4hUsrII6rY2aG+HxkYwi27b20vfvFCNtctQ/xbyadV2/aSSmbsXfKetra3e0dFR8P2GZnQbNUS1SwWXlJrOzfJjZp3u3pruubKoUYdKtUsJVTVeP6lkqlGLVKBx4yDdf20zGBwsfnkkO9WoRapMNV4/qWQKapEKpN45lUVBLVKBdP2kspTFyEQRyV1bm4K5UqhGLSISuLirkJ9iZuvN7C0z22FmFyVdMBGRspHw6KK4TR//BDzt7l83s4lAbbZ/ICJSFYowO1vWGrWZ/R4wF/gRgLsfdvffFOTVRQKgodYSW7qTpQiji7IOeDGzZqAd2A7MAjqBle5+YNR2y4BlAA0NDbO7u7sLVkiRpGiotcSW6WQZHdIpOY4uynfAy3igBbjX3c8HDgC3jN7I3dvdvdXdW+vr62MXTqSUNNRaMhpde165Mv3JUlOT/t8XcHRRnKDuAXrc/ZWh++uJgluk7FXjVLUSQ6r23N0djcXv7oa+vvTbDgwkProoa1C7+/8A75nZuUMPzSdqBhEpexpqLWml+6qVSWo0UYKji+L2o74JWGdm24Bm4B8KVgKREtJQ6yoU5+px3K9UqZOlrQ26uqI26a6ugl/giBXU7r5lqP15prtf4e4fFbQUIiWiodZVJl2TxrJlnw7rTF+p6upKcrJomlMRqR5NTVE4j9bYGNWEU0rQHUjTnIqIQPyrx4F91dKkTCJSPRoa0teo0zV1BDSrVTA1ao0OE5HElenV4yCCOm77vohIXgJr0ogriIuJcdv3RUQqVfAXEzU6TEQksyCCWqPDREQyCyKoy7R9X0SkKIII6jJt3xcRKYogghoSHyovCVCXSpHi0IAXGZMirD4kIkOCqVFLedGE+yLFo6CWMVGXSpHiUVDLmKhLZZXRBYmSUlDLmKhLZRXRHA8lp6CWMVGXyiqiCxIlp6AOWOjfNtWlsoIc72TTBYmSU/e8QKn7mxRNtpMtlzmcJRGqUQdK3zalaLKdbLogUXKxgtrMuszsdTPbYmZaDLEI9G1T8ha37SzbyaYLEiWXS9PHF919X2IlkWPo26bkJZe2szgnW0DLUlUjNX0ESt82JWcja9DXXhu/7UwnW/DiBrUDG82s08yWpdvAzJaZWYeZdfT29hauhFVK3zYlJ6P7Og8MpN8uXTOHTrbgxVqKy8zOdPf3zeyzwLPATe6+KdP2uS7FJSJ5yrSe3Wha3y5YeS/F5e7vD93uBZ4ELixc8UQkb3GuMqs5o2xlDWozO9HMJqd+BxYCbyRdMBHJQaarzDU1as6oAHFq1KcD/2VmW4HNwH+4+9PJFktEcpLpguBDD2noaAXIGtTuvtvdZw39nOfu+u4kUigrVsD48VGtd/z46P5Y6IJgRdMQcpFSWbEC7r336P2BgaP377kn9/2pr3PFUj9qkVJpb8/tcalaCmqRQos7dDtTX+dMj0vVUtOHSCHlMnS7piZ9KNfUJFtGKTuqUYsUUi7THi5LO8g38+NStVSjFimkXKY9TF0wbG+PatY1NVFIj+VColQ01ahFCinXVX/vuQf6+6P5Ofr7FdKSloJapJA0E50kQEEtUkgaeCIJUFCLFHoVYa36KwWmi4lS3bSKsJQB1ailumkVYSkDCmqpblpFWMqAglqqW67d6URKQEEt1U3d6aQMKKiluqk7nZQB9foQ0TzOEjjVqEVEAqegFhEJXOygNrMaM/tvM/tpkgUSEZFj5VKjXgnsSKogIiKSXqygNrMpwBLggWSLIyIio8WtUf8A+BtgMNMGZrbMzDrMrKO3t7cQZRMREWIEtZldDux1987jbefu7e7e6u6t9fX1BSugiEi1i1Oj/mPgK2bWBTwCXGpmaxMtlYiIDMsa1O5+q7tPcfcm4CrgF+7+jcRLJuEr9DzOIpKWRiZK7tatg5Uroa/v6GOax1kkMTkNeHH359398qQKI2UgNdH+yJBO0TzOIonQyETJTbqJ9kfSPM4iBaeglqPitDlnC2LN4yxScApqiaSaNLq7wf1om/PosD5eEGseZ5FEKKglEnftwHQT7QPU1WkeZ5GEKKglEnftwHQT7a9dC/v2KaRFEqLueRJpaIiaO9I9Ppom2hcpKtWoJaK1A0WCpaCWiNYOFAmWmj7kKDVpiARJNWoRkcApqEOUGnhiBuPHR7ea9EikaqnpIzSpgSepPs0DA9GtJj0SqVqqUYfmeHNpaNIjkaqkoA5Ntrk0NOmRSNVRUIcm26RGmvRIpOooqEOTaS4N0AAUkSqloA7NyIEnADU10a0GoIhULQV1vpJYN7CtDbq6oulG+/uj264uhbRIlVL3vHyM7kqnLnQikgDVqPMRdw5nEZE8ZA1qM5tkZpvNbKuZvWlmf1+MgpWFuHM4i4jkIU6N+v+AS919FtAMLDazOYmWqlxk6iqnLnQiUkBZg9ojHw/dnTD044mWqlxoDmcRKYJYbdRmVmNmW4C9wLPu/kqabZaZWYeZdfT29ha4mIHSHM4iUgTmHr9ybGanAE8CN7n7G5m2a21t9Y6OjvxLJyJSJcys091b0z2XU68Pd/8N8DywOP9iiYhIHHF6fdQP1aQxs88AC4C3Ei6XiIgMiTPg5QzgITOrIQr2R939p8kWS0REUrIGtbtvA84vQllERCSN8hmZmMScGiIiZaA85vrQnBoiUsXKo0atOTVEpIqFE9THa9rQnBoiUsXCCOpU00Z3dzT3cqppIxXWmlNDRKpYGEGdrWlDc2qISBULI6izNW1oTg0RqWJh9PpoaIiaO9I9ntLWpmAWkaoURo1aTRsiIhmFEdRq2hARySiMpg9Q04aISAZh1KhFRCQjBbWISOAU1CIigVNQi4gETkEtIhK4nBa3jb1Ts14gzQiWqnAasK/UhSgxHQMdA9AxgNyOQaO716d7IpGgrmZm1pFpJeFqoWOgYwA6BlC4Y6CmDxGRwCmoRUQCp6AuvPZSFyAAOgY6BqBjAAU6BmqjFhEJnGrUIiKBU1CLiAROQT1GZrbYzN42s3fM7JY0z7eZ2bahnxfNbFYpypmkbMdgxHYXmNmAmX29mOVLWpz3b2bzzGyLmb1pZr8sdhmTFuP/wclm9u9mtnXoGHyzFOVMkpmtMbO9ZvZGhufNzO4aOkbbzKwl5xdxd/3k+APUAL8CzgYmAluBaaO2+QJw6tDvXwJeKXW5i30MRmz3C2AD8PVSl7vI58ApwHagYej+Z0td7hIcg78F7hz6vR74EJhY6rIX+DjMBVqANzI8/2XgPwED5owlC1SjHpsLgXfcfbe7HwYeAb46cgN3f9HdPxq6+zIwpchlTFrWYzDkJuBxYG8xC1cEcd7/nwFPuPu7AO5ejcfAgclmZsBJREHdX9xiJsvdNxG9r0y+CvyLR14GTjGzM3J5DQX12HwOeG/E/Z6hxzL5C6JP1EqS9RiY2eeAK4H7iliuYolzDvwBcKqZPW9mnWZ2TdFKVxxxjsE/A1OB94HXgZXuPlic4gUj17z4lHBWeCkvluaxtP0czeyLREH9J4mWqPjiHIMfADe7+0BUoaoocd7/eGA2MB/4DPCSmb3s7juTLlyRxDkGi4AtwKXA7wPPmtkL7v7bhMsWkth5kYmCemx6gM+PuD+FqMZwDDObCTwAfMnd+4pUtmKJcwxagUeGQvo04Mtm1u/u/1qUEiYrzvvvAfa5+wHggJltAmYBlRLUcY7BN4F/9Kix9h0z+zXwh8Dm4hQxCLHy4njU9DE2rwLnmNlZZjYRuAp4auQGZtYAPAH8eQXVoEbKegzc/Sx3b3L3JmA9sKJCQhpivH/g34CLzWy8mdUCfwTsKHI5kxTnGLxL9I0CMzsdOBfYXdRSlt5TwDVDvT/mAPvd/YNcdqAa9Ri4e7+Z3Qg8Q3Tle427v2lm1w89fx/wd0AdcM9QjbLfK2gmsZjHoGLFef/uvsPMnga2AYPAA+6etgtXOYp5DqwCHjSz14maAG5294qa+tTMHgbmAaeZWQ9wOzABho/BBqKeH+8AnxB9y8jtNYa6j4iISKDU9CEiEjgFtYhI4BTUIiKBU1CLiAROQS0iEjgFtYhI4BTUIiKB+3/J1INg/R/OLQAAAABJRU5ErkJggg==\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "y_ = linear_model(x_train)\n", + "plt.plot(x_train.data.numpy(), y_train.data.numpy(), 'bo', label='real')\n", + "plt.plot(x_train.data.numpy(), y_.data.numpy(), 'ro', label='estimated')\n", + "plt.legend()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "从上面的例子可以看到,更新之后红色的线跑到了蓝色的线下面,没有特别好的拟合蓝色的真实值,所以我们需要在进行几次更新" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "epoch: 19, loss: 17.798984092741378\n", + "epoch: 39, loss: 16.14508120463308\n", + "epoch: 59, loss: 15.55101918276564\n", + "epoch: 79, loss: 15.33763961353287\n", + "epoch: 99, loss: 15.26099545058815\n" + ] + } + ], + "source": [ + "for e in range(100): # 进行 100 次更新\n", + " y_ = linear_model(x_train)\n", + " loss = get_loss(y_, y_train)\n", + " \n", + " w.grad.zero_() # 注意:归零梯度\n", + " b.grad.zero_() # 注意:归零梯度\n", + " loss.backward()\n", + " \n", + " w.data = w.data - 1e-2 * w.grad.data # 更新 w\n", + " b.data = b.data - 1e-2 * b.grad.data # 更新 b \n", + " if (e + 1) % 20 == 0:\n", + " print('epoch: {}, loss: {}'.format(e, loss.item()))" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 14, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAWoAAAD4CAYAAADFAawfAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8/fFQqAAAACXBIWXMAAAsTAAALEwEAmpwYAAAW+0lEQVR4nO3df5DU9X3H8df7+CGi1DgHcTT07rRTDUR+CKfFTkVUBBIyjdTMNPYSQxKDhmhpZpLRlJmaDmFSZzohMU3Uq6VOwxkn4I/aKVUSfwRn1OidQWNAwZADD205DkuUHwXu3v3je3vAsct993a/3/18d5+PmRvYvXX3s9/B1372/fll7i4AQLjqKt0AAMCpEdQAEDiCGgACR1ADQOAIagAI3MgknnT8+PHe1NSUxFMDQFXq6OjY4+4T8v0ukaBuampSe3t7Ek8NAFXJzHYU+h2lDwAIHEENAIEjqAEgcInUqPM5cuSIurq6dOjQobResuqNGTNGEydO1KhRoyrdFAAJSi2ou7q6NG7cODU1NcnM0nrZquXu6unpUVdXl84///xKNwdAglIrfRw6dEj19fWEdJmYmerr6/mGAqSorU1qapLq6qI/29rSed3UetSSCOky43oC6Wlrk5YskQ4ciG7v2BHdlqSWlmRfm8FEAIhh+fJjIZ1z4EB0f9II6piampq0Z8+eSjcDQIXs3Fnc/eUUbFAnWQtyd/X19ZXvCQFUvYaG4u4vpyCDOlcL2rFDcj9WCyolrDs7OzVp0iQtXbpUM2bM0IoVK3TppZdq6tSpuvPOOwced91112nmzJn62Mc+ptbW1jK8GwDVYOVKaezYE+8bOza6P2lBBnVStaA333xTN954o+666y7t2rVLL730kjZt2qSOjg5t3LhRkrR69Wp1dHSovb1dd999t3p6ekp7UQBVoaVFam2VGhsls+jP1tbkBxKllGd9xJVULaixsVGzZs3S17/+dW3YsEGXXHKJJOmDDz7Qtm3bNHv2bN1999169NFHJUlvv/22tm3bpvr6+tJeGEBVaGlJJ5gHCzKoGxqicke++0txxhlnSIpq1N/85jd18803n/D7Z599Vj//+c/1wgsvaOzYsZozZw7zlAFUXJClj6RrQfPnz9fq1av1wQcfSJJ27dql3bt3a9++fTr77LM1duxYvfHGG3rxxRfL84IAUIIge9S5rxbLl0fljoaGKKTL9ZVj3rx52rJliy6//HJJ0plnnqk1a9ZowYIFuvfeezV16lRddNFFmjVrVnleEABKYO5e9idtbm72wQcHbNmyRZMmTSr7a9U6ritQHcysw92b8/0uyNIHAOAYghoAAkdQA0DgCGoACBxBDQCBI6gBIHAEdR4PPPCA3nnnnYHbN910kzZv3lzy83Z2durBBx8s+r9bvHix1q1bV/LrA8imcIO6Umfe6OSgvv/++zV58uSSn3e4QQ2gtoUZ1EnscyppzZo1uuyyyzR9+nTdfPPN6u3t1eLFi3XxxRdrypQpWrVqldatW6f29na1tLRo+vTpOnjwoObMmaPcAp4zzzxTt99+u2bOnKm5c+fqpZde0pw5c3TBBRfo8ccflxQF8hVXXKEZM2ZoxowZev755yVJd9xxh5577jlNnz5dq1atUm9vr77xjW8MbLd63333SYr2Irn11ls1efJkLVy4ULt37y7pfQPIOHcv+8/MmTN9sM2bN590X0GNje5RRJ/409gY/znyvP4nP/lJP3z4sLu7f+UrX/FvfetbPnfu3IHHvPfee+7ufuWVV/rLL788cP/xtyX5+vXr3d39uuuu82uvvdYPHz7smzZt8mnTprm7+/79+/3gwYPu7r5161bPXY9nnnnGFy5cOPC89913n69YscLd3Q8dOuQzZ8707du3+8MPP+xz5871o0eP+q5du/yss87ytWvXFnxfALJPUrsXyNRYe32Y2TJJX5Zkkv7Z3b+X3EeHEtnn9KmnnlJHR4cuvfRSSdLBgwe1YMECbd++XbfddpsWLlyoefPmDfk8o0eP1oIFCyRJU6ZM0WmnnaZRo0ZpypQp6uzslCQdOXJEt956qzZt2qQRI0Zo69ateZ9rw4YNeu211wbqz/v27dO2bdu0ceNG3XDDDRoxYoTOO+88XX311cN+3wCyb8igNrOLFYX0ZZIOS3rCzP7T3bcl1qoE9jl1d33+85/Xd77znRPuX7lypZ588kn98Ic/1E9/+lOtXr36lM8zatSogdO/6+rqdNpppw38/ejRo5KkVatW6ZxzztGrr76qvr4+jRkzpmCbfvCDH2j+/Pkn3L9+/XpOGAcwIE6NepKkF939gLsflfQLSYsSbVUC+5xec801Wrdu3UC9d+/evdqxY4f6+vp0/fXXa8WKFXrllVckSePGjdP7778/7Nfat2+fzj33XNXV1enHP/6xent78z7v/Pnzdc899+jIkSOSpK1bt2r//v2aPXu2HnroIfX29urdd9/VM888M+y2AMi+OKWP1yWtNLN6SQclfUJS++AHmdkSSUskqaHUHf4T2Od08uTJ+va3v6158+apr69Po0aN0ne/+10tWrRo4KDbXG978eLFuuWWW3T66afrhRdeKPq1li5dquuvv15r167VVVddNXBgwdSpUzVy5EhNmzZNixcv1rJly9TZ2akZM2bI3TVhwgQ99thjWrRokZ5++mlNmTJFF154oa688sphv28A2Rdrm1Mz+5Kkr0r6QNJmSQfd/WuFHs82p+nhugLVoeRtTt39X9x9hrvPlrRXUnL1aQDACeLO+viwu+82swZJfyHp8mSbBQDIiXsU18P9Neojkr7q7u8N58XcndkMZRSnbAUg++KWPq5w98nuPs3dnxrOC40ZM0Y9PT2ES5m4u3p6egpO/UM2VHCnBGRIaofbTpw4UV1dXeru7k7rJavemDFjNHHixEo3A8OU2ynhwIHodm6nBKl8BzmjOqR2uC2AEzU15V/X1dgo9S9yRQ3hcFsgQAnslIAqRVADFVJoXVip68VQfQhqoEIS2CkBVYqgBiqkpUVqbY1q0mbRn62tDCTiZKnN+gBwspYWghlDo0cNgPncgaNHDdQ45nOHjx41UOOWLz8W0jkHDkT3IwwENVDjmM9dBgnXjghqoMYxn7tEudrRjh3RMdy52lEZw5qgBmoc87lLlELtiKAGahzzuYuQr8SRQu2ITZkAII7B02Ok6KvH6adLPT0nP77I3bXYlAkAijW497xsWf4Sh5R47YigBoDB8g0Q5us1S9LevYnXjih9AMBghTYLz6dMG4hT+gCAYsQdCExpegxBDaC2xFmcUmgSeX19RabHENQAakfcxSmFJpd///tRmaOvL/ozpTmMBDVQQexal7K4i1MCm1zOYCJQIYWm5bLYJEF1dVFPejCzqJdcQQwmAgFi17oKyOjGJgQ1UCHsWlcBGd3YhKAGKiSjnbtsC6z2HBdBDVRIRjt32dfSUpGZG6UgqIEKyWjnDhXAmYlABXEKOeKgRw0AgSOoAdSULC4yovQBoGYMXmSUW0EuhV2CokcNoGZkdZERQQ1gSFksF+ST1UVGBDWAU4q74VwWZHWREUEN4JSyWi7IJ6uLjGIFtZl9zcx+Y2avm9lPzGxM0g0DUJykyhNZLRfkk9VFRkMGtZl9RNJfS2p294sljZD0maQbBiC+JMsTWS0XFJLBFeSxSx8jJZ1uZiMljZX0TnJNAlCsJMsTWS0XpCnpwdYhg9rdd0n6R0k7Jb0raZ+7bxj8ODNbYmbtZtbe3d1d3lYCOKUkyxNZLRekJY3B1iFPeDGzsyU9LOkvJf2vpLWS1rn7mkL/DSe8AOlqaooCYrDGxujrPZJTrmtf6gkvcyX9zt273f2IpEck/Wn8lweQtMTLE9UykToBaQy2xgnqnZJmmdlYMzNJ10jaUr4mAChVouWJappInYA0BltjHW5rZn+vqPRxVNKvJN3k7v9X6PGUPoAqQl3llMp1SHHJh9u6+53u/lF3v9jdP3eqkAaQQacqbVTTROoEpDHYyu55QK0baku5hob8PeqsTqROQNIHQLCEHKh1Q03CrtGJ1CGNnxLUQK0bqrRRgxOpQxs/JagxIKQeBFIUZ9pCFtddlyC0jagIakgKrweBMoj7yVujpY1TCW38lKCGpPB6EBiG44N5/Hjpi1+M98lbg6WNoYS2EVWsedTFYh519tTVRf8/D2YWfdtF4PJN5s2Huc+xlGtudDFKnkeN6hdaDwJFyveVKB/mPscS2pcMghqSKFNmXtwA5pM3tpDGTwlqSAqvB4EixQlgPnkzi6DGgJB6EChSvq9Eo0ZJ9fV88lYBlpAD1SAXwMuXR2WQhoYovAnmqkBQA9Ui6Q0nUDGUPoBKWrpUGjkyKk+MHBndBgahRw1UytKl0j33HLvd23vs9o9+VJk2IUj0qIFKaW0t7n7ULIIaqJTe3uLuR80iqIFKGTGiuPtRswhqoNzi7lqXO0Ul7v2oWQwmAuU01LFWx8sNGLa2RuWOESOixzKQiEHYPQ8oJ07sxjCxex6QltB2nEdVIKgDxJFYGcZ+sUgAQR0YjsTKuJUrdXT0iZsjHR3NrnUoDUEdGI7EyrY2tejL3qpONapPpk416sveqjaxBweGj8HEwHAkVrYxlojhYjAxQyhxVkAZBwUYS0QSCOrAcCRWyso8KMAHLZJAUAeGI7FSVuZBAT5okQRq1KhtCQwKtLVx0AqKd6oaNUvIUdsaGvKP/pVQq+CgFZQbpQ/UNmoVyACCGrWNQQFkAKUPgFoFAkePGgACR1ADQOCGDGozu8jMNh3383sz+5sU2gYAUIwatbu/KWm6JJnZCEm7JD2abLMAADnFlj6ukfRbd88z8RQAkIRig/ozkn6S7xdmtsTM2s2svbu7u/SWAQAkFRHUZjZa0p9LWpvv9+7e6u7N7t48YcKEcrUPAGpeMT3qj0t6xd3/J6nGAABOVkxQ36ACZQ8AQHJiBbWZjZV0raRHkm0OAGCwWEvI3f2ApPqE2wIAyIOViQAQOIIaAAJHUGP4yngoLIDC2OYUxWtrk5Ytk3p6jt2XOxRWYstQoMzoUaM4uVO7jw/pnBIOhQVQGEGN4uQ7tft4O3em1xagRhDUKM5QQVzCobAA8iOocUycwcFTBTGHwgKJIKgRydWed+yQ3I8NDg4O63yndktSfT2HwgIJIagRyVd7zjc4mO/U7jVrpD17CGkgIebuZX/S5uZmb29vL/vzIkF1dVFPejAzqa8v/fYANcbMOty9Od/v6FEjUqj2zOAgUHGZCWoWwSUsX+2ZwUEgCJkI6rjjXChBvtpzgcFBPjSBdGWiRt3UFIXzYI2NUmdn2V4GMeQ+NI8fdxw7lgkfQKkyX6MutMYihEVwtda7jDs5BED5ZCKoQx3nqsWSTMgfmkC1ykRQhzrOVYu9y1A/NIFqlomgLmKcK1W12LsM9UMTqGaZCGopCuXOzmjtRWdn5UNaSqF3GWABPNQPTaCaZSaoQ5RY77KtTRo/XvrsZ4MsgIf4oQlUM4K6BIn0LtmYH8AgBHWJyt67ZGN+lEmAlTMME2cmhoaN+VEGgxcmcaRlttGjDg0b86MManHqaDUjqEPDxvwog1qcOlrNCOrQsDE/yoCFSdWFoA4R899QIhYmVReCGqhCLEyqLgR1qZgDhUDxxax6MD2vFMyBApACetSlYA4UgBQQ1KVgDhSAFBDUpWAOFIAUENSlYA4UgBTECmoz+5CZrTOzN8xsi5ldnnTDMoE5UABSEHfWx/clPeHunzaz0ZLyrHGuUS0tBDOARA0Z1Gb2B5JmS1osSe5+WNLhZJsFAMiJU/q4QFK3pH81s1+Z2f1mdkbC7QIA9IsT1CMlzZB0j7tfImm/pDsGP8jMlphZu5m1d3d3l7mZQHJYXIrQxQnqLkld7v7L/tvrFAX3Cdy91d2b3b15woQJ5WwjkJjc4tIAj6YEBgwZ1O7+35LeNrOL+u+6RtLmRFsFpITFpciCuLM+bpPU1j/jY7ukLyTXJCA9LC5FFsQKanffJKk52aYA6WtoiMod+e4HQpGdlYmM+CABLC5FFmQjqBnxQUJYXIosMHcv+5M2Nzd7e3t7+Z6wqSn/99PGxmhHdADIODPrcPe8JeZs9KgZ8QFQw8IJ6lPVoNlOFEANCyOoh6pBM+IDoIaFEdRDrTpgxAdADQtjMLGuLupJD2YWHaEMAFUu/MFEatAAUFAYQU0NGgAKCiOoqUEDQEFhBLWkNrWoSZ2qU5+a1Kk2EdIAIMXfPS9Rudl5uYkfudl5Ep1qAAiiR82ewABQWBBBzQpxACgsiKBmdh4AFBZEUDM7DwAKCyKomZ0HAIUFMetDikKZYAaAkwXRowYAFEZQA0DgCGoACBxBjWHjYHggHcEMJiJbWPYPpIceNYaFZf9AeghqDAvL/oH0ENQYFpb9A+khqDEsLPsH0kNQY1hY9g+kh1kfGDaW/QPpoEcNAIEjqAEgcAQ1AASOoA4YS7QBSAwmBosl2gBy6FEHiiXaAHJi9ajNrFPS+5J6JR119+YkGwWWaAM4ppjSx1XuviexluAEDQ1RuSPf/QBqC6WPQLFEG0BO3KB2SRvMrMPMluR7gJktMbN2M2vv7u4uXwtrFEu0AeSYuw/9ILPz3P0dM/uwpJ9Jus3dNxZ6fHNzs7e3t5exmQBQ3cyso9D4X6wetbu/0//nbkmPSrqsfM0DAJzKkEFtZmeY2bjc3yXNk/R60g0DAETizPo4R9KjZpZ7/IPu/kSirQIADBgyqN19u6RpKbQFAJAH0/MAIHCxZn0U/aRm3ZLyLNeoCeMl1frCIK4B10DiGkjFXYNGd5+Q7xeJBHUtM7P2Wl9izzXgGkhcA6l814DSBwAEjqAGgMAR1OXXWukGBIBrwDWQuAZSma4BNWoACBw9agAIHEENAIEjqIfJzBaY2Ztm9paZ3ZHn9y1m9lr/z/NmVnWrO4e6Bsc97lIz6zWzT6fZvqTFef9mNsfMNpnZb8zsF2m3MWkx/j84y8z+w8xe7b8GX6hEO5NkZqvNbLeZ5d0DySJ391+j18xsRtEv4u78FPkjaYSk30q6QNJoSa9KmjzoMX8q6ez+v39c0i8r3e60r8Fxj3ta0npJn650u1P+N/AhSZslNfTf/nCl212Ba/C3ku7q//sESXslja5028t8HWZLmiHp9QK//4Sk/5JkkmYNJwvoUQ/PZZLecvft7n5Y0kOSPnX8A9z9eXd/r//mi5ImptzGpA15DfrdJulhSbvTbFwK4rz/v5L0iLvvlAa2Ca4mca6BSxpn0a5uZyoK6qPpNjNZHu3Nv/cUD/mUpH/zyIuSPmRm5xbzGgT18HxE0tvH3e7qv6+QLyn6RK0mQ14DM/uIpEWS7k2xXWmJ82/gQklnm9mz/acj3Zha69IR5xr8k6RJkt6R9GtJy9y9L53mBaPYvDhJMYfb4hjLc1/eeY5mdpWioP6zRFuUvjjX4HuSbnf33v5tcqtJnPc/UtJMSddIOl3SC2b2ortvTbpxKYlzDeZL2iTpakl/JOlnZvacu/8+4baFJHZeFEJQD0+XpD887vZERT2GE5jZVEn3S/q4u/ek1La0xLkGzZIe6g/p8ZI+YWZH3f2xVFqYrDjvv0vSHnffL2m/mW1UtGVwtQR1nGvwBUn/4FGx9i0z+52kj0p6KZ0mBiFWXpwKpY/heVnSH5vZ+WY2WtJnJD1+/APMrEHSI5I+V0U9qOMNeQ3c/Xx3b3L3JknrJC2tkpCWYrx/Sf8u6QozG2lmYyX9iaQtKbczSXGuwU5F3yhkZudIukjS9lRbWXmPS7qxf/bHLEn73P3dYp6AHvUwuPtRM7tV0pOKRr5Xu/tvzOyW/t/fK+nvJNVL+lF/j/KoV9FOYjGvQdWK8/7dfYuZPSHpNUl9ku5396o5xi7mv4EVkh4ws18rKgHc7u5VtfWpmf1E0hxJ482sS9KdkkZJA9dgvaKZH29JOqDoW0Zxr9E/fQQAEChKHwAQOIIaAAJHUANA4AhqAAgcQQ0AgSOoASBwBDUABO7/AXHnjGNFpFY8AAAAAElFTkSuQmCC\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "y_ = linear_model(x_train)\n", + "plt.plot(x_train.data.numpy(), y_train.data.numpy(), 'bo', label='real')\n", + "plt.plot(x_train.data.numpy(), y_.data.numpy(), 'ro', label='estimated')\n", + "plt.legend()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "经过 100 次更新,我们发现红色的预测结果已经比较好的拟合了蓝色的真实值。\n", + "\n", + "现在你已经学会了你的第一个机器学习模型了,再接再厉,完成下面的小练习。" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 4. 多项式回归模型" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "下面更进一步尝试一下多项式回归,下面是关于 x 的多项式:\n", + "\n", + "$$\n", + "\\hat{y} = w_0 + w_1 x + w_2 x^2 + w_3 x^3 \n", + "$$\n", + "\n", + "这样就能够拟合更加复杂的模型,这就是多项式模型,这里使用了 $x$ 的更高次,同理还有多元回归模型,形式也是一样的,只是出了使用 $x$,还是更多的变量,比如 $y$、$z$ 等等,同时他们的 $loss$ 函数和简单的线性回归模型是一致的。" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "首先我们可以先定义一个需要拟合的目标函数,这个函数是个三次的多项式" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "y = 0.90 + 0.50 * x + 3.00 * x^2 + 2.40 * x^3\n" + ] + } + ], + "source": [ + "# 定义一个多变量函数\n", + "\n", + "w_target = np.array([0.5, 3, 2.4]) # 定义参数\n", + "b_target = np.array([0.9]) # 定义参数\n", + "\n", + "f_des = 'y = {:.2f} + {:.2f} * x + {:.2f} * x^2 + {:.2f} * x^3'.format(\n", + " b_target[0], w_target[0], w_target[1], w_target[2]) # 打印出函数的式子\n", + "\n", + "print(f_des)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "我们可以先画出这个多项式的图像" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 16, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXkAAAD7CAYAAACPDORaAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8/fFQqAAAACXBIWXMAAAsTAAALEwEAmpwYAAAlSklEQVR4nO3deXxU9b3/8dcnC4RAIIBhDRCQNaAIBBC1aIuite7WpValdYGqbbWbdemv1rbe2ta2t96rt8UFUajCddeqBa1KXZBdEcIma8KSsIYl68zn90cGbrRhSTLJmZm8n48Hj5k558z5fs4A73zznXO+x9wdERFJTElBFyAiIo1HIS8iksAU8iIiCUwhLyKSwBTyIiIJTCEvIpLAjjnkzexxMysys09rLOtgZrPNbHXksX2NdXea2RozW2lmZ0e7cBERObq69OSfAM75wrI7gLfcvR/wVuQ1ZpYLXAkMjrznYTNLbnC1IiJSJynHuqG7zzGznC8svhA4I/J8KvAO8NPI8mfcvRxYZ2ZrgFHAh0dq47jjjvOcnC82ISIiR7Jw4cLt7p5V27pjDvnD6OzuWwDcfYuZdYos7w7MrbFdQWTZEeXk5LBgwYIGliQi0ryY2YbDrWusL16tlmW1zp9gZhPNbIGZLSguLm6kckREmqeGhvw2M+sKEHksiiwvAHrU2C4b2FzbDtx9srvnuXteVlatv22IiEg9NTTkXwYmRJ5PAF6qsfxKM2tpZr2BfsC8BrYlIiJ1dMxj8mb2NNVfsh5nZgXAPcD9wEwzux7YCFwG4O7LzGwmsByoAm5x91B9CqysrKSgoICysrL6vF1qkZaWRnZ2NqmpqUGXIiKNzGJpquG8vDz/4hev69atIyMjg44dO2JW21C/1IW7s2PHDvbu3Uvv3r2DLkdEosDMFrp7Xm3rYv6K17KyMgV8FJkZHTt21G9GIs1EzIc8oICPMn2eIs1HXIS8iEgim/L+OmYv39Yo+1bIN4GcnBy2b98edBkiEoP2lFbyuzdWMnv51kbZv0K+DtydcDgcdBkxU4eINNyzCwsorQxx7ZicRtm/Qv4o1q9fz6BBg7j55psZPnw4mzZt4ve//z0jR47kxBNP5J577jm07UUXXcSIESMYPHgwkydPPuq+33jjDYYPH87QoUMZN24cAL/4xS944IEHDm0zZMgQ1q9f/291/OpXv+L2228/tN0TTzzB9773PQCmTZvGqFGjOOmkk5g0aRKhUL3OXhWRRhYOO099uJ4RvdozpHu7RmmjoXPXNKl7X1nG8s0lUd1nbre23HP+4CNus3LlSqZMmcLDDz/MrFmzWL16NfPmzcPdueCCC5gzZw5jx47l8ccfp0OHDpSWljJy5EguvfRSOnbsWOs+i4uLufHGG5kzZw69e/dm586dR621Zh3FxcWMGTOG3/3udwDMmDGDu+++m/z8fGbMmMH7779PamoqN998M9OnT+faa6+t+4cjIo3qX2u2s37HAX5wVv9GayOuQj4ovXr14uSTTwZg1qxZzJo1i2HDhgGwb98+Vq9ezdixY3nwwQd54YUXANi0aROrV68+bMjPnTuXsWPHHjpXvUOHDnWqIysriz59+jB37lz69evHypUrOfXUU3nooYdYuHAhI0eOBKC0tJROnTodabciEpAnP1jPcW1a8tUhXRutjbgK+aP1uBtL69atDz13d+68804mTZr0uW3eeecd3nzzTT788EPS09M544wzjnguurvXeipjSkrK58bba+6jZh0AV1xxBTNnzmTgwIFcfPHFmBnuzoQJE/jNb35T5+MUkaazcccB/rmyiO99uS8tUhpv5Fxj8nV09tln8/jjj7Nv3z4ACgsLKSoqYs+ePbRv35709HRWrFjB3Llzj7ifMWPG8O6777Ju3TqAQ8M1OTk5LFq0CIBFixYdWl+bSy65hBdffJGnn36aK664AoBx48bx7LPPUlRUdGi/GzYcdhZSEQnItI82kGTGVaN7NWo7cdWTjwXjx48nPz+fMWPGANCmTRumTZvGOeecw1/+8hdOPPFEBgwYcGhY5XCysrKYPHkyl1xyCeFwmE6dOjF79mwuvfRSnnzySU466SRGjhxJ//6HH6tr3749ubm5LF++nFGjRgGQm5vLr3/9a8aPH084HCY1NZWHHnqIXr0a9x+SiBy70ooQM+Zv4uzBnenSLq1R24r5uWvy8/MZNGhQQBUlLn2uIsGZOX8Ttz/3Cc9MPJmT+9T+vV1dxPXcNSIiicTdeeKD9QzonMHo3kc/4aKhFPIiIk1o0cZdLN9SwjVjejXJPFJxEfKxNKSUCPR5igRn6gcbyGiZwsXDjnrb66iI+ZBPS0tjx44dCqYoOTiffFpa437ZIyL/buueMl5buoWv52XTumXTnPcS82fXZGdnU1BQgG7yHT0H7wwlIk1r6ofrCbvz7VOa7oY9MR/yqampuoORiMS9/eVVTJ+7gbMHd6Fnx/Qmazfmh2tERBLBswsLKCmr4oYv9WnSdhXyIiKNLBR2HntvHcN6ZjKiV/smbVshLyLSyGYv38bGnQe4sYl78RClkDezH5jZMjP71MyeNrM0M+tgZrPNbHXksWl/fImIxIjH3ltLdvtWjM/t3ORtNzjkzaw78H0gz92HAMnAlcAdwFvu3g94K/JaRKRZWbJpN/PX7+K6U3uTktz0gyfRajEFaGVmKUA6sBm4EJgaWT8VuChKbYmIxI1H/7WWjLQULh/ZI5D2Gxzy7l4IPABsBLYAe9x9FtDZ3bdEttkC1HrnCjObaGYLzGyBzoUXkURSsOsAr3+6latG9aRNE1389EXRGK5pT3WvvTfQDWhtZlcf6/vdfbK757l7XlZWVkPLERGJGU+8vx4DJpySE1gN0RiuORNY5+7F7l4JPA+cAmwzs64AkceiKLQlIhIX9pRW8sz8TZx7Qle6ZbYKrI5ohPxG4GQzS7fqKdXGAfnAy8CEyDYTgJei0JaISFyYNncD+8qr+M7pxwdaR4MHidz9IzN7FlgEVAGLgclAG2CmmV1P9Q+CyxralohIPCitCPHYe+v48oAscru1DbSWqHwT4O73APd8YXE51b16EZFm5Zn5G9m5v4Jbvtw36FJ0xauISDRVVIWZPGcto3p3IC+n8e/8dDQKeRGRKHpxcSFb9pTFRC8eFPIiIlETCjv/8+5nDOnelrH9jgu6HEAhLyISNa9/uoV12/dzyxl9m+T+rcdCIS8iEgXuzkNvf8bxWa05e3CXoMs5RCEvIhIF76wsJn9LCd85/XiSkmKjFw8KeRGRBqvuxa+he2YrLhrWPehyPkchLyLSQB98toMFG3YxcWwfUgOYTvhIYqsaEZE44+78cfYqurRN44qAphM+EoW8iEgDzFm9nYUbdnHLV/qSlpocdDn/RiEvIlJPB3vx3TNbcXledtDl1EohLyJST2+vLOLjTbv53lf60jIl9nrxoJAXEamXg734Hh1acemI2OzFg0JeRKReZi/fxqeFJXz/K/1i7oyammK3MhGRGBUOO396czU5HdO5OMbOi/8ihbyISB39Y9lW8reUcOuZ/UiJ4V48KORFROqkuhe/ij5ZrblgaGz34kEhLyJSJy8uKWTVtn3cdmZ/kmNojprDUciLiByjssoQf5i1ihO6t+O8E7oGXc4xUciLiByjpz7cQOHuUu746sCYmmnySBTyIiLHYM+BSv777TWM7Z/FqX1j465PxyIqIW9mmWb2rJmtMLN8MxtjZh3MbLaZrY48to9GWyIiQXj43TWUlFVyxzkDgy6lTqLVk/8z8Ia7DwSGAvnAHcBb7t4PeCvyWkQk7mzeXcqU99dz8bDu5HZrG3Q5ddLgkDeztsBY4DEAd69w993AhcDUyGZTgYsa2paISBD+OHsVAD8aPyDgSuouGj35PkAxMMXMFpvZo2bWGujs7lsAIo+danuzmU00swVmtqC4uDgK5YiIRM+KrSU8t6iAb52SQ/fMVkGXU2fRCPkUYDjwP+4+DNhPHYZm3H2yu+e5e15WVlYUyhERiZ7fvr6CjJYp3HzG8UGXUi/RCPkCoMDdP4q8fpbq0N9mZl0BIo9FUWhLRKTJvL9mO2+vLOaWL/clM71F0OXUS4ND3t23ApvM7OBg1ThgOfAyMCGybALwUkPbEhFpKlWhMPe+soyeHdKZcEpO0OXUW0qU9vM9YLqZtQDWAt+m+gfITDO7HtgIXBaltkREGt20uRtYtW0ff71mREze1u9YRSXk3X0JkFfLqnHR2L+ISFPaub+CP85exWl9j2N8buegy2kQXfEqIvIFD8xayf6KEPecn4tZfExfcDgKeRGRGpZt3sPT8zZy7Zhe9OucEXQ5DaaQFxGJcHfufXk57dNbcNuZ/YMuJyoU8iIiEa9+soV563fy4/EDaNcqNehyokIhLyIClFaE+M1r+Qzu1pYrRvYIupyoUciLiAAP/nM1m/eUcc/5g+Pijk/HSiEvIs3eiq0lPDJnLZeNyGZU7w5BlxNVCnkRadbCYeeu55fStlUqd507KOhyok4hLyLN2t/mbWTRxt3cfe4g2reOz/lpjkQhLyLNVlFJGb99YwWnHN+RS4Z3D7qcRqGQF5Fm65evLqe8KsyvLxoS91e2Ho5CXkSapbdXFvHqJ1v47pf70ierTdDlNBqFvIg0O6UVIf7fi59yfFZrJp3eJ+hyGlW0phoWEYkbv//HSgp2lTJj4sm0TInfaYSPhXryItKsfLR2B1M+WMe1Y3oxuk/HoMtpdAp5EWk29pdX8ZNnP6FH+3Tu+OrAoMtpEhquEZFm4/7XV7Bp1wFmTBxDeovmEX/qyYtIs/D+mu08NXcD153aO+GmLjgShbyIJLy9ZZXc/uwn9MlqzU/OHhB0OU2qefy+IiLN2n1/z2fLnlKevemUuL4pd32oJy8iCe3tFUU8M38TE8cez/Ce7YMup8lFLeTNLNnMFpvZq5HXHcxstpmtjjw2v09XRAJVVFLGj//3YwZ2yeAHZ/ULupxARLMnfyuQX+P1HcBb7t4PeCvyWkSkSYTDzg9nfsz+iir++6phCX/R0+FEJeTNLBv4GvBojcUXAlMjz6cCF0WjLRGRY/HXOWt5b812fnH+YPp2ygi6nMBEqyf/n8DtQLjGss7uvgUg8tiptjea2UQzW2BmC4qLi6NUjog0Z4s37uIPs1bytRO6JtT9WuujwSFvZucBRe6+sD7vd/fJ7p7n7nlZWVkNLUdEmrmSskq+/8xiOrdN4z8uOSFhpxA+VtE4hfJU4AIzOxdIA9qa2TRgm5l1dfctZtYVKIpCWyIih+Xu3P3Cp2zeXcbMSSfTrlVq0CUFrsE9eXe/092z3T0HuBL4p7tfDbwMTIhsNgF4qaFtiYgcyYz5m3jl48384Mx+jOjVfK5qPZLGPE/+fuAsM1sNnBV5LSLSKD7etJufv7yM0/oex01n9A26nJgR1Ste3f0d4J3I8x3AuGjuX0SkNtv3lXPTtIVktWnJg98YRnJS8x6Hr0nTGohIXKsKhfne3xazY38Fz910Ch1atwi6pJiikBeRuPa7f6zkw7U7eOCyoQzp3i7ocmKO5q4Rkbj16iebmTxnLdeO6cXXR2QHXU5MUsiLSFxauXUvtz/7CXm92vOzr+UGXU7MUsiLSNwp3lvO9VPn07plCg9/czgtUhRlh6NPRkTiSmlFiBueXMCOfRU8NiGPTm3Tgi4ppumLVxGJG+Gw84MZS/ikYDd/vXoEJ2ZnBl1SzFNPXkTixm/fWMEby7bys6/lMn5wl6DLiQsKeRGJC9M/2sBfI2fSXHdqTtDlxA2FvIjEvHdWFvHzl5bx5QFZ/Py83GY/s2RdKORFJKYt3LCTm6Yton/nDP7rquGkJCu26kKflojErGWb9/CtKfPp3LYlT143ijYtda5IXSnkRSQmrS3ex7WPzSOjZQrTbhhNVkbLoEuKSwp5EYk5hbtLufrRjwB46obRZLdPD7ii+KXffUQkphTvLeeaRz9ib3kVT994MsdntQm6pLimnryIxIziveVc/ehHbN5TypRvjdSsklGgnryIxIRtJWVc9chcCneX8tiEkeTl6PZ90aCQF5HAbd5dylWPzKV4bzlTvz2K0X06Bl1SwlDIi0igNu08wDcemcueA5U8ef1oRvRqH3RJCUUhLyKBWb99P1c9Mpf9FSGm3zhaE441ggZ/8WpmPczsbTPLN7NlZnZrZHkHM5ttZqsjj/rxLCKHLC3Yw9f/8iFlVWGevvFkBXwjicbZNVXAj9x9EHAycIuZ5QJ3AG+5ez/grchrERHeXlHEFZM/pGVKEjMnjSG3W9ugS0pYDQ55d9/i7osiz/cC+UB34EJgamSzqcBFDW1LROLfM/M2csOTC+iT1ZoXbjmFvp10HnxjiuqYvJnlAMOAj4DO7r4Fqn8QmFmnaLYlIvHF3fnT7FU8+M81nN4/i4e+OVxz0TSBqH3CZtYGeA64zd1LjnUqUDObCEwE6NmzZ7TKEZEYUlYZ4q4XlvL8okKuyOvBry8eQqpmk2wSUfmUzSyV6oCf7u7PRxZvM7OukfVdgaLa3uvuk909z93zsrKyolGOiMSQzbtLufyvH/L8okJ+eFZ/7r/0BAV8E2pwT96qu+yPAfnu/scaq14GJgD3Rx5famhbIhJf5q7dwS3TF1FeFWbyNSN0y74ARGO45lTgGmCpmS2JLLuL6nCfaWbXAxuBy6LQlojEAXdnyvvrue+1fHI6pvPXa/L0BWtAGhzy7v4ecLgB+HEN3b+IxJd95VX87IWlvLhkM+NzO/OHy4eSkZYadFnNlr7aFpGoWbxxF7fNWMKmnQf40Vn9ueXLfUlK0v1Yg6SQF5EGC4Wd/3lnDX96czVd2qYxY9IYRmoWyZigkBeRBtm8u5TbZixh3rqdnD+0G7++aAjtWml4JlYo5EWkXsJhZ8aCTfzHa/mEw84fLhvKJcO7c6zXyEjTUMiLSJ19VryPO59fyrx1Ozm5Twd+e+mJ9OrYOuiypBYKeRE5ZhVVYSbP+YwH/7mGtJQkfnvpCVye10O99ximkBeRY/LBmu3c+8pyVm7by9dO6Mo9F+TSKSMt6LLkKBTyInJEa4v38R+v5fNmfhHZ7VvxyLV5nJXbOeiy5Bgp5EWkVrsPVPDnt1bz1IcbSEtN5vZzBnDdqb1JS00OujSpA4W8iHzO3rJKpn6wnkf+tY69ZZVcMbInPzyrP1kZLYMuTepBIS8iAJSUVfLE++t57L117CmtZNzATvzknAEM7KK7NsUzhbxIM7dzfwVPfbiBx95bS0lZFWcO6syt4/pxQna7oEuTKFDIizRTK7aWMOW99by4pJDyqjBn5VaH+5DuCvdEopAXaUaqQmH+uaKIKe+v58O1O0hLTeKS4dl8+9Qc+nfOCLo8aQQKeZFmYNW2vTy3sIDnFxdSvLecbu3S+Ok5A/nGqB5kprcIujxpRAp5kQRVtLeM15du5blFBXxSsIeUJOOMAZ34+ohszhzUiRTdgq9ZUMiLJJCNOw7wj2Vb+ceyrSzcuAt3yO3alp+fl8sFJ3XjuDY6DbK5UciLxLGyyhCLNuzivTXbeXtlMflbSgAY1LUtt43rzzlDujCgi8bamzOFvEgcKa8KsWxzCfPW7eT9NduZt24n5VVhkpOM4T0zufvcQZw9uAs9O6YHXarECIW8SIwKhZ0NO/bz6eYSFm/cxeKNu1m+uYSKUBiAAZ0z+OboXpzWryOjenekTUv9d5Z/p38VIgGrqApTsOsAG3Ye4LOifazYupeVW/eyumgvZZXVgd4qNZkTs9tx3Wm9GdYzk2E9MzUDpByTRg95MzsH+DOQDDzq7vc3dpsiscDdKa0MsetAJdv3lrOtpIxte8spKilj654yCneXsmHHAbbsKSXs//e+rIyWDOySwdWjezGgSwaDurZlYJcMnQ0j9dKoIW9mycBDwFlAATDfzF529+WN2a40nbLKEMV7y9m+r5zdByopKaukpLSSkrIqSkor2VdeRWlliLLKEGWVYUorQpRXhagKO1UhpyocpirkhLw65dw/v//kJMMMks1ITjKSzEhNrn6ekpxESuQxNclITU4iNaX6eUpy5HVkm4PLk5OSSEk2UpIi+0gykpIMMyPJwKh+dCDsjnt1WDvVPe7KkFMZCkeehymtDLG/PMSBiioOVFQ/lpRWsetABbtLK6moCv/bZ5Zk1UHeLbMVI3Pa07NjNr06pNOrYzq9j2tNR50BI1HU2D35UcAad18LYGbPABcCCvk44O5s2VPGxp0HKNhVSsGu/3ssKimneG85e8urDvv+tNQkWrdIoVWLZNJSk2kV+ZPeIiUStAdDujpwD95b6OBdhtydsEPInXDYCYWdsPvnfkBUVIXZXxGiKlQdulUhp6LG88pQmKpw9WNlyA9ba12ZQYvkJFq1SKZ1ixTSWyST3jKF9NRkco5LZ1h6Ju3SU8ls1YL26al0bNOSLm3T6Ny2JR3btCQ5SXdSkqbR2CHfHdhU43UBMLqR25R6KCopY2nhHlZt28fqor18VrSPNUX72F8ROrSNGXRpm0b3zFYM6taWsW1akpXRkqw2LTkuowXt01vQtlUqbdNSadsqhZYpsTXv+MEfGlXhMKHw//2wOLj84GPYnaRIzx6DJKv+AZSakkSLyG8HCmmJF40d8rX9T/hcd8rMJgITAXr27NnI5QjAvvIqFm3YxcebdvNJ4R6WFuxha0nZofWd27akX6cMLsvrQd9Obcjp2Jrs9q3ompkWc8FdF2ZGskFyUvweg0hdNXbIFwA9arzOBjbX3MDdJwOTAfLy8qL3+7Qcsq+8igXrdzJ37U7mrt3B0sI9hCLf9PXJas3JfTpwQnYmJ2a3Y0CXDNqmpQZcsYhES2OH/Hygn5n1BgqBK4GrGrlNAdZv38+b+dt4K7+I+et3UhV2UpONodmZ3HT68Yzu04GhPTIV6CIJrlFD3t2rzOy7wD+oPoXycXdf1phtNlfuzqeFJbz6yWZm529jbfF+APp1asP1X+rNl/pmMbxXJuktdGmESHPS6P/j3f014LXGbqe52rBjPy8u3sxLHxeytng/qcnG6N4duebkXowb2FmXt4s0c+rWxaHSihAvf1zI0/M2sWTTbgBG9+7AjV/qw1eHdNH84CJyiEI+jnxWvI9pczfw7MIC9pZVMaBzBnd+dSDnD+1Gt8xWQZcnIjFIIR/j3J13Vhbz6HtreX/NDlKTjXNP6MrVJ/cir1f7QxcOiYjURiEfo0Jh57WlW3j4nc/I31JCt3Zp/OTsAVye14OsDF32LiLHRiEfYyqqwrywuIC/vLuWddv3c3xWax64bCgXntSNVE1QJSJ1pJCPEeGw8+rSLTzwj5Vs3HmAE7q34y9XD2d8bheSdAm9iNSTQj4GfLBmO795fQVLC/cwsEsGU741kjMGZGm8XUQaTCEfoFXb9nLf3/N5d1Ux3dql8YfLhnLRsO6a/EpEokYhH4ADFVX8+a3VPPavdaS3SOaucwdy7Zgc0lI1cZaIRJdCvonNWraVe19ZTuHuUi4bkc2d5w6iQ2tdvCQijUMh30QKd5dyz0vLeDN/G/07t+F/vzOGkTkdgi5LRBKcQr6RuTvPLizg3leWEwo7d351INed1lunQ4pIk1DIN6Lt+8q56/mlzFq+jVE5HfjD5UPp0UETholI01HIN5LZy7dx5/OfUFJaxd3nDuK603rrrBkRaXIK+Sgrqwxx7yvLeHreJnK7tmX6DScxoEtG0GWJSDOlkI+ijTsOcNP0hSzbXMJNZxzPD87sT4sUjb2LSHAU8lEye/k2fjhzCUlmPP6tPL4ysHPQJYmIKOQbqioU5oFZq/jLu59xQvd2PPzN4fpyVURihkK+AfYcqOSm6Qv54LMdXDW6Jz8/L1dXrYpITFHI19OGHfv59hPz2bTzAL//+olcltcj6JJERP6NQr4e5q/fycQnF+DAtOtHM7pPx6BLEhGpVYNO/TCz35vZCjP7xMxeMLPMGuvuNLM1ZrbSzM5ucKUx4sXFhXzzkY9on96CF24+VQEvIjGtoef3zQaGuPuJwCrgTgAzywWuBAYD5wAPm1lcD1a7O39+czW3zVjC8F6ZPH/zKfQ+rnXQZYmIHFGDQt7dZ7l7VeTlXCA78vxC4Bl3L3f3dcAaYFRD2gpSOOzc+8py/vTmKi4Z3p0nrxtNZrpmjhSR2BfNK3WuA16PPO8ObKqxriCy7N+Y2UQzW2BmC4qLi6NYTnRUhcLc/twnPPHBeq4/rTd/uGyoLnASkbhx1C9ezexNoEstq+5295ci29wNVAHTD76tlu29tv27+2RgMkBeXl6t2wSlvCrEbc8s4fVPt3Lbmf24dVw/3ZJPROLKUUPe3c880nozmwCcB4xz94MhXQDUPKcwG9hc3yKDUFoRYtK0hcxZVcz/Oy+X60/rHXRJIiJ11tCza84Bfgpc4O4Haqx6GbjSzFqaWW+gHzCvIW01pQMVVUyYMo/3Vhfzu0tPVMCLSNxq6Hny/w20BGZHhjHmuvt33H2Zmc0EllM9jHOLu4ca2FaTKKsMccPUBSxYv5M/XzmM84d2C7okEZF6a1DIu3vfI6y7D7ivIftvauVVISY9tZAP1+7gT5efpIAXkbin00QiKkNhbpm+mHdXFXP/JSdw0bBaTwYSEYkrCnmqT5O87ZklvJm/jV9dOJgrRvYMuiQRkaho9iEfDju3P/sJf1+6hZ99bRDXjMkJuiQRkahp9iF//xsreH5xIT86qz83fKlP0OWIiERVsw75Ke+vY/KctVw7phff/cphv0MWEYlbzTbkX1+6hV++upzxuZ255/zBupJVRBJSswz5Bet3cuuMJQzrkcmD3xhGcpICXkQSU7ML+TVF+7h+6gKyM1vx6ISRul2fiCS0ZhXyO/aV860p80hNNp749ig6tNZ0wSKS2JrN7f8qQ2Fumr6I4r3lzJw0hp4d04MuSUSk0TWbkP/lK8uZt24nf77yJIb2yAy6HBGRJtEshmv+9tFGnpq7gUmn9+HCkzRdgYg0Hwkf8vPW7eTnL33K6f2zuP3sgUGXIyLSpBI65At3l3LTtIX06JCuUyVFpFlK2JAvqwwx6akFlFeFeeTaEbRrlRp0SSIiTS5hv3i995XlfFpYwqPX5tG3U0bQ5YiIBCIhe/IvLSnk6XkbmXR6H87M7Rx0OSIigUm4kF9bvI+7nl/KiF7t+fH4AUGXIyISqIQK+bLKELf8bTGpKUn81zeGkZqcUIcnIlJnCTUm/6tXl5O/pYTHv5VHt8xWQZcjIhK4hOnqvvLxZqZ/tJFJY/vwlYEahxcRgSiFvJn92MzczI6rsexOM1tjZivN7OxotHM467bv587nlzK8ZyY/Plvj8CIiBzV4uMbMegBnARtrLMsFrgQGA92AN82sv7uHGtpebVKSjGE9M7n/0hM1Di8iUkM0EvFPwO2A11h2IfCMu5e7+zpgDTAqCm3VqkeHdJ66fjTdNQ4vIvI5DQp5M7sAKHT3j7+wqjuwqcbrgsiy2vYx0cwWmNmC4uLihpQjIiJfcNThGjN7E+hSy6q7gbuA8bW9rZZlXssy3H0yMBkgLy+v1m1ERKR+jhry7n5mbcvN7ASgN/Bx5CbY2cAiMxtFdc+9R43Ns4HNDa5WRETqpN7DNe6+1N07uXuOu+dQHezD3X0r8DJwpZm1NLPeQD9gXlQqFhGRY9YoF0O5+zIzmwksB6qAWxrrzBoRETm8qIV8pDdf8/V9wH3R2r+IiNSdTioXEUlgCnkRkQRm7rFz1qKZFQMbGrCL44DtUSonSIlyHKBjiUWJchygYzmol7tn1bYipkK+ocxsgbvnBV1HQyXKcYCOJRYlynGAjuVYaLhGRCSBKeRFRBJYooX85KALiJJEOQ7QscSiRDkO0LEcVUKNyYuIyOclWk9eRERqSKiQN7NfmdknZrbEzGaZWbega6ovM/u9ma2IHM8LZpYZdE31ZWaXmdkyMwubWdydCWFm50TucLbGzO4Iup76MrPHzazIzD4NupaGMrMeZva2meVH/m3dGnRN9WFmaWY2z8w+jhzHvVFvI5GGa8ysrbuXRJ5/H8h19+8EXFa9mNl44J/uXmVmvwVw958GXFa9mNkgIAz8Ffixuy8IuKRjZmbJwCqq735WAMwHvuHuywMtrB7MbCywD3jS3YcEXU9DmFlXoKu7LzKzDGAhcFG8/b1Y9RS+rd19n5mlAu8Bt7r73Gi1kVA9+YMBH9Gaw8xhHw/cfZa7V0VezqV6uua45O757r4y6DrqaRSwxt3XunsF8AzVdz6LO+4+B9gZdB3R4O5b3H1R5PleIJ/D3Jgolnm1fZGXqZE/Uc2thAp5ADO7z8w2Ad8Efh50PVFyHfB60EU0U8d8lzMJhpnlAMOAjwIupV7MLNnMlgBFwGx3j+pxxF3Im9mbZvZpLX8uBHD3u929BzAd+G6w1R7Z0Y4lss3dVE/XPD24So/uWI4lTh3zXc6k6ZlZG+A54LYv/CYfN9w95O4nUf3b+igzi+pQWqPMJ9+YDnenqlr8Dfg7cE8jltMgRzsWM5sAnAeM8xj/8qQOfy/xRnc5i1GRMezngOnu/nzQ9TSUu+82s3eAc4CofTkedz35IzGzfjVeXgCsCKqWhjKzc4CfAhe4+4Gg62nG5gP9zKy3mbUArqT6zmcSoMgXlo8B+e7+x6DrqS8zyzp45pyZtQLOJMq5lWhn1zwHDKD6TI4NwHfcvTDYqurHzNYALYEdkUVz4/hMoYuB/wKygN3AEnc/O9Ci6sDMzgX+E0gGHo/cECfumNnTwBlUz3a4DbjH3R8LtKh6MrPTgH8BS6n+/w5wl7u/FlxVdWdmJwJTqf63lQTMdPdfRrWNRAp5ERH5vIQarhERkc9TyIuIJDCFvIhIAlPIi4gkMIW8iEgCU8iLiCQwhbyISAJTyIuIJLD/D4NyVqoKGt6uAAAAAElFTkSuQmCC\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "# 画出这个函数的曲线\n", + "x_sample = np.arange(-3, 3.1, 0.1)\n", + "y_sample = b_target[0] + w_target[0] * x_sample + w_target[1] * x_sample ** 2 + w_target[2] * x_sample ** 3\n", + "\n", + "plt.plot(x_sample, y_sample, label='real curve')\n", + "plt.legend()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "接着构建数据集,需要 x 和 y,同时是一个三次多项式,所以取 $x,\\ x^2, x^3$" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [], + "source": [ + "# 构建数据 x 和 y\n", + "# x 是一个如下矩阵 [x, x^2, x^3]\n", + "# y 是函数的结果 [y]\n", + "\n", + "x_train = np.stack([x_sample ** i for i in range(1, 4)], axis=1)\n", + "x_train = torch.from_numpy(x_train).float() # 转换成 float tensor\n", + "\n", + "y_train = torch.from_numpy(y_sample).float().unsqueeze(1) # 转化成 float tensor " + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "torch.Size([61, 3])\n" + ] + } + ], + "source": [ + "print(x_train.size())" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "接着我们可以定义需要优化的参数,就是前面这个函数里面的 $w_i$" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [], + "source": [ + "# 定义参数\n", + "w = torch.randn((3, 1), dtype=torch.float, requires_grad=True)\n", + "b = torch.zeros((1), dtype=torch.float, requires_grad=True)\n", + "\n", + "# 定义模型\n", + "def multi_linear(x):\n", + " return torch.mm(x, w) + b\n", + "\n", + "def get_loss(y_, y):\n", + " return torch.mean((y_ - y) ** 2)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "我们可以画出没有更新之前的模型和真实的模型之间的对比" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 20, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXkAAAD7CAYAAACPDORaAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8/fFQqAAAACXBIWXMAAAsTAAALEwEAmpwYAAAqEklEQVR4nO3deXxU1f3/8deHEAib7CCCLLaKIMpi4Ata0YoCVQRUrLjiilatYFtBxJ/aKopCXWhFi4poRZZiFdoKgiDuQANCBSKCyiZbBFEwBEhyfn+cGRIwgZCZyZ2ZvJ+Px3nMdmfuZ7J85sy5536OOecQEZHkVCHoAEREJHaU5EVEkpiSvIhIElOSFxFJYkryIiJJTEleRCSJlTjJm9l4M9tmZssL3VfHzOaY2erQZe1Cjw0zszVmtsrMekQ7cBERObKj6clPAHoect89wFzn3InA3NBtzKw10B84JfScsWaWEnG0IiJyVCqWdEPn3Ptm1vyQu/sA54SuvwzMB4aG7p/snNsLfG1ma4BOwCeH20e9evVc8+aH7kJERA5n8eLF3zrn6hf1WImTfDEaOuc2AzjnNptZg9D9jYEFhbbbGLrvsJo3b05GRkaEIYmIlC9mtq64x2J14NWKuK/I+glmNtDMMswsIysrK0bhiIiUT5Em+a1m1gggdLktdP9G4PhC2zUBNhX1As65cc65dOdcev36RX7bEBGRUoo0yc8ABoSuDwCmF7q/v5lVNrMWwInAogj3JSIiR6nEY/JmNgl/kLWemW0EHgBGAlPN7EZgPXAZgHNuhZlNBVYCucDtzrm80gS4f/9+Nm7cSE5OTmmeLhFKS0ujSZMmpKamBh2KiJSCxVOp4fT0dHfogdevv/6aGjVqULduXcyKGuqXWHHOsX37dnbt2kWLFi2CDkdEimFmi51z6UU9FvdnvObk5CjBB8TMqFu3rr5FiSSwuE/ygBJ8gPSzF0lsCZHkRUSS2ZgxMGNGbF5bSb4ExowZQ6tWrbjqqquYMWMGI0eOBODNN99k5cqVB7abMGECmzYVzBS96aabDnpcRORQO3fCsGEwffoRNy2VSM94LRfGjh3LzJkzDxx87N27N+CTfK9evWjdujXgk3ybNm047rjjAHjhhReCCbiQvLw8UlJUNkgkXk2YANnZcPvtsXl99eSP4NZbb+Wrr76id+/ePPnkk0yYMIE77riDjz/+mBkzZnD33XfTrl07HnvsMTIyMrjqqqto164de/bs4ZxzzjlQpqF69eoMHz6ctm3b0rlzZ7Zu3QrAl19+SefOnenYsSP3338/1atXLzKOV155hdNOO422bdtyzTXXAHDdddcxbdq0A9uEnzt//nx++ctfcuWVV3LqqacydOhQxo4de2C7Bx98kD//+c8AjBo1io4dO3LaaafxwAMPRP8HKCLFys+HZ56BM86ADh1is4/E6skPHgxLl0b3Ndu1g6eeKvbh5557jlmzZvHuu+9Sr149JkyYAMAZZ5xB79696dWrF/369QNg5syZjB49mvT0n85k+vHHH+ncuTMjRoxgyJAhPP/889x3330MGjSIQYMGccUVV/Dcc88VGcOKFSsYMWIEH330EfXq1WPHjh1HfFuLFi1i+fLltGjRgk8//ZTBgwdz2223ATB16lRmzZrF7NmzWb16NYsWLcI5R+/evXn//ffp2rXrEV9fRCI3Zw6sWQN//GPs9qGefBmpVKkSvXr1AuD0009n7dq1AHzyySdcdtllAFx55ZVFPnfevHn069ePevXqAVCnTp0j7q9Tp04Hhpfat2/Ptm3b2LRpE8uWLaN27do0bdqU2bNnM3v2bNq3b0+HDh34/PPPWb16daRvVURK6K9/hYYNIdRPjInE6skfpscd71JTUw9MR0xJSSE3N7fEz3XOFTmVsWLFiuTn5x/YZt++fQceq1at2kHb9uvXj2nTprFlyxb69+9/4DnDhg3jlltuOer3IyKR+eor+M9/4L77oFKl2O1HPfkI1KhRg127dhV7uyQ6d+7M66+/DsDkyZOL3KZbt25MnTqV7du3AxwYrmnevDmLFy8GYPr06ezfv7/Y/fTv35/Jkyczbdq0A8NLPXr0YPz48ezevRuAb775hm3bthX7GiISPc8+CxUqQKz7WEryEejfvz+jRo2iffv2fPnll1x33XXceuutBw68lsRTTz3FE088QadOndi8eTM1a9b8yTannHIKw4cP5+yzz6Zt27b87ne/A+Dmm2/mvffeo1OnTixcuPAnvfdDX2PXrl00btyYRo0aAdC9e3euvPJKunTpwqmnnkq/fv2O+kNKRI5edja8+CJcfDE0PuJKG5GJ+9o1mZmZtGrVKqCIYi87O5sqVapgZkyePJlJkyYxPVYTZksp2X8HImVt/Hi48UaYPx/OPjvy1ztc7ZrEGpNPQosXL+aOO+7AOUetWrUYP3580CGJSAw5B3/5C7RpA2UxkU1JPmBnnXUWy5YtCzoMESkjn3ziZ4I/+yyURWkojcmLiJShv/4VjjkGrr66bPanJC8iUka++Qb+8Q+4/noo5uT2qFOSFxEpI3/9qy9lcOedZbdPJXkRkTKwezc895yfNnnCCWW3XyX5MtC8eXO+/fbboMMQkQBNmODLCv/+92W7XyX5o+CcO1BGQHGISEnl5cGTT0LnztClS9nuW0n+CNauXUurVq247bbb6NChAxs2bCi2PG/fvn05/fTTOeWUUxg3btwRX3vWrFl06NCBtm3b0q1bN8CXAR49evSBbdq0acPatWt/EsdDDz3EkCFDDmw3YcIEfvvb3wLw6quv0qlTJ9q1a8ctt9xCXl5etH4cIlIKM2b4WjVl3YuHKM2TN7O7gJsAB3wGXA9UBaYAzYG1wK+dc99Fsp8AKg0DsGrVKl566SXGjh172PK848ePp06dOuzZs4eOHTty6aWXUrdu3SJfMysri5tvvpn333+fFi1alKh8cOE4srKy6NKlC48//jgAU6ZMYfjw4WRmZjJlyhQ++ugjUlNTue2225g4cSLXXnvtUf5kRCRanngCmjeHvn3Lft8RJ3kzawzcCbR2zu0xs6lAf6A1MNc5N9LM7gHuAYZGur8gNGvWjM6dOwMcVJ4XYPfu3axevZquXbsyZswY3njjDQA2bNjA6tWri03yCxYsoGvXrgfKAZekfHDhOOrXr88JJ5zAggULOPHEE1m1ahVnnnkmzzzzDIsXL6Zjx44A7NmzhwYNGkT2AxCRUlu0CD780HcmKwZw+mm0dlkRqGJm+/E9+E3AMOCc0OMvA/OJMMkHVWm4cOGv4srzzp8/n3feeYdPPvmEqlWrcs4555CTk1Psa5akfDBw0GscWoDs8ssvZ+rUqZx88slcfPHFmBnOOQYMGMCjjz561O9TRKLviSegZk244YZg9h/xmLxz7htgNLAe2Ax875ybDTR0zm0ObbMZKLI7aWYDzSzDzDKysrIiDSfmiivP+/3331O7dm2qVq3K559/zoIFCw77Ol26dOG9997j66+/Bg4uH7xkyRIAlixZcuDxolxyySW8+eabTJo0icsvvxzwZYmnTZt2oGTwjh07WLduXWRvWkRKZd06mDYNBg6EGjWCiSEawzW1gT5AC2An8A8zK/EJu865ccA48FUoI40n1rp3705mZiZdQofIq1evzquvvkrPnj157rnnOO2002jZsuWBYZXi1K9fn3HjxnHJJZeQn59PgwYNmDNnDpdeeimvvPIK7dq1o2PHjpx00knFvkbt2rVp3bo1K1eupFOnTgC0bt2ahx9+mO7du5Ofn09qairPPPMMzZo1i94PQURKZMwYX58mNCciEBGXGjazy4CezrkbQ7evBToD3YBznHObzawRMN851/Jwr1UeSw0nAv0ORI7ezp3QtCn06gWvvRbbfR2u1HA0plCuBzqbWVXzg8zdgExgBjAgtM0AIL6KpIuIxNDYsbBrFwwNeLpJxMM1zrmFZjYNWALkAp/ih1+qA1PN7Eb8B8Flke5LRCQRZGf7k58uuADatg02lqjMrnHOPQA8cMjde/G9+mi8fpEzUST24mnlMJFE8cIL8O23cO+9QUeSAGe8pqWlsX37diWbADjn2L59O2lpaUGHIpIw9u2DUaP8qk9nnhl0NAmwMlSTJk3YuHEjiTC9MhmlpaXRpEmToMMQSRivvgobN/refDyI+ySfmpp64KxQEZF4lpcHI0dChw7QvXvQ0Xhxn+RFRBLF66/D6tX+BKh4OYwY92PyIiKJwDl45BE4+WS/MEi8UE9eRCQKZs6EZcvgpZegQhx1n+MoFBGRxBTuxTdtClddFXQ0B1NPXkQkQvPmwUcfwV/+AqmpQUdzMPXkRUQi4Bzcfz80bgw33RR0ND+lnryISARmz4aPP/a1auLxvEH15EVESinci2/aNLhFQY5EPXkRkVJ66y2/vN/zz0PlykFHUzT15EVESiHci2/RAgYMOPL2QVFPXkSkFGbMgCVL/Lz4eJtRU5h68iIiRyk/Hx54AH7+c7i6xIudBkM9eRGRo/TGG/7s1r//HSrGeRZVT15E5CiEe/EtW8IVVwQdzZHF+WeQiEh8mTgRVqyASZMgJSXoaI5MPXkRkRLKyYH77oPTT4df/zroaEpGPXkRkRJ65hlYvx7Gj4+vSpOHkyBhiogE67vvYMQI6NEDunULOpqSi0qSN7NaZjbNzD43s0wz62JmdcxsjpmtDl3Wjsa+RESCMHIk7NwJjz0WdCRHJ1o9+aeBWc65k4G2QCZwDzDXOXciMDd0W0Qk4WzYAE8/DddcA23bBh3N0Yk4yZvZMUBX4EUA59w+59xOoA/wcmizl4G+ke5LRCQI99/vLx96KNg4SiMaPfkTgCzgJTP71MxeMLNqQEPn3GaA0GWDop5sZgPNLMPMMrKysqIQjohI9Hz2Gbz8Mvz2t77aZKKJRpKvCHQAnnXOtQd+5CiGZpxz45xz6c659Pr160chHBGR6LnnHqhZE4YNCzqS0olGkt8IbHTOLQzdnoZP+lvNrBFA6HJbFPYlIlJm5s715YTvvRfq1Ak6mtKJOMk757YAG8ysZeiubsBKYAYQLsA5AJge6b5ERMpKbi4MGgQnnOCHahJVtE6G+i0w0cwqAV8B1+M/QKaa2Y3AeuCyKO1LRCTmnn3Wly944434XNavpKKS5J1zS4H0Ih5KoFMGRES8b7/1M2rOOw/69Ak6msjojFcRkUPcdx/s2uXnxpsFHU1klORFRApZuhTGjYM77oDWrYOOJnJK8iIiIc7BnXdC3brw4INBRxMdqkIpIhIydSp88AH87W9Qq1bQ0USHevIiIkB2Ntx9N7RvDzfeGHQ00aOevIgIvi7Nhg1+5adEWPGppNSTF5Fy77PPYPRouP56OOusoKOJLiV5ESnX8vPhllv8GPyoUUFHE30arhGRcm3cOPjkE19psm7doKOJPvXkRaTc2rzZV5k891y/IEgyUpIXkXLrrrsgJ8fXqUn0M1uLoyQvIuXSzJkwZQoMHw4nnRR0NLGjJC8i5U52Ntx2G5x8MgwZEnQ0saUDryJS7gwfDmvXwnvvQeXKQUcTW+rJi0i58v77vrrk7bdD165BRxN7SvIiUm7s3u1PeGrRAh57LOhoyoaGa0Sk3Bg6FL7+2g/TVKsWdDRlQz15ESkX5s6FsWNh8ODkK11wOEryIpL0fvgBbrgBWraEESOCjqZsabhGRJLe738PGzfCRx9BlSpBR1O21JMXkaT21lvwwgu+VnznzkFHU/ailuTNLMXMPjWzf4du1zGzOWa2OnRZO1r7EhEpic2b4brr4NRT4Y9/DDqaYESzJz8IyCx0+x5grnPuRGBu6LaISJnIz4drr/XTJqdMSf6TnooTlSRvZk2AC4EXCt3dB3g5dP1loG809iUiUhKjRsE778CYMdCqVdDRBCdaPfmngCFAfqH7GjrnNgOELhsU9UQzG2hmGWaWkZWVFaVwRKQ8W7gQ7rsPLrssudZrLY2Ik7yZ9QK2OecWl+b5zrlxzrl051x6/fr1Iw1HRMq577+HK66Axo39giDJWkK4pKIxhfJMoLeZXQCkAceY2avAVjNr5JzbbGaNgG1R2JeISLGcg1tvhfXrfY2aWrWCjih4EffknXPDnHNNnHPNgf7APOfc1cAMYEBoswHA9Ej3JSJyOC++CJMn+5k0Z5wRdDTxIZbz5EcC55vZauD80G0RkZj473/hjjvgvPP8kn7iRfWMV+fcfGB+6Pp2oFs0X19EpCjbtsGll8Kxx8KkSZCSEnRE8UNlDUQkoeXmQv/+kJXlyxbUqxd0RPFFSV5EEtqwYfDuuzBhAnToEHQ08Ue1a0QkYU2dCqNH+1WeBgw48vblkZK8iCSk5ct9+eAzz4Qnngg6mvilJC8iCWfrVrjoIqhRA/7xD6hUKeiI4pfG5EUkoWRnQ+/efkbNe+9Bo0ZBRxTflORFJGHk58M11/g58W+8AenpQUcU/5TkRSRh3HMP/POf8OST0KdP0NEkBo3Ji0hC+NvffPng22+HQYOCjiZxKMmLSNybNcsn9wsugKeeUmXJo6EkLyJx7eOPfcmCNm188bGKGmQ+KkryIhK3li71vffjjoO33/ZTJuXoKMmLSFz64gvo3h2OOcYv49ewYdARJSYleRGJO+vX+5LBAHPmQLNmwcaTyDS6JSJxZetWOP98+OEHX3isZcugI0psSvIiEje2bvU9+A0bfA++ffugI0p8SvIiEhc2bYJu3WDdOvjXv3zhMYmckryIBG7DBjj3XNiyxc+J79o16IiSh5K8iATq6699gt+xA2bPhi5dgo4ouSjJi0hg1qzxCX73bpg7VwXHYiHiKZRmdryZvWtmmWa2wswGhe6vY2ZzzGx16LJ25OGKSLJYvBh+8QvYs8fPolGCj41ozJPPBX7vnGsFdAZuN7PWwD3AXOfcicDc0G0REd56C84+G9LS4IMPoG3boCNKXhEneefcZufcktD1XUAm0BjoA7wc2uxloG+k+xKRxPfCC37Rj5YtYcECOPnkoCNKblE949XMmgPtgYVAQ+fcZvAfBECDaO5LRBKLc3D//XDzzf5kp/nz4dhjg44q+UUtyZtZdeB1YLBz7oejeN5AM8sws4ysrKxohSMicSQnB667Dh56CG68EWbMULGxshKVJG9mqfgEP9E598/Q3VvNrFHo8UbAtqKe65wb55xLd86l169fPxrhiEgc2bABzjoLXnkF/vQneP55SE0NOqryIxqzawx4Ech0zj1R6KEZwIDQ9QHA9Ej3JSKJ5b334PTTYdUqePNN+H//Twt+lLVo9OTPBK4BzjWzpaF2ATASON/MVgPnh26LSDngHDz9tC9TULcuLFqkNVmDEvHJUM65D4HiPpu7Rfr6IpJYdu2C3/wGJk6Evn3h5Zd9TXgJhurJi0jULFzoK0dOmuQPsr7+uhJ80JTkRSRieXkwYoSvHLl/vx+Lv+8+qKAMEzjVrhGRiGzYAFdfDe+/D/37w7PPQq1aQUclYfqcFZFSyc/30yFPPRWWLPFj76+9pgQfb5TkReSorVoFv/wlDBzox+CXLoVrr9X0yHikJC8iJbZvnx97b9sW/vc/X4dm3jz42c+CjkyKozF5ESmRefNg0CBYvhwuuwzGjFHtmUSgnryIHNYXX/iqkd26+cU9pk+HqVOV4BOFkryIFGnHDhg8GE45xVeMfPRRyMz0CV8Sh4ZrROQgP/wAf/kL/PnP8P33cNNNvrBYw4ZBRyaloSQvIoBP6GPGwJNPwnffQa9e8MgjfoqkJC4leZFy7ttvYexYn9x37vTDMfff76tHSuJTkhcppz77zFeKnDjRL+rRp49P7h06BB2ZRJOSvEg5kpsL//mPT+7vvgtVqviTmO680x9gleSjJC9SDqxY4csO/P3vsGULHH88jBzp11utUyfo6CSWlORFktSWLTBtmk/uGRlQsSJccIFfa/Wii/xtSX76NYskka++gjfe8O3jj/0KTe3awVNPwRVXQIMGQUcoZU1JXiSB5eT4ZP7OO/DWW7Bsmb+/bVt48EG45BJo0ybQECVgSvIiCWTvXvj0U1+7/Z134IMPfKJPSYEuXWD0aLj4YjjhhKAjlXihJC8Sp/Ly4Msvfa32BQt8+/RTXwkSfA/91lvhvPOga1eoUSPYeCU+KcmLBGzfPli71if0zEw/f/2zz2DlStizx29TtSp07Ah33QWdO/umAmFSEjFP8mbWE3gaSAFecM6NjPU+ReKBc5CdDdu3w9atsGlTQfvmG1i3zif2DRv8Kkthxx7rSwn85jf+sm1bf6nZMFIaMf2zMbMU4BngfGAj8F8zm+GcWxnL/UrZ2bPHJ7CtW30y27nz4LZrl0902dl+2+xsP4a8f78/MWf/ft/y8vzrOXfw66ek+MWgU1IKWmqqT3iFL1NToVIl3wrfPvR64eeFW4UKBc3MXzrnE2/hy337Ctrevf4yO9uX3929G3780V/u3Ol/Fjt2+O0OVaGCT+RNm8IvfuEX3Ai3k06C+vVj+iuTcibWfYNOwBrn3FcAZjYZ6ANEN8nv3OmnFjRp4s/yaNzY/1dLRJyDjRv9tLy1aw9umzb5edg//FD886tUgerVoVo1f71qVd+qVStIvuGkm5JSsHRc+DKcYPPyDm7hD4fcXJ9Ed+/2t/ftK7gMXw+38O1oMXNUruSompZP9ar5VK+SR/UqeVRLy+PEurl0bpFLnRr7qXPMfuoek0uD2vtpXH8vx9XbT4M6uaRUNP9Gw59c4U+yrBTYmVrwAwq3ypV9S0mJ3puQciHWSb4xsKHQ7Y3A/0V9L5mZcNVVB9/XsKFP+I0aHdyOO84/1qCBb1WrRj2cRLR5Myxe7M+MXLnS/0gzM30CDTPzn5/Nmvm518ce63+U4ct69fwizuFWuXIUAtu/v+ArwKFfCQpfHtpycn5y6bL3kL93P7l79pObk0tuTi77c/Jw+/aTvy/3oFZhXw4VyMNwVCAfw1GJfQdaisuHvfj2fRTeZ0lVrOh/sGlp/pMz/OkZvqxRw3+yhluNGnDMMVCzZkGrVQtq1/anutas6T9gJGnFOskXtazvQV/IzWwgMBCgadOmpdtLhw4+I23Y4NvGjQXX163z0xKysop+brVq/vtxvXr+j75uXd/q1PH/CIWzVs2a/h+mRg3f0tIScuXiXbvgk09g0SJ/JmRGhh8jDjvuOGjdGm64AVq1gp//HJo395+ZlSvju9PFJdVvsmFNEUn40AR9aCvu/tzc0r3JSpV84gsnw7Q0rHJlUqpUIaVyZSrXqQxpNQt6yJUr++dUrlzQcy5qzKfwOE9KSsFl4d54eNyncAP/1STcoOBrSuHLwmNYhb+GhMeIwi388y78s/vxRz9OFB4/Co8hHY6Z/9uuU8f/D9Sv71uDBv7y2GMPbnXqJOTffHlm7tBB0Gi+uFkX4EHnXI/Q7WEAzrlHi9o+PT3dZWRkxCaYffv8wPHmzbBtm29ZWf4yPKC8Y4e/DA8uH0lKik/21ar5Fh6PCLdwjystrSCRHJo4wmMV4VZ4kDg8QFxUoghfL5wgCo9t5OYeaLt+rMBHXzdi/lfNmL+uBRlbm5Dn/Nf+lsdsIr3Wl6Qf8wXp1TI5tdIqaubtKEgk4cvCrbSJt0KFn47dhFuVKj/9GR66XeGea+HHw/eHW1qahjXC8vL8p/r33xe0nTt9wfgdO3z77jv/N5+VdXAranwrNdX3Apo0KWiNG/sDDM2a+d5A3br6IChjZrbYOZde5GMxTvIVgS+AbsA3wH+BK51zK4raPqZJ/mjl5voB5/A/ReEjiYe2H38suhd6aILcu7egdxZja/gZ/+Ii/sVFfMBZ5JJKKvvoVCGDc1I/5uy0hXSqtoKaaXsP7sEW/kAKXw9/UIUTaOHrhRProUMH4cfCiT01Vf/8icI5/7e/das/+BJumzf7r30bNxa0nJyDn1u1qk/4J5xw8FHln/0MWrSI0jieFHa4JB/T4RrnXK6Z3QG8jZ9COb64BB93Klb0X01jUaLPuZ9OLSncA8/LO/irfXh6R+Ejk+HroeEBVyGFJcsrMeVfVZjxdhqr1viebOtW+fzuQji/B3TpUolq1c4Azoj+e5LkEh7GqVULWrYsfjvn/LeB9ev90Ojatf5y3Tp/xH7+/IOHjCpU8Im+ZcuC1qqVr3Nct25s31M5FdOe/NGKq558gvjyS7/ow2uvwapVvrN89tm+ymCvXjq9XQLmnB8S/fJL3774wv+hrlrlrxf+FtCwoU/2p5ziTwxo186f1lulSmDhJ4rAevISG9nZMGkSPP88LFzo7zv7bPj97+HSS1UfXOKImU/eDRvCGYd8g8zP95MjMjP9tK4VK2D5chg/vqD3X6GCP3mgXTvf0tP9RIvatcv6nSQs9eQTyKpV8OyzMGGCHy5t08av6tO/v5/5IpIU8vP9sM/Spb6s5rJl/vq6dQXb/OxnPuF37OhrPHToUK57/IEdeD1aSvI/5RzMnAlPPAFz5/rhmMsu86e8n3mmjmNKObJ9u6/WFp73m5HhjwWA/8do184n/DPO8KcSN2kSaLhlSUk+AeXl+VV9Hn3Ud2SOP94n9htu8N98RQQ/42fhwoIynYsW+fFM8NM5zzrLt65d/bBPkvaKlOQTyL59fh3Oxx6D1avh5JPhnnvgyit9Z0VEDiM3F/73P19o/4MPfOH98ImQjRrBuecWtObNAw01mpTkE0B+PkydCsOH+5lnp58O994LffvqrHORUnPOz+J57z14912YN8/P9gE/lbNHD9/OPdefzZ6glOTj3Lx5MGSIrx1z2ml+iOZXv0rab5YiwXHOF2eaN88vrTVvni//ULGiX1qrZ0+48EL/j5hA/4BK8nFqxQr4wx9g1iw/5v7ww77Oms7IFykj+/b5RXLfftu3Tz/19x9/vD/R5KKL4Je/9Gd0xzEl+Tjz44/wpz/5GTPVq/shmjvuiPu/I5Hkt2UL/Oc/8O9/w+zZ/iBu1aq+h3/JJb6XX6tW0FH+hJJ8HJk+He6808/8uv56ePxxX/xPROJMTo4vyzBjBrz5pq/bk5oK3br5hN+3b9ys8HK4JK9DemVk/Xro08f/XdSo4Q/8jx+vBC8St9LSfA9+7FhfiO3jj2HwYD/tbeBAP1unRw946SVfyTNOKcnHmHP+DNU2bfxxnscf98N+v/hF0JGJSIlVqOAPzD7+uE/yS5f62RJr1hScvHLRRTB5csHq63FCST6Gtm3z3+quv96fjLd8Odx9t+a7iyQ0M7+6+iOP+CS/aJEfg126FK64wif8G27wUzYLr9AeECX5GJkxwxfSe+stGD3a/75btAg6KhGJKjNfP2f0aF9bZ9486NfPn64ePuHqvvv8yS8BUZKPsj17/HBdnz5+AZ3Fi311SE2LFElyFSr46Zbjx/tZOpMm+XHaRx/1BdXOPdfXBS/j4Rwl+Sj66itfNOz5530pgoUL/e9YRMqZqlV9edi33vI9/Ice8pU1r77a9/7uvNOflFUGlOSjZMYMX+107Vo/xfbRR/2qeiJSzjVp4ods1qzxpWQvuAD+9je/OMo558CUKf6krBhRko9Qbq7vtffpAz//ua+EeuGFQUclInGnQoWCIZuNG2HkSD+3OrwgxKhRsdltTF61nPjuOz9N9rHH4JZb4MMPk6qwnYjESv36MHSo792/9Rb83//FbK69lv8rpS+/9D32r77y50Jcd13QEYlIwqlQwVcj/NWv/Ek1MaAkXwoffujPXHXOn+DUtWvQEYlIwotR1cuIhmvMbJSZfW5m/zOzN8ysVqHHhpnZGjNbZWY9Io40Tkyc6EtX1K3rF6JRgheReBbpmPwcoI1z7jTgC2AYgJm1BvoDpwA9gbFmltAzxZ3zlSOvvtovIfnJJ3DiiUFHJSJyeBEleefcbOdcbujmAiC8cm4fYLJzbq9z7mtgDdApkn0FKT/f1yV64AG49lpfdrpOnaCjEhE5smjOrrkBmBm63hjYUOixjaH7fsLMBppZhpllZIXXYowjublw440wZgzcdZcvNqb57yKSKI6Y5M3sHTNbXkTrU2ib4UAuMDF8VxEvVeShY+fcOOdcunMuvX6c1GYO27vXT2GdMAEefBD+/OeEWhFMROTIs2ucc+cd7nEzGwD0Arq5ghVINgLHF9qsCbCptEEGITvbV5B8+2148kk/XCMikmginV3TExgK9HbOZRd6aAbQ38wqm1kL4ERgUST7Kks//uinrc6ZAy++qAQvIokr0nnyfwUqA3PMj2MscM7d6pxbYWZTgZX4YZzbnXN5Ee6rTOzZA717+7nwr70Gl18edEQiIqUXUZJ3zv38MI+NAEZE8vplbe9eP0Tz7rvw978rwYtI4tMZryH798Ovfw2zZsELL8BVVwUdkYhI5FSgDD9N8qqrfLngZ57xUyZFRJJBuU/y+fl+OcZ//MNPkbzttqAjEhGJnnKf5IcO9ePvDz0Ev/td0NGIiERXuU7yY8b49Xdvvx2GDw86GhGR6Cu3Sf711/3897594emndSariCSncpnkP/rIH2jt3NnPhU9J6PqYIiLFK3dJ/vPP4aKLoFkzP5umSpWgIxIRiZ1yleSzsny5gtRUmDkT6tULOiIRkdgqNydD7d8P/frBli3w/vtwwglBRyQiEnvlJskPHuyT+8SJ0LFj0NGIiJSNcjFcM24cjB0LQ4bAlVcGHY2ISNlJ+iT/wQd+HnzPnvDII0FHIyJStpI6ya9fD5deCi1awKRJmiopIuVP0ib5PXvg4oshJwemT4datYKOSESk7CXtgdfBg2HJEj8XvlWroKMREQlGUvbkJ03yB1uHDPEnPomIlFdJl+S/+AIGDoQzzoCHHw46GhGRYCVVks/J8as7VaoEkyf7M1tFRMqzpBqTv+suWLYM/v1vOP74oKMREQle0vTkp0yB556Du++GCy8MOhoRkfgQlSRvZn8wM2dm9QrdN8zM1pjZKjPrEY39FGf1arj5ZujSBUaMiOWeREQSS8TDNWZ2PHA+sL7Qfa2B/sApwHHAO2Z2knMuL9L9FaViRZ/gn39e4/AiIoVFoyf/JDAEcIXu6wNMds7tdc59DawBOkVhX0Vq0QLefhuaNo3VHkREElNESd7MegPfOOeWHfJQY2BDodsbQ/cV9RoDzSzDzDKysrIiCUdERA5xxOEaM3sHOLaIh4YD9wLdi3paEfe5Iu7DOTcOGAeQnp5e5DYiIlI6R0zyzrnzirrfzE4FWgDLzK+C3QRYYmad8D33wpMYmwCbIo5WRESOSqmHa5xznznnGjjnmjvnmuMTewfn3BZgBtDfzCqbWQvgRGBRVCIWEZESi8nJUM65FWY2FVgJ5AK3x2pmjYiIFC9qST7Umy98ewSgWesiIgFKmjNeRUTkp5TkRUSSmDkXP7MWzSwLWBfBS9QDvo1SOEFKlvcBei/xKFneB+i9hDVzztUv6oG4SvKRMrMM51x60HFEKlneB+i9xKNkeR+g91ISGq4REUliSvIiIkks2ZL8uKADiJJkeR+g9xKPkuV9gN7LESXVmLyIiBws2XryIiJSSFIleTN7yMz+Z2ZLzWy2mR0XdEylZWajzOzz0Pt5w8xqBR1TaZnZZWa2wszyzSzhZkKYWc/QCmdrzOyeoOMpLTMbb2bbzGx50LFEysyON7N3zSwz9Lc1KOiYSsPM0sxskZktC72PP0Z9H8k0XGNmxzjnfghdvxNo7Zy7NeCwSsXMugPznHO5ZvYYgHNuaMBhlYqZtQLygb8Bf3DOZQQcUomZWQrwBX71s43Af4ErnHMrAw2sFMysK7AbeMU51yboeCJhZo2ARs65JWZWA1gM9E2034v5Er7VnHO7zSwV+BAY5JxbEK19JFVPPpzgQ6pRTA37ROCcm+2cyw3dXIAv15yQnHOZzrlVQcdRSp2ANc65r5xz+4DJ+JXPEo5z7n1gR9BxRINzbrNzbkno+i4gk2IWJopnztsdupkaalHNW0mV5AHMbISZbQCuAu4POp4ouQGYGXQQ5VSJVzmTYJhZc6A9sDDgUErFzFLMbCmwDZjjnIvq+0i4JG9m75jZ8iJaHwDn3HDn3PHAROCOYKM9vCO9l9A2w/HlmicGF+mRleS9JKgSr3ImZc/MqgOvA4MP+SafMJxzec65dvhv653MLKpDaTGpJx9Lxa1UVYTXgP8AD8QwnIgc6b2Y2QCgF9DNxfnBk6P4vSQarXIWp0Jj2K8DE51z/ww6nkg553aa2XygJxC1g+MJ15M/HDM7sdDN3sDnQcUSKTPrCQwFejvnsoOOpxz7L3CimbUws0pAf/zKZxKg0AHLF4FM59wTQcdTWmZWPzxzzsyqAOcR5byVbLNrXgda4mdyrANudc59E2xUpWNma4DKwPbQXQsSeKbQxcBfgPrATmCpc65HoEEdBTO7AHgKSAHGhxbESThmNgk4B1/tcCvwgHPuxUCDKiUz+wXwAfAZ/v8d4F7n3FvBRXX0zOw04GX831YFYKpz7k9R3UcyJXkRETlYUg3XiIjIwZTkRUSSmJK8iEgSU5IXEUliSvIiIklMSV5EJIkpyYuIJDEleRGRJPb/AT43r/9qplA4AAAAAElFTkSuQmCC\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "# 画出更新之前的模型\n", + "y_pred = multi_linear(x_train)\n", + "\n", + "plt.plot(x_train.data.numpy()[:, 0], y_pred.data.numpy(), label='fitting curve', color='r')\n", + "plt.plot(x_train.data.numpy()[:, 0], y_sample, label='real curve', color='b')\n", + "plt.legend()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "可以发现,这两条曲线之间存在差异,我们计算一下他们之间的误差" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "tensor(1144.2654, grad_fn=)\n" + ] + } + ], + "source": [ + "# 计算误差,这里的误差和一元的线性模型的误差是相同的,前面已经定义过了 get_loss\n", + "loss = get_loss(y_pred, y_train)\n", + "print(loss)" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "metadata": {}, + "outputs": [], + "source": [ + "# 自动求导\n", + "loss.backward()" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "tensor([[ -94.7455],\n", + " [-139.1247],\n", + " [-629.8584]])\n", + "tensor([-25.7413])\n" + ] + } + ], + "source": [ + "# 查看一下 w 和 b 的梯度\n", + "print(w.grad)\n", + "print(b.grad)" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "metadata": {}, + "outputs": [], + "source": [ + "# 更新一下参数\n", + "w.data = w.data - 0.001 * w.grad.data\n", + "b.data = b.data - 0.001 * b.grad.data" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 25, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXkAAAD7CAYAAACPDORaAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8/fFQqAAAACXBIWXMAAAsTAAALEwEAmpwYAAApTklEQVR4nO3deXxU1f3/8deHEHZlCyD78i0iyC7wBRfUougXKeBWccVdiwvYVpDiz6WWVsWVKlpUBBVBilXSKsgmLggioFTZBNllCyAIhC3J+f1xJiRgAklmJndm8n4+HvdxZ+7cmfuZLJ85c+65n2POOUREJDGVCjoAERGJHiV5EZEEpiQvIpLAlORFRBKYkryISAJTkhcRSWAFTvJmNtrMtpnZd7m2VTOz6Wa2MrSumuuxIWa2ysxWmNlFkQ5cREROrDAt+THAxcdsewCY6ZxrCswM3cfMWgB9gdNDzxlpZklhRysiIoVSuqA7Ouc+NbNGx2zuDZwXuj0WmA0MDm2f4Jw7CKwxs1VAJ2Du8Y6RkpLiGjU69hAiInI8Cxcu3O6cq5HXYwVO8vmo5ZzbDOCc22xmNUPb6wLzcu23MbTtuBo1asSCBQvCDElEpGQxs3X5PRatE6+Wx7Y86yeY2e1mtsDMFqSlpUUpHBGRkincJL/VzGoDhNbbQts3AvVz7VcP2JTXCzjnRjnnOjjnOtSokee3DRERKaJwk3wq0C90ux8wOdf2vmZW1swaA02B+WEeS0RECqnAffJmNh5/kjXFzDYCDwOPAxPN7BZgPXAlgHNuiZlNBJYCGcBdzrnMogR4+PBhNm7cyIEDB4rydAlTuXLlqFevHsnJyUGHIiJFYLFUarhDhw7u2BOva9as4aSTTqJ69eqY5dXVL9HinGPHjh3s2bOHxo0bBx2OiOTDzBY65zrk9VjMX/F64MABJfiAmBnVq1fXtyiROBbzSR5Qgg+QfvYi8S0ukryISCIbMQJSU6Pz2kryBTBixAiaN2/OtddeS2pqKo8//jgA77//PkuXLj2y35gxY9i0KWek6K233nrU4yIix9q1C4YMgcmTT7hrkYR7xWuJMHLkSKZMmXLk5GOvXr0An+R79uxJixYtAJ/kW7ZsSZ06dQB49dVXgwk4l8zMTJKSVDZIJFaNGQPp6XDXXdF5fbXkT+DOO+9k9erV9OrVi2effZYxY8Zw991388UXX5Camsr9999P27ZteeKJJ1iwYAHXXnstbdu2Zf/+/Zx33nlHyjRUqlSJoUOH0qZNGzp37szWrVsB+OGHH+jcuTMdO3bkoYceolKlSnnG8cYbb9C6dWvatGnD9ddfD8CNN97IpEmTjuyT/dzZs2dz/vnnc80119CqVSsGDx7MyJEjj+z3yCOP8PTTTwMwfPhwOnbsSOvWrXn44Ycj/wMUkXxlZcGLL8KZZ0L79tE5Rny15AcOhG++iexrtm0Lzz2X78Mvv/wyU6dO5eOPPyYlJYUxY8YAcOaZZ9KrVy969uzJFVdcAcCUKVN46qmn6NDhlyOZ9u3bR+fOnRk2bBiDBg3ilVde4cEHH2TAgAEMGDCAq6++mpdffjnPGJYsWcKwYcOYM2cOKSkp7Ny584Rva/78+Xz33Xc0btyYr7/+moEDB9K/f38AJk6cyNSpU5k2bRorV65k/vz5OOfo1asXn376KV27dj3h64tI+KZPh1Wr4NFHo3cMteSLSZkyZejZsycAZ5xxBmvXrgVg7ty5XHnllQBcc801eT531qxZXHHFFaSkpABQrVq1Ex6vU6dOR7qX2rVrx7Zt29i0aROLFy+matWqNGjQgGnTpjFt2jTatWtH+/btWb58OStXrgz3rYpIAb3wAtSqBaF2YlTEV0v+OC3uWJecnHxkOGJSUhIZGRkFfq5zLs+hjKVLlyYrK+vIPocOHTryWMWKFY/a94orrmDSpEls2bKFvn37HnnOkCFDuOOOOwr9fkQkPKtXwwcfwIMPQpky0TuOWvJhOOmkk9izZ0++9wuic+fOvPvuuwBMmDAhz326devGxIkT2bFjB8CR7ppGjRqxcOFCACZPnszhw4fzPU7fvn2ZMGECkyZNOtK9dNFFFzF69Gj27t0LwI8//si2bdvyfQ0RiZyXXoJSpSDabSwl+TD07duX4cOH065dO3744QduvPFG7rzzziMnXgviueee45lnnqFTp05s3ryZypUr/2Kf008/naFDh3LuuefSpk0bfv/73wNw22238cknn9CpUye+/PLLX7Tej32NPXv2ULduXWrXrg1A9+7dueaaa+jSpQutWrXiiiuuKPSHlIgUXno6vPYaXHop1D3hTBvhifnaNcuWLaN58+YBRRR96enplC9fHjNjwoQJjB8/nsnRGjBbRIn+OxApbqNHwy23wOzZcO654b/e8WrXxFeffAJauHAhd999N845qlSpwujRo4MOSUSiyDn4+9+hZUsojoFsSvIBO+ecc1i8eHHQYYhIMZk7148Ef+klKI7SUOqTFxEpRi+8ACefDNddVzzHU5IXESkmP/4I//wn3HQT5HNxe8QpyYuIFJMXXvClDO69t/iOqSQvIlIM9u6Fl1/2wyabNCm+4yrJF4NGjRqxffv2oMMQkQCNGePLCv/hD8V7XCX5QnDOHSkjoDhEpKAyM+HZZ6FzZ+jSpXiPrSR/AmvXrqV58+b079+f9u3bs2HDhnzL8/bp04czzjiD008/nVGjRp3wtadOnUr79u1p06YN3bp1A3wZ4KeeeurIPi1btmTt2rW/iOOxxx5j0KBBR/YbM2YM99xzDwBvvfUWnTp1om3bttxxxx1kZmZG6schIkWQmupr1RR3Kx4iNE7ezO4DbgUc8C1wE1ABeAdoBKwFfuuc+ymc4wRQaRiAFStW8PrrrzNy5MjjlucdPXo01apVY//+/XTs2JHLL7+c6tWr5/maaWlp3HbbbXz66ac0bty4QOWDc8eRlpZGly5dePLJJwF45513GDp0KMuWLeOdd95hzpw5JCcn079/f8aNG8cNN9xQyJ+MiETKM89Ao0bQp0/xHzvsJG9mdYF7gRbOuf1mNhHoC7QAZjrnHjezB4AHgMHhHi8IDRs2pHPnzgBHlecF2Lt3LytXrqRr166MGDGC9957D4ANGzawcuXKfJP8vHnz6Nq165FywAUpH5w7jho1atCkSRPmzZtH06ZNWbFiBWeddRYvvvgiCxcupGPHjgDs37+fmjVrhvcDEJEimz8fPv/cNyZLB3D5aaQOWRoob2aH8S34TcAQ4LzQ42OB2YSZ5IOqNJy78Fd+5Xlnz57NjBkzmDt3LhUqVOC8887jwIED+b5mQcoHA0e9xrEFyK666iomTpzIaaedxqWXXoqZ4ZyjX79+/O1vfyv0+xSRyHvmGahcGW6+OZjjh90n75z7EXgKWA9sBnY756YBtZxzm0P7bAbybE6a2e1mtsDMFqSlpYUbTtTlV5539+7dVK1alQoVKrB8+XLmzZt33Nfp0qULn3zyCWvWrAGOLh+8aNEiABYtWnTk8bxcdtllvP/++4wfP56rrroK8GWJJ02adKRk8M6dO1m3bl14b1pEimTdOpg0CW6/HU46KZgYItFdUxXoDTQGdgH/NLMCX7DrnBsFjAJfhTLceKKte/fuLFu2jC6hU+SVKlXirbfe4uKLL+bll1+mdevWNGvW7Ei3Sn5q1KjBqFGjuOyyy8jKyqJmzZpMnz6dyy+/nDfeeIO2bdvSsWNHTj311Hxfo2rVqrRo0YKlS5fSqVMnAFq0aMFf/vIXunfvTlZWFsnJybz44os0bNgwcj8EESmQESN8fZrQmIhAhF1q2MyuBC52zt0Sun8D0BnoBpznnNtsZrWB2c65Zsd7rZJYajge6HcgUni7dkGDBtCzJ7z9dnSPdbxSw5EYQrke6GxmFcx3MncDlgGpQL/QPv2A2CqSLiISRSNHwp49MDjg4SZhd9c45740s0nAIiAD+Brf/VIJmGhmt+A/CK4M91giIvEgPd1f/NSjB7RpE2wsERld45x7GHj4mM0H8a36SLx+niNRJPpiaeYwkXjx6quwfTv86U9BRxIHV7yWK1eOHTt2KNkEwDnHjh07KFeuXNChiMSNQ4dg+HA/69NZZwUdTRzMDFWvXj02btxIPAyvTETlypWjXr16QYchEjfeegs2bvSt+VgQ80k+OTn5yFWhIiKxLDMTHn8c2reH7t2DjsaL+SQvIhIv3n0XVq70F0DFymnEmO+TFxGJB87BX/8Kp53mJwaJFWrJi4hEwJQpsHgxvP46lIqh5nMMhSIiEp+yW/ENGsC11wYdzdHUkhcRCdOsWTBnDvz975CcHHQ0R1NLXkQkDM7BQw9B3bpw661BR/NLasmLiIRh2jT44gtfqyYWrxtUS15EpIiyW/ENGgQ3KciJqCUvIlJEH37op/d75RUoWzboaPKmlryISBFkt+IbN4Z+/U68f1DUkhcRKYLUVFi0yI+Lj7URNbmpJS8iUkhZWfDww/CrX8F1BZ7sNBhqyYuIFNJ77/mrW998E0rHeBZVS15EpBCyW/HNmsHVVwcdzYnF+GeQiEhsGTcOliyB8eMhKSnoaE5MLXkRkQI6cAAefBDOOAN++9ugoykYteRFRAroxRdh/XoYPTq2Kk0eT5yEKSISrJ9+gmHD4KKLoFu3oKMpuIgkeTOrYmaTzGy5mS0zsy5mVs3MppvZytC6aiSOJSIShMcfh1274Ikngo6kcCLVkn8emOqcOw1oAywDHgBmOueaAjND90VE4s6GDfD883D99dCmTdDRFE7YSd7MTga6Aq8BOOcOOed2Ab2BsaHdxgJ9wj2WiEgQHnrIrx97LNg4iiISLfkmQBrwupl9bWavmllFoJZzbjNAaF0zryeb2e1mtsDMFqSlpUUgHBGRyPn2Wxg7Fu65x1ebjDeRSPKlgfbAS865dsA+CtE145wb5Zzr4JzrUKNGjQiEIyISOQ88AJUrw5AhQUdSNJFI8huBjc65L0P3J+GT/lYzqw0QWm+LwLFERIrNzJm+nPCf/gTVqgUdTdGEneSdc1uADWbWLLSpG7AUSAWyC3D2AyaHeywRkeKSkQEDBkCTJr6rJl5F6mKoe4BxZlYGWA3chP8AmWhmtwDrgSsjdCwRkah76SVfvuC992JzWr+CikiSd859A3TI46E4umRARMTbvt2PqLngAujdO+howqMrXkVEjvHgg7Bnjx8bbxZ0NOFRkhcRyeWbb2DUKLj7bmjRIuhowqckLyIS4hzcey9Urw6PPBJ0NJGhKpQiIiETJ8Jnn8E//gFVqgQdTWSoJS8iAqSnw/33Q7t2cMstQUcTOWrJi4jg69Js2OBnfoqHGZ8KSi15ESnxvv0WnnoKbroJzjkn6GgiS0leREq0rCy44w7fBz98eNDRRJ66a0SkRBs1CubO9ZUmq1cPOprIU0teREqszZt9lclf/9pPCJKIlORFpMS67z44cMDXqYn3K1vzoyQvIiXSlCnwzjswdCicemrQ0USPkryIlDjp6dC/P5x2GgwaFHQ00aUTryJS4gwdCmvXwiefQNmyQUcTXWrJi0iJ8umnvrrkXXdB165BRxN9SvIiUmLs3esveGrcGJ54Iuhoioe6a0SkxBg8GNas8d00FSsGHU3xUEteREqEmTNh5EgYODDxShccj5K8iCS8n3+Gm2+GZs1g2LCgoyle6q4RkYT3hz/Axo0wZw6ULx90NMVLLXkRSWgffgivvuprxXfuHHQ0xS9iSd7MkszsazP7T+h+NTObbmYrQ+uqkTqWiEhBbN4MN94IrVrBo48GHU0wItmSHwAsy3X/AWCmc64pMDN0X0SkWGRlwQ03+GGT77yT+Bc95SciSd7M6gGXAK/m2twbGBu6PRboE4ljiYgUxPDhMGMGjBgBzZsHHU1wItWSfw4YBGTl2lbLObcZILSumdcTzex2M1tgZgvS0tIiFI6IlGRffgkPPghXXplY87UWRdhJ3sx6AtuccwuL8nzn3CjnXAfnXIcaNWqEG46IlHC7d8PVV0Pdun5CkEQtIVxQkRhCeRbQy8x6AOWAk83sLWCrmdV2zm02s9rAtggcS0QkX87BnXfC+vW+Rk2VKkFHFLywW/LOuSHOuXrOuUZAX2CWc+46IBXoF9qtHzA53GOJiBzPa6/BhAl+JM2ZZwYdTWyI5jj5x4ELzWwlcGHovohIVHz1Fdx9N1xwgZ/ST7yIXvHqnJsNzA7d3gF0i+Tri4jkZds2uPxyOOUUGD8ekpKCjih2qKyBiMS1jAzo2xfS0nzZgpSUoCOKLUryIhLXhgyBjz+GMWOgffugo4k9ql0jInFr4kR46ik/y1O/fifevyRSkheRuPTdd7588FlnwTPPBB1N7FKSF5G4s3Ur/OY3cNJJ8M9/QpkyQUcUu9QnLyJxJT0devXyI2o++QRq1w46otimJC8icSMrC66/3o+Jf+896NAh6Ihin5K8iMSNBx6Af/0Lnn0WevcOOpr4oD55EYkL//iHLx98110wYEDQ0cQPJXkRiXlTp/rk3qMHPPecKksWhpK8iMS0L77wJQtatvTFx0qrk7lQlORFJGZ9841vvdepAx995IdMSuEoyYtITPr+e+jeHU4+2U/jV6tW0BHFJyV5EYk569f7ksEA06dDw4bBxhPP1LslIjFl61a48EL4+WdfeKxZs6Ajim9K8iISM7Zu9S34DRt8C75du6Ajin9K8iISEzZtgm7dYN06+Pe/feExCZ+SvIgEbsMG+PWvYcsWPya+a9egI0ocSvIiEqg1a3yC37kTpk2DLl2CjiixKMmLSGBWrfIJfu9emDlTBceiIewhlGZW38w+NrNlZrbEzAaEtlczs+lmtjK0rhp+uCKSKBYuhLPPhv37/SgaJfjoiMQ4+QzgD8655kBn4C4zawE8AMx0zjUFZobui4jw4Ydw7rlQrhx89hm0aRN0RIkr7CTvnNvsnFsUur0HWAbUBXoDY0O7jQX6hHssEYl/r77qJ/1o1gzmzYPTTgs6osQW0StezawR0A74EqjlnNsM/oMAqBnJY4lIfHEOHnoIbrvNX+w0ezacckrQUSW+iCV5M6sEvAsMdM79XIjn3W5mC8xsQVpaWqTCEZEYcuAA3HgjPPYY3HILpKaq2FhxiUiSN7NkfIIf55z7V2jzVjOrHXq8NrAtr+c650Y55zo45zrUqFEjEuGISAzZsAHOOQfeeAP+/Gd45RVITg46qpIjEqNrDHgNWOaceybXQ6lAv9DtfsDkcI8lIvHlk0/gjDNgxQp4/334f/9PE34Ut0i05M8Crgd+bWbfhJYewOPAhWa2ErgwdF9ESgDn4PnnfZmC6tVh/nzNyRqUsC+Gcs59DuT32dwt3NcXkfiyZw/87ncwbhz06QNjx/qa8BIM1ZMXkYj58ktfOXL8eH+S9d13leCDpiQvImHLzIRhw3zlyMOHfV/8gw9CKWWYwKl2jYiEZcMGuO46+PRT6NsXXnoJqlQJOirJps9ZESmSrCw/HLJVK1i0yPe9v/22EnysUZIXkUJbsQLOPx9uv933wX/zDdxwg4ZHxiIleREpsEOHfN97mzbw3//6OjSzZsH//E/QkUl+1CcvIgUyaxYMGADffQdXXgkjRqj2TDxQS15Ejuv7733VyG7d/OQekyfDxIlK8PFCSV5E8rRzJwwcCKef7itG/u1vsGyZT/gSP9RdIyJH+fln+Pvf4emnYfduuPVWX1isVq2gI5OiUJIXEcAn9BEj4Nln4aefoGdP+Otf/RBJiV9K8iIl3PbtMHKkT+67dvnumIce8tUjJf4pyYuUUN9+6ytFjhvnJ/Xo3dsn9/btg45MIklJXqQEyciADz7wyf3jj6F8eX8R0733+hOskniU5EVKgCVLfNmBN9+ELVugfn14/HE/32q1akFHJ9GkJC+SoLZsgUmTfHJfsABKl4YePfxcq7/5jb8viU+/ZpEEsno1vPeeX774ws/Q1LYtPPccXH011KwZdIRS3JTkReLYgQM+mc+YAR9+CIsX++1t2sAjj8Bll0HLloGGKAFTkheJIwcPwtdf+9rtM2bAZ5/5RJ+UBF26wFNPwaWXQpMmQUcqsUJJXiRGZWbCDz/4Wu3z5vnl6699JUjwLfQ774QLLoCuXeGkk4KNVwpp717YvBk2bfLrWrV8/eYIU5IXCdihQ7B2rU/oy5b58evffgtLl8L+/X6fChWgY0e47z7o3NkvKhAWY7Ky/KXCaWlHL1u3/nLZvNkn+dwuvzw+k7yZXQw8DyQBrzrnHo/2MUVigXOQng47dvj/602bcpYff4R163xi37DB54dsp5ziSwn87nd+3aaNX2s0TDHIyPDFe3bvzlnv3u2T908/+UuCs2/v3Ol/udnLTz8d/YvMrWpV31KvVcvPstKjB9SpA7Vr56zr1o3KW4rqn42ZJQEvAhcCG4GvzCzVObc0mseV4rN/f07jZMcO/z+Qe9mzxye69HS/b3q670M+fNj/Px0+7JfMTP96zh39+klJfjLopKScJTnZJ7zc6+RkKFPGL7nvH3s79/Oyl1KlchYzv3bO/7/mXh86lLMcPOjX6em+QbZ3L+zb59e7dvmfxc6dfr9jlSrlE3mDBnD22X7Cjezl1FOhRo2o/sriU2am/2Hmtezf7/+oDhzIuZ39R5d7yf4F7duXc3vPnqOX7K9Ox3PyyX6Ow+rV/UUG9evn3K5RI2dJSfHrmjX9H2BAot026ASscs6tBjCzCUBvQEk+DjgHGzf6YXlr1x69bNrkx2H//HP+zy9fHipVgooV/e0KFfxSsWJO8s1OuklJOVPHZa+zE2xm5tFL9odDRob/H9+7198/dChnnX07e8m+HylmjrJlHBXKZlKpfCaVymVQqVwGFctm0LTSITqfcpBqFQ5SreIBqlc4QM1K6dStvJc6J++lZsV9JJHp31z2sjMLtmfBvDw+XU60ZP+wjr1/7Pp4+xZkyR1Pdty5b2dm/nKd+5d27O2MjKM/6Y/95M/9qZpfC7mgypb1f3jZf5DZ6/r1/cmM3Evlyj6RV66cs1Sp4lvjlSvH3VeqaEdbF9iQ6/5G4H+jfEwpgs2bYeFCf2Xk0qW+b3jZsqO7Dc38N8qGDf3Y61NO8d8+s9cpKf5/IXspWzYKgWZlHd0C+/lnv85umeVuUh/TknP70snaf5CM/YfJSD9Exv7DHD6QiTtwkKxDGWQdPOyXQ4cpRRalyMJwR9ZlOHRkSXJZcBC/HOeDrliZHf1JeeynZl6P53U/ryX7a0727dxfe8yO/sqV/Xj2V6/SpY++nZzsP/Wztx/7iZ+9lC2b8/Useylb9pdL+fJQrlzOuly5nBZFhQp+e1JS8f4uYki0k3xe0/oe9YXczG4Hbgdo0KBBlMMR8Dlx7lyYP99fCblgge8jzlanDrRoATffDM2bw69+BY0a+UZPRBJ3Zqbvv9y+Pac/c+fOnL7O3Et2n+ju3Tn9PwWVnHzUP7uVL09ShQoklStH2ZTyUL5aTlLIThi5b+dOLsf2/+TVZ5SdtHIntdzJLXcSzP7qkjsp5pVEC7KIHEe0k/xGoH6u+/WATbl3cM6NAkYBdOjQ4ZgeWYmEPXtgzhw/u8/s2T6pZ/eBN2sG550HHTr4pVUr/4200Pbt818Htm71/TjZ623bckYZZN/+6adfdr5nM8v5apz9laBp06O/Op988i+/Yleq9Muv4snJRXgjIokl2kn+K6CpmTUGfgT6AtdE+ZgCrFoF//63Xz77zHd1JidDp07wwANw7rn+9gkTelaWT8wbN/phINnr3ENFNm3Ku3PezPfh1KzpT0C1bu1vp6T4E1W5l2rV/HLyyb4VKyIREdUk75zLMLO7gY/wQyhHO+eWRPOYJZVz/qKZd96B1FRYscJvb9ECfv97uPBCf0VkxYrHPDEry7e4V6/OOcO6bl3Osn59ztU32cqU8X06der4K3K6d88ZBpa7oz4lpUT3hYrEgqifJnbOfQh8GO3jlFQ//OAnfXj7bZ/Yk5N9K71/fz99W5Mm+L6ZdetgzipYudIvq1b5pL5mjR9yltspp/izq+3b+2vkGzTwHfL16vl1Sopa2yJxIr7GAgngB4uMHw+vvAJffum3nXsu/OGuA1x+2hKqbV7iM/79y2H5cp/Uc48frFjRn01t3hwuucR/EjRu7JeGDf3JRxFJCErycWTFCnjpJRgzxrF7t9Gy7k882XUOfUtNpP7qT+De9Tk7ly7tE/lpp/ni4aee6u83bepb6hqVIVIiKMnHOLdpM1Ne2cgzb6Yw84fGJHOIK5nE7xjJWT/OwbaX9Yn87LP9/G0tWvgWepMmGl0iIkryMcM5P1j9q69g4UIyF37DpLl1+dvu37GYjtRnPX+tOpybO31HrU4Noc190Gq0T+ZxdgWeiBQfZYeg7Nrlr0b68kuf2L/6CrZs4RDJvGn9eCL5BVYeasRpNXcy5pYVXPOH2iRXvz/oqEUkzijJFwfn/AnQOXN8UfC5c33NAOd833izZmRd0J2JSVczdPp5rN5UjjNawbt/gj59qlGqlGZaFpGiUZKPhkOH/KD1zz6Dzz/3yX3HDv9YtWq+GPjVV/t1x47MWliZQYN87ZjWreGDV+D//k/nRkUkfErykXDokK8VkF03YM4cP84R/GiWXr3gnHPgrLP8/VD2XrIE/tgXpk71w8/HjoVrr9X1QyISOUryRZGV5WdMnjHDL59/npPUW7eGW2/187Gdfba/+vMY+/bBn/8Mzzzjy6wMHw53363h6SISeUryBbVxo29yT58Os2b5CoqQU67x/PN9Yk9JOe7LTJ4M997rqwXcdBM8+eQJnyIiUmRK8vk5eNC30KdO9ct33/ntder4qbsuuAC6dfP3C2D9erjnHl9X5vTTfXf92WdHMX4REZTkj7Z1K3zwAfznPzBtmu9XKVPGt9BvvBEuvti33AtxRtQ539d+772+hMyTT8LAgbpOSUSKR8lO8s75s5/vv+9r8s6f77fXrw833OBb7Oefn0fpxoLZtg3uuMO//Dnn+GTfuHHEohcROaGSl+SzsvxY9ffe89l31Sq//X//F/7yF1+6sXXrsMcvpqbCbbf5a56eesq33jVqRkSKW8lI8pmZvhP8n/+Ef/3L109PToZf/xr++Ec/xLF27Ygcav9+GDDAV4hs2xZmzvQl10VEgpC4ST47sU+c6BP71q1+Qt8ePeDyy/26SPPc5W/1arjiCvj6az/70qOP+i59EZGgJFaSd85fNvr2236KpE2bfGK/5BK48kq/LmL/+omkpvpu/FKl/HnbSy6JymFERAolMZL8pk3wj3/4mTRWrvRdMT16+NIBPXtGLbGDnzv1wQfhiSfgjDNg0iRo1ChqhxMRKZTESPLbt8Njj/mRMIMHw2WXQdWqUT/sTz/57plZs/womuee01WrIhJbEiPJt2oFmzfnWUIgWn74wXfJrF4Nr7/uh9GLiMSaxEjyZsWa4D//HPr08acAZszw10qJiMSiUuE82cyGm9lyM/uvmb1nZlVyPTbEzFaZ2QozuyjsSGPEuHG+mkH16n64vRK8iMSysJI8MB1o6ZxrDXwPDAEwsxZAX+B04GJgpJnF9aVAzvnKkdddB2ee6ef9aNo06KhERI4vrCTvnJvmnMsI3Z0H1Avd7g1McM4ddM6tAVYBncI5VpCysvwVqw8/7IdJfvSRn/tDRCTWhduSz+1mYErodl1gQ67HNoa2/YKZ3W5mC8xsQVpaWgTDiYyMDLjlFhgxAu67D8aM0QVOIhI/TpjkzWyGmX2Xx9I71z5DgQxgXPamPF7K5fX6zrlRzrkOzrkONWrUKMp7iJqDB6FvX5/YH3kEnn5aU/KJSHw54ega59wFx3vczPoBPYFuzrnsRL4RqJ9rt3rApqIGGYT0dD/c/qOP4NlnfXeNiEi8CXd0zcXAYKCXcy4910OpQF8zK2tmjYGmwPxwjlWc9u3zE2lPnw6vvaYELyLxK9xx8i8AZYHp5vsx5jnn7nTOLTGzicBSfDfOXc65zDCPVSz27/dFKT//3JfAueqqoCMSESm6sJK8c+5Xx3lsGDAsnNcvbgcP+i6ajz+GN99UgheR+JcYV7xGwOHD8Nvf+ulcX30Vrr026IhERMIXySGUcSsjwyf11FR48UU/ZFJEJBGU+CSflQU33+wnjXr6aejfP+iIREQip8Qn+cGDff/7Y4/B738fdDQiIpFVopP8iBF+ku277oKhQ4OORkQk8kpskn/3XT/+vU8feP55XckqIompRCb5OXP8idbOnf1Y+KS4ro8pIpK/Epfkly+H3/wGGjb0o2nKlw86IhGR6ClRST4tzZcrSE6GKVMgJSXoiEREoqvEXAx1+LCfdHvLFvj0U2jSJOiIRESir8Qk+YEDfXIfNw46dgw6GhGR4lEiumtGjYKRI2HQILjmmqCjEREpPgmf5D/7zI+Dv/hi+Otfg45GRKR4JXSSX78eLr8cGjeG8eM1VFJESp6ETfL798Oll8KBAzB5MlSpEnREIiLFL2FPvA4cCIsW+bHwzZsHHY2ISDASsiU/frw/2TpokL/wSUSkpEq4JP/993D77XDmmfCXvwQdjYhIsBIqyR844Gd3KlMGJkzwV7aKiJRkCdUnf999sHgx/Oc/UL9+0NGIiAQvYVry77wDL78M998Pl1wSdDQiIrEhIknezP5oZs7MUnJtG2Jmq8xshZldFInj5GflSrjtNujSBYYNi+aRRETiS9jdNWZWH7gQWJ9rWwugL3A6UAeYYWanOucywz1eXkqX9gn+lVfUDy8iklskWvLPAoMAl2tbb2CCc+6gc24NsAroFIFj5alxY/joI2jQIFpHEBGJT2EleTPrBfzonFt8zEN1gQ257m8MbcvrNW43swVmtiAtLS2ccERE5Bgn7K4xsxnAKXk8NBT4E9A9r6flsc3lsQ3n3ChgFECHDh3y3EdERIrmhEneOXdBXtvNrBXQGFhsfhbsesAiM+uEb7nnHsRYD9gUdrQiIlIoRe6ucc5965yr6Zxr5JxrhE/s7Z1zW4BUoK+ZlTWzxkBTYH5EIhYRkQKLysVQzrklZjYRWApkAHdFa2SNiIjkL2JJPtSaz31/GKBR6yIiAUqYK15FROSXlORFRBKYORc7oxbNLA1YF8ZLpADbIxROkBLlfYDeSyxKlPcBei/ZGjrnauT1QEwl+XCZ2QLnXIeg4whXorwP0HuJRYnyPkDvpSDUXSMiksCU5EVEEliiJflRQQcQIYnyPkDvJRYlyvsAvZcTSqg+eREROVqiteRFRCSXhEryZvaYmf3XzL4xs2lmVifomIrKzIab2fLQ+3nPzKoEHVNRmdmVZrbEzLLMLO5GQpjZxaEZzlaZ2QNBx1NUZjbazLaZ2XdBxxIuM6tvZh+b2bLQ39aAoGMqCjMrZ2bzzWxx6H08GvFjJFJ3jZmd7Jz7OXT7XqCFc+7OgMMqEjPrDsxyzmWY2RMAzrnBAYdVJGbWHMgC/gH80Tm3IOCQCszMkoDv8bOfbQS+Aq52zi0NNLAiMLOuwF7gDedcy6DjCYeZ1QZqO+cWmdlJwEKgT7z9XsyX8K3onNtrZsnA58AA59y8SB0joVry2Qk+pCL51LCPB865ac65jNDdefhyzXHJObfMObci6DiKqBOwyjm32jl3CJiAn/ks7jjnPgV2Bh1HJDjnNjvnFoVu7wGWkc/ERLHMeXtDd5NDS0TzVkIleQAzG2ZmG4BrgYeCjidCbgamBB1ECVXgWc4kGGbWCGgHfBlwKEViZklm9g2wDZjunIvo+4i7JG9mM8zsuzyW3gDOuaHOufrAOODuYKM9vhO9l9A+Q/HlmscFF+mJFeS9xKkCz3Imxc/MKgHvAgOP+SYfN5xzmc65tvhv653MLKJdaVGpJx9N+c1UlYe3gQ+Ah6MYTlhO9F7MrB/QE+jmYvzkSSF+L/FGs5zFqFAf9rvAOOfcv4KOJ1zOuV1mNhu4GIjYyfG4a8kfj5k1zXW3F7A8qFjCZWYXA4OBXs659KDjKcG+ApqaWWMzKwP0xc98JgEKnbB8DVjmnHsm6HiKysxqZI+cM7PywAVEOG8l2uiad4Fm+JEc64A7nXM/BhtV0ZjZKqAssCO0aV4cjxS6FPg7UAPYBXzjnLso0KAKwcx6AM8BScDo0IQ4ccfMxgPn4asdbgUeds69FmhQRWRmZwOfAd/i/98B/uSc+zC4qArPzFoDY/F/W6WAic65P0f0GImU5EVE5GgJ1V0jIiJHU5IXEUlgSvIiIglMSV5EJIEpyYuIJDAleRGRBKYkLyKSwJTkRUQS2P8H6dIMFIfK5rgAAAAASUVORK5CYII=\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "# 画出更新一次之后的模型\n", + "y_pred = multi_linear(x_train)\n", + "\n", + "plt.plot(x_train.data.numpy()[:, 0], y_pred.data.numpy(), label='fitting curve', color='r')\n", + "plt.plot(x_train.data.numpy()[:, 0], y_sample, label='real curve', color='b')\n", + "plt.legend()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "因为只更新了一次,所以两条曲线之间的差异仍然存在,我们进行 100 次迭代" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "epoch 20, Loss: 65.56586\n", + "epoch 40, Loss: 15.41177\n", + "epoch 60, Loss: 3.70702\n", + "epoch 80, Loss: 0.97122\n", + "epoch 100, Loss: 0.32874\n" + ] + } + ], + "source": [ + "# 进行 100 次参数更新\n", + "for e in range(100):\n", + " y_pred = multi_linear(x_train)\n", + " loss = get_loss(y_pred, y_train)\n", + " \n", + " w.grad.data.zero_()\n", + " b.grad.data.zero_()\n", + " loss.backward()\n", + " \n", + " # 更新参数\n", + " w.data = w.data - 0.001 * w.grad.data\n", + " b.data = b.data - 0.001 * b.grad.data\n", + " if (e + 1) % 20 == 0:\n", + " print('epoch {}, Loss: {:.5f}'.format(e+1, loss.data.item()))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "可以看到更新完成之后 loss 已经非常小了,我们画出更新之后的曲线对比" + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 27, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXkAAAD7CAYAAACPDORaAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8/fFQqAAAACXBIWXMAAAsTAAALEwEAmpwYAAAqjElEQVR4nO3dd3gU5drH8e+dRghFSgLSAhEB6cUQwYIogg0RKYpYEEVsKOpRFDmiR+UIBysqKCICihRBAQu8KFUR1ICoQEB6LyG0UELK3u8fu2DUACG7m9nd3J/r2mt3Z2bnuSfoL0+enXlGVBVjjDGhKczpAowxxviPhbwxxoQwC3ljjAlhFvLGGBPCLOSNMSaEWcgbY0wIy3fIi8hoEdkjIityLSsnIt+IyFrPc9lc6/qLyDoRWSMiV/u6cGOMMWd2Nj35McA1f1v2NDBHVWsBczzvEZF6QDegvuczw0Uk3OtqjTHGnJWI/G6oqgtFpMbfFt8ItPa8HgvMB57yLJ+oqseBjSKyDkgCFp+ujdjYWK1R4+9NGGOMOZ2lS5fuVdW4vNblO+RPoaKq7gRQ1Z0iUsGzvAqwJNd22zzLTqtGjRokJyd7WZIxxhQtIrL5VOv89cWr5LEsz/kTRKS3iCSLSHJqaqqfyjHGmKLJ25DfLSKVADzPezzLtwHVcm1XFdiR1w5UdaSqJqpqYlxcnn9tGGOMKSBvQ34G0MPzugcwPdfybiJSTEQSgFrAT162ZYwx5izle0xeRCbg/pI1VkS2Ac8Bg4HJInIPsAXoCqCqK0VkMrAKyAYeUtWcghSYlZXFtm3byMjIKMjHjZeio6OpWrUqkZGRTpdijCkACaSphhMTE/XvX7xu3LiRUqVKUb58eUTyGuo3/qKqpKWlkZ6eTkJCgtPlGGNOQUSWqmpiXusC/orXjIwMC3iHiAjly5e3v6KMCWIBH/KABbyD7GdvTHALipA3xphQNmwYzJjhn31byOfDsGHDqFu3LrfddhszZsxg8ODBAEybNo1Vq1ad3G7MmDHs2PHnmaK9evX6y3pjjPm7Awegf3+YPv2MmxaIt1e8FgnDhw9n5syZJ7987NChA+AO+fbt21OvXj3AHfINGjSgcuXKAIwaNcqZgnPJyckhPNymDTImUI0ZA0ePwkP15gFX+Hz/1pM/g/vvv58NGzbQoUMHXn/9dcaMGUOfPn344YcfmDFjBk8++SRNmjRhyJAhJCcnc9ttt9GkSROOHTtG69atT07TULJkSQYMGEDjxo1p0aIFu3fvBmD9+vW0aNGC5s2bM3DgQEqWLJlnHePGjaNRo0Y0btyYO+64A4C77rqLKVOmnNzmxGfnz5/PFVdcQffu3WnYsCFPPfUUw4cPP7nd888/z6uvvgrA0KFDad68OY0aNeK5557z/Q/QGHNKLhe881YOF4ctplnKeL+0EVw9+UcfheXLfbvPJk3gjTdOufrdd99l1qxZzJs3j9jYWMaMGQPAxRdfTIcOHWjfvj1dunQBYObMmbzyyiskJv7zTKYjR47QokULBg0aRL9+/Xj//ff597//Td++fenbty+33nor7777bp41rFy5kkGDBrFo0SJiY2PZt2/fGQ/rp59+YsWKFSQkJPDLL7/w6KOP8uCDDwIwefJkZs2axezZs1m7di0//fQTqkqHDh1YuHAhrVq1OuP+jTHemz0b1m0I5wWGwSP9/dKG9eQLSVRUFO3btwfgwgsvZNOmTQAsXryYrl27AtC9e/c8Pzt37ly6dOlCbGwsAOXKlTtje0lJSSeHl5o2bcqePXvYsWMHv/76K2XLliU+Pp7Zs2cze/ZsmjZtSrNmzVi9ejVr16719lCNMfn09lsuKobtoXOrvdCokV/aCK6e/Gl63IEuMjLy5OmI4eHhZGdn5/uzqprnqYwRERG4XK6T22RmZp5cV6JEib9s26VLF6ZMmcKuXbvo1q3byc/079+f++6776yPxxjjnfXr4euZwrM6gqjHHvJbO9aT90KpUqVIT08/5fv8aNGiBVOnTgVg4sSJeW7Tpk0bJk+eTFpaGsDJ4ZoaNWqwdOlSAKZPn05WVtYp2+nWrRsTJ05kypQpJ4eXrr76akaPHs3hw4cB2L59O3v27DnlPowxvjNiBISTw31Vv4YbbvBbOxbyXujWrRtDhw6ladOmrF+/nrvuuov777//5Bev+fHGG2/w2muvkZSUxM6dOznnnHP+sU39+vUZMGAAl19+OY0bN+bxxx8H4N5772XBggUkJSXx448//qP3/vd9pKenU6VKFSpVqgRAu3bt6N69Oy1btqRhw4Z06dLlrH9JGWPO3tGj8MHIbDrpVCo/ejP48Qy4gJ+7JiUlhbp16zpUkf8dPXqU4sWLIyJMnDiRCRMmMN1fJ8wWUKj/GxhT2EaNgnvvhYXF2nLZzslQtuyZP3Qap5u7JrjG5EPQ0qVL6dOnD6pKmTJlGD16tNMlGWP8SBXefiObRrKKS3vW8jrgz8RC3mGXXXYZv/76q9NlGGMKyaJF8OvKCEbyFvLI435vz0LeGGMK0dvDXJSRQ3S/YjcUwjCoffFqjDGFZOtWmDIVeupoSjxeOKcuW8gbY0wheestUJfySI0v4NprC6VNG64xxphCkJ4OI0dk04Wp1PhXZwgrnD629eQLQY0aNdi7d6/TZRhjHPTBB3DwcAT/Kj0KevYstHYt5M+Cqp6cRsDqMMbkV3Y2vPlKJpfyHUmPXgynuXDR1yzkz2DTpk3UrVuXBx98kGbNmrF169ZTTs/bsWNHLrzwQurXr8/IkSPPuO9Zs2bRrFkzGjduTJs2bQD3NMCvvPLKyW0aNGjApk2b/lHHiy++SL9+/U5uN2bMGB5++GEAPv74Y5KSkmjSpAn33XcfOTk5vvpxGGMK4PPPYdP2KB6PfBv69CnUtn0yJi8ijwG9AAV+B3oCMcAkoAawCbhZVfd7044DMw0DsGbNGj788EOGDx9+2ul5R48eTbly5Th27BjNmzenc+fOlC9fPs99pqamcu+997Jw4UISEhLyNX1w7jpSU1Np2bIl//vf/wCYNGkSAwYMICUlhUmTJrFo0SIiIyN58MEHGT9+PHfeeedZ/mSMMb6gCq8OzqQmW+lwdyzExRVq+16HvIhUAR4B6qnqMRGZDHQD6gFzVHWwiDwNPA085W17TqhevTotWrQA+Mv0vACHDx9m7dq1tGrVimHDhvH5558DsHXrVtauXXvKkF+yZAmtWrU6OR1wfqYPzl1HXFwc5513HkuWLKFWrVqsWbOGSy65hHfeeYelS5fSvHlzAI4dO0aFChW8+wEYYwps8WL4cVkUb/M64U88Vujt++rsmgiguIhk4e7B7wD6A60968cC8/Ey5J2aaTj3xF+nmp53/vz5fPvttyxevJiYmBhat25NRkbGKfeZn+mDgb/s4+8TkN1yyy1MnjyZCy64gJtuugkRQVXp0aMHL7/88lkfpzHG914dkkVZOcxdHfbD+ecXevtej8mr6nbgFWALsBM4qKqzgYqqutOzzU4gz+6kiPQWkWQRSU5NTfW2HL871fS8Bw8epGzZssTExLB69WqWLFly2v20bNmSBQsWsHHjRuCv0wcvW7YMgGXLlp1cn5dOnToxbdo0JkyYwC233AK4pyWeMmXKySmD9+3bx+bNm707aGNMgaxfD59/EcEDOpwS/R9xpAZfDNeUBW4EEoADwKcicnt+P6+qI4GR4J6F0tt6/K1du3akpKTQsmVLwH1f1Y8//phrrrmGd999l0aNGlGnTp2TwyqnEhcXx8iRI+nUqRMul4sKFSrwzTff0LlzZ8aNG0eTJk1o3rw5tWvXPuU+ypYtS7169Vi1ahVJSUkA1KtXj5deeol27drhcrmIjIzknXfeoXr16r77IRhj8uXN13OI0BweumgpXDTAkRq8nmpYRLoC16jqPZ73dwItgDZAa1XdKSKVgPmqWud0+yqKUw0HA/s3MObspaVB9SpZdDk+njFfxsH11/utrdNNNeyLUyi3AC1EJEbcg8xtgBRgBtDDs00PILAmSTfGGD8a9qZy5HgkT573WaFNYZAXr4drVPVHEZkCLAOygV9wD7+UBCaLyD24fxF09bYtY4wJBunpMOy1LG7iS+o/26nQpjDIi0/OrlHV54Dn/rb4OO5evS/2n+eZKMb/AunOYcYEi3dHKAeORNG/0li4bYqjtQT8Fa/R0dGkpaVZ2DhAVUlLSyM6OtrpUowJGhkZ7ouf2jKb5s9fD5GRjtYT8LNQVq1alW3bthEMp1eGoujoaKpWrep0GcYEjQ9HK7v3F+OZ2Pehx8dOlxP4IR8ZGXnyqlBjjAlkWVkw5IXjtGQZlw+8HIoVc7qkwA95Y4wJFhMmwObd0bxdZgTS68yTFBYGC3ljjPEBlwteHniURqzl+gFNoHhxp0sCLOSNMcYnpk2D1ZtjmFDqbeSBN5wu5yQLeWOM8ZIq/PffRzifHXR9qmah3hTkTAL+FEpjjAl0X30FS1NK8FTxtwh/+EGny/kL68kbY4wXVGHgk0dJYBc9/hULpUs7XdJfWMgbY4wXpk+HX1bH8GHMq0Q+8V+ny/kHC3ljjCkglwuee+IItdjO7QOqwznnOF3SP1jIG2NMAU2dovy2vgQfl36DiL5DnS4nTxbyxhhTADk58Hy/I9RlC91eqBdQZ9TkZiFvjDEFMHmSsmpzSSaVf5vw+193upxTspA3xpizlJ0Nzz95hIZsoMvLFwbEHDWnYiFvjDFn6ZOPXfyxoySfVXqPsJ5vOl3OaVnIG2PMWcjKgheePkJT1tLx1csgIrBjNLCrM8aYADPqvRzW7y7FFzU+RG4J7F48WMgbY0y+pafD888c5zKSuf7Ndo7euzW/LOSNMSafXh2UwZ70GGY0nYDcMNzpcvIl8H8NGWNMANi1C155TejCp1z0fi8QcbqkfPFJyItIGRGZIiKrRSRFRFqKSDkR+UZE1nqey/qiLWOMccILT6ZzPCuM/3b4ES680Oly8s1XPfk3gVmqegHQGEgBngbmqGotYI7nvTHGBJ01a2Dk+BjuC/+AWm894nQ5Z8XrkBeR0kAr4AMAVc1U1QPAjcBYz2ZjgY7etmWMMU545oH9FNejDHwoDeLjnS7nrPiiJ38ekAp8KCK/iMgoESkBVFTVnQCe5wp5fVhEeotIsogkp6am+qAcY4zxncU/KJ/NK0u/mHeo8OLDTpdz1nwR8hFAM2CEqjYFjnAWQzOqOlJVE1U1MS4uzgflGGOMb6jCk732cS47eXxQ+YC7IUh++CLktwHbVPVHz/spuEN/t4hUAvA87/FBW8YYU2g+m5zNopTyPF/xXUr06el0OQXidcir6i5gq4jU8SxqA6wCZgA9PMt6ANO9bcsYYwrLsWPwrweO0JDfuOe9pICfvuBUfFX1w8B4EYkCNgA9cf8CmSwi9wBbgK4+assYY/zulefS2bz/HOZd9AkRHV52upwC80nIq+pyIDGPVW18sX9jjClMW7bAy69F0TVsCq0/uidoLnzKi13xaowxf9OvZyqa42LoAxuhVi2ny/GKhbwxxuSyYG42k+bG8XTpEVQf8qDT5XgtOL9JMMYYP8jOhkfuOEA8R3hyeELA3rf1bFhP3hhjPN5/LZ3fdsTySsNxxHTv6HQ5PmEhb4wxwL598O+BwuUsoMvELkH9ZWtuFvLGGAM8ddduDh6PZliPpUi9uk6X4zMW8saYIu+7uVmM+qIij5UcRaO37nW6HJ+yL16NMUXa8ePQ+5aDVOcwz39YHUqVcrokn7KevDGmSPvfk6ms3hvL8FaTKNHlWqfL8TkLeWNMkfXHaheD3i7NzVGfc92nwTkB2ZlYyBtjiiRVuP/GnUTrMd4YkgkV8rzlRdCzkDfGFEkfDdvHvD+qMLjOGCr1vdnpcvzGQt4YU+TsTVUe7xdBy7Al9P7ihpA5Jz4vFvLGmCLn4Y5bOJQZzXuPriasVk2ny/ErC3ljTJEyZeQ+Jv5QnYFVRtNwyO1Ol+N3FvLGmCJjz27lgT7hJMpSnv6/K4L2bk9nw0LeGFMkqML912wiPasYY/uvJqJ+nTN/KARYyBtjioRPXtvF58sTePH8cdR78Vanyyk0FvLGmJC3Y2sOfZ6K4eLwJTz+7XUQVnSir+gcqTGmSFKFe9tu5HhOBGOG7CG8elWnSypUPgt5EQkXkV9E5EvP+3Ii8o2IrPU8l/VVW8YYk1+jntvK12vOZ3DjidR6/Aanyyl0vuzJ9wVScr1/GpijqrWAOZ73xhhTaFYlH6HvS7FcFbWQPrM7hPRFT6fik5AXkarA9cCoXItvBMZ6Xo8FOvqiLWOMyY9jR5VubdMoqemM+ySCsAqxTpfkCF/15N8A+gGuXMsqqupOAM9zaM7+Y4wJSE/csJrfD8Qz7rbZVOp8sdPlOMbrkBeR9sAeVV1awM/3FpFkEUlOTU31thxjjOHzYVsZPrcu/6o2mWvGFp3TJfMiqurdDkReBu4AsoFooDTwGdAcaK2qO0WkEjBfVU979UFiYqImJyd7VY8xpmjbsuYYTepnUlM2sGjduURVr+R0SX4nIktVNTGvdV735FW1v6pWVdUaQDdgrqreDswAeng26wFM97YtY4w5nexsuL31VrJywpjw/pEiEfBn4s/z5AcDbUVkLdDW894YY/zm+S4r+G5XbUZ0mMn5d13qdDkBwaez86jqfGC+53Ua0MaX+zfGmFOZ9up6Bk1vwD3nfsntUzs5XU7AsCtejTFBb/WiNO58sgLNo5bz9o9JRWJ2yfyykDfGBLVDaVnc1O4w0XqMqZ+FER1vZ2vnZiFvjAlaLhfclbSKtUerMPmZX6l2fSOnSwo4FvLGmKA15OZkPt/QmKGtv6L1oLZOlxOQLOSNMUFp1hurGTC1Gd0qzuPR2dc7XU7AspA3xgSdFTO3cvPjVWgYtYZRPzVCIu2L1lOxkDfGBJWdK/dxXYdwSnGYL/8vihLx5Z0uKaBZyBtjgsaRtAxuaLGHfdml+eL9XVRrXdPpkgKehbwxJijkZLm4rdHv/HK4FhP7/UKze5o6XVJQsJA3xgSFJ1ouYvqO5rzZcT7th1zmdDlBw0LeGBPwhnVZyBtLL6Nvo3n0+exKp8sJKhbyxpiANq73d/Sd2oqO5y7h1Z9bFclb+HnDQt4YE7A+f2IRd7/fkjblljEhpQnhUeFOlxR0LOSNMQHpmxeX0O3VRJqXWs20VXWILhPtdElByULeGBNwfhiWTMeBDbmg+Ga+/j2ekhVLOF1S0LKQN8YElOVjlnNd3/OpHLWX2cviKFu9tNMlBTULeWNMwFg2Mpk2PeMpFXGMbxfFUPGCsk6XFPQs5I0xAeHHNxbT5r6alIzMYP7CcKonxjldUkiwkDfGOO77/y6k7WP1KRd1mIVLilGzpd34w1cs5I0xjpo7YA5XD7iQytH7WfhLaao3swnHfMlC3hjjmFl9Z3L9fy/mvBK7WbCiPFXqneN0SSHH65AXkWoiMk9EUkRkpYj09SwvJyLfiMhaz7N9g2KMcVNlXIcp3DDsKi4otYN5KZWoWLOk01WFJF/05LOBf6lqXaAF8JCI1AOeBuaoai1gjue9MaaI04zjvNRkCj2+6MLlldcyf0M8sdWKO11WyPI65FV1p6ou87xOB1KAKsCNwFjPZmOBjt62ZYwJbtmp++l93jc8+1tX7mjyG19vqMs5sZFOlxXSfDomLyI1gKbAj0BFVd0J7l8EgH1dbkwRdnjlZjok/M6one0ZcOMKxi5rRFQxm2zM33wW8iJSEpgKPKqqh87ic71FJFlEklNTU31VjjEmgGye8AOtmhxk9pGLee9ff/DStAY2mWQh8UnIi0gk7oAfr6qfeRbvFpFKnvWVgD15fVZVR6pqoqomxsXZxQ/GhBRV5tz/KRd2r80GVw1mvLuT3q/UdrqqIsUXZ9cI8AGQoqqv5Vo1A+jhed0DmO5tW8aY4KHph3mlyce0e68TFUsd4+fkMK67r5rTZRU5ET7YxyXAHcDvIrLcs+wZYDAwWUTuAbYAXX3QljEmCBxZvpZ7Ll/HpEN30KXhaj5cVIeSpWx8xgleh7yqfg+c6l+vjbf7N8YEEVVWvPwFtz5bk1WudgzutY5+Iy+w8XcH+aInb4wx6L79vNN2Gk8su5VzIo8yc2wa7W4/3+myijwLeWOM13Z/toi7b8vg64yeXFdnHaPnJlCxst2qLxDY3DXGmILLyGBm19E06nw+c45fyltPbubLlPMt4AOI9eSNMQWyd9r3PH7nXj5Kv5sGZbczZ1YODZKqO12W+RvryRtjzoqm7WP85SOpe1MdJqRfz4DuG/l5RxUaJMU4XZrJg4W8MSZ/VNk0bAbXVl7O7Qt7U7PyMZb9mM1L4xOIjna6OHMqNlxjjDmjIwuSGXrHb/xvazfCwuDNftt56L/xhNvQe8CzkDfGnJJr81Y+uvVrnlncnh0kcnPieoZOrkF8QhWnSzP5ZMM1xph/OnCAhXeOonlCKnctvo8qlZTvZx9l0s81iU+w7nswsZA3xvzpwAEW3f0B7eKWcflHvUgtHs/4N1NZsq0ql7S1L1aDkQ3XGGNg/36+e+wz/vNxTebk3EOFYgcY+vB2HnypCjGW7UHNQt6YIsy1fiOznviWV76ow7yce6hYbD+vPrqd+1+oQkxMGafLMz5gIW9MUaNK+qxFjHl6NW/91oq13Evl4vt4/aHt9P5PFWJiyjpdofEhC3ljior9+1n56ixGjcxhdOoNHOJSWlTewgvPpNG5d3ki7VarIclC3phQlpND2tT5TBi8mTHLG7NUbyWCLG5O2kzfocVIahXvdIXGzyzkjQk1OTkc+uZHZg5by6R5cXyZcRVZtKFp3Fbe6LGN7v2qEhdnUwAXFRbyxoSCzEz2TPuBGcO38fnic/k28zIyuZiKxfbzcPtN9HiuBo0S7dZ7RZGFvDHBSJXjK9byw8gVfDsrm283JJDsugwX4dSI2UOfqzZyU994WrYpS3i4fZFalFnIGxMMcnI49NNqfpq0kR8XHGPh6gp8l9GcY9QmnGxanLuJZy9fS8e+1WncogIiFZyu2AQIC3ljAo3Lxf5fNrHi6y38vugQy1ZEsWRnPKtcdVHqA1C/1BZ6t1zHVbfG0eqWSpQubWPsJm8W8sY4JPNwJpu/38r6JXtY/+sR1q9TUraVZMXBamzT84DzACgXfoAWVbdyS+JvtOhQgeYdKlGmbDxgZ8aYM/N7yIvINcCbQDgwSlUH+7tNE7iysyEjw/2clfXnc06Oe73qX7cPC4PwcPfjxOvISPcjIsL9HBYgMzCpuo/twH5l39YjpK0/wL4th0nbdozdW46zY7uLHXsi2XEwhu1Hy7IjpyIuagI1ASjOUWqV2EHr2jtp2GAHDVuVpeF11ahSswwiZRw9NhO8/BryIhIOvAO0BbYBP4vIDFVd5c92jX8dPw67dsHu3e7nE6/T0uDAgb8+0tPh6FH349gxJStLfF5PmCiRES6iPI/ICCUyXImKdBEZ7lkXqURGKBHh7nURuR5hooSJizBRBCUMRV2KK8eFK0fRHMWVo2RlwfFMyMwUMrOEzGzhaGYEhzOjOJwdzeGc4rgIBwQo6Xn8qSz7qByZSuWSh7ig+j7iK6+hZr1i1LywDDUvq8y5F5RBxIZdjG/5uyefBKxT1Q0AIjIRuBGwkA9w+/dDSor7sWEDbNr052PHjrw/U7pYBmUij1ImIp0yHKS67qNkzkFKZB8kJusQMTmHiOEoxThOJFlEkH3yOZwcBHc3/sSzIrgII4fwk885hJNFJNlEkEWk+6GRZGVFkpkVRRaRZBJFJlF/rvcsO/G5bCLI8DxnEXmyndyPE6/coe9+jjqxZ8miRFgOUeE5xERlU7JUNiWL51CyhIsSJYQy5cIoXzGCcpWjKVetBOUTShN3QXliEipCWLlC+fcz5gR/h3wVYGuu99uAi3JvICK9gd4A8fE2xljYcnJg9WpIToalS2HlSli1yt07PyE8zEV86YPUKLaTq3Uj1UumUPXwaiqyi3PZRUV2U4E9FMvMghJloXz5Px9lykCpUrkesVCiBBQvDsWKQXS0+xEV9ecYzIlHWBiI/POh+ue4jiq4XH8uc7n+fOS17iQFskAz3e38/XFiTOhEXZGREBPjrtuu/zdBxN8hn9ff5n8ZdVXVkcBIgMTERM1je+ND+/bBd9/BwoXw88+wbBkcOeJeVzI6iwZld3Bt+BrqnvMz9Q7+QF1SiHdtIeKgC6pVg/POg4QEqF4dKl8ElSv/+YiNxe4HZ0xg8XfIbwNyX2ZXFTjFH/vGHw4fhjlzYP589+PXX92d2uioHJrFbeWesktJzPk/mmcspHbGH4Tti4J69eCK+tCgFdR/AGrXdod6sWJOH44x5iz5O+R/BmqJSAKwHegGdPdzm0Xe5s3wxRfw5Zcwbx5kZrpD/ZKK63ihwje03j2R5pk/U2yPQuPG0CEJkvpD8+ZQp471xo0JIX4NeVXNFpE+wP/hPoVytKqu9GebRdWGDTBhAkyaBL//7l5WO3YfD8fN5vod73Nx5vcUSw2DSy6Bh6+FK/4HF15ovXNjQpzfz5NX1a+Br/3dTlG0Zw9MngyffAKLF7uXXVplA6/GTaJ96ofU3rvW3VPveQO0fQ4uushC3Zgixq54DTIuF8ydCyNGwPTp7rNjGlXcxeDYj7l17zDi9+yCK6+EG/rCDTeAnbFkTJFmIR8k9u+HMWPc4b52LZQvkcFjcVPosWswDfasgtatoftA6NwZytqsg8YYNwv5ALdpEwwdCh9+CMeOwcWVNjAw+mW6HPmI6PNqw7/ugm7doGpVp0s1xgQgC/kAtWoVDB4Mn3zivuz+ztiZPHzsGRqnrYauXeHBudCypfviIGOMOQUL+QCzfDn85z8wbRrERGXzSJnxPJ42gKrFImDIg9CzJ8TFOV2mMSZIWMgHiE2b4N//hvHjoUzxDAae8x4PH3yR2EqV4M0hcMst7kv9jTHmLFhqOGzvXhg0CIYPV8I0h6dKvcfT6QMo06Qe9P8Qrr8+cObSNcYEHQt5h2RlwbBh8MILyuHDcFepz/jPwb5UbVoFXv4MrrjCxtuNMV6zkHfAokXwwAPuK1OvLbOEoa5e1K+YDR+8CZ06WbgbY3zGxgEK0d69cM89cOmlcGBDGp/Tka+iO1P/vb7uOX47d7aAN8b4lIV8IVCFceOgTh1l3BgXT0YPY1VGTTo+UQv5Yw307m1fqhpj/MKSxc/27IH77nOfEnlJyd9413UbDZLKw/BFUL++0+UZY0Kc9eT9aPp0aNBA+fqLbIaG9WNBzLU0+Ohp98TuFvDGmEJgPXk/OHQI+vZ1zzXTpPgfzM3pTIPbmsCwFVDO7vFpjCk8FvI+9ssv0KWLsmmjMiBsCANLvE3Ux2+5z5oxxphCZiHvQx98AA89pMRqKgv1Ji656VwY8QtUqOB0acaYIsrG5H3g6FG4+27o1Qsucy1kWcRFXDL2PpgyxQLeGOMo68l7ad066NxZ+e034Vle5LlanxI+dSZccIHTpRljjIW8N+bNg043uQg7ks7XdOPanpXg7SUQE+N0acYYA9hwTYGNHg3t2rqofGQtyeEtuPbDW9wLLeCNMQHEQv4suVzw9NPu6Qmu0Ln8cG5nEn6aBHfd5XRpxhjzD16FvIgMFZHVIvKbiHwuImVyresvIutEZI2IXO11pQHg6FG4+WZlyBC4j3f5KukFzlk6Fxo1cro0Y4zJk7c9+W+ABqraCPgD6A8gIvWAbkB94BpguIiEe9mWo/btgytbu/hsqvIajzHi9h+InDfbzp4xxgQ0r0JeVWerarbn7RLgxN2kbwQmqupxVd0IrAOSvGnLSTt3wuWXZrM8OYvP6Mxjg89Fxo2F6GinSzPGmNPy5dk1dwOTPK+r4A79E7Z5lv2DiPQGegPEx8f7sBzf2LQJrmqdza4tmXwd0YkrJ/a2q1eNMUHjjCEvIt8C5+axaoCqTvdsMwDIBsaf+Fge22te+1fVkcBIgMTExDy3cUpKCrS9IoujqUeYU7wTF301EFq3drosY4zJtzOGvKpedbr1ItIDaA+0UdUTIb0NqJZrs6rAjoIW6YRly+DqNllEHNrHgnNupuG3r0OzZk6XZYwxZ8Xbs2uuAZ4COqjq0VyrZgDdRKSYiCQAtYCfvGmrMC1dCle2yqbEwR18d+7NNFzyvgW8MSYoeTsm/zZQDPhG3LetW6Kq96vqShGZDKzCPYzzkKrmeNlWofj1V2h3RSZlj+5gwfm9iJ/3CVTJ8+sEY4wJeF6FvKqef5p1g4BB3uy/sK1aBW1bZxKTvoe5dR4g/vuJEBvrdFnGGFNgNneNx9q10Oay44Qf2M+cWveT8N04C3hjTNCzkAc2boQrL8kge186C2r2pvb3oyEuzumyjDHGa0U+5HfuhCsvzuBI6lHmJfSm3qL37SpWY0zIKNIhn54O119xlNRdLuZVv4/GP4yAihWdLssYY3ymyIZ8VhZ0ue4Iv60pxhcVetH8hzfh3Lyu+TLGmOBVJENeFXrddozZ35fgg5J9ufa7Z6ByZafLMsYYnyuS88k/2y+DcZ8W5/nIQdw95zaoXdvpkowxxi+KXMi/93YWg16Jppd8wMAZiZAUtJNjGmPMGRWpkP/m/1w8+HA41/EVI8YUR64JiXuZGGPMKRWZkF+3Dm7pmEE9VjLxpfVE3Nnd6ZKMMcbvikTIHzoEN7Y+gGQcY3q3iZR65mGnSzLGmEIR8mfXuFxwx/X7WLO9NLMb9+O8cUNA8pru3hhjQk/I9+Sfe/QgM74vx+uxg7hyzgCIjHS6JGOMKTQhHfKffnSMl946h7sjP6LPgpuhfHmnSzLGmEIVsiG/8ncXd/UUWvIDw6dUQOrVdbokY4wpdCEZ8keOQNc2aZTKOcDU51dQrIOdKmmMKZpCMuQf6rSD1anlGX/FB1QaeK/T5RhjjGNCLuTHvL6fsbMrMzB2BG1m9LUzaYwxRVpIhfzK5Vk8+ERxrghbwLPzroSSJZ0uyRhjHBUyIX9iHL606wCfvLOf8Ab2Rasxxvgk5EXkCRFREYnNtay/iKwTkTUi4vdvPh9qv5nV+yow/sZPOff+jv5uzhhjgoLXV7yKSDWgLbAl17J6QDegPlAZ+FZEaqtqjrft5eWbMdsZO786z1V+nzaTevujCWOMCUq+6Mm/DvQDNNeyG4GJqnpcVTcC6wC/zel7VatMJjYcxLPfXw3FivmrGWOMCTpe9eRFpAOwXVV/lb+exVIFWJLr/TbPMr+Q8xK45bcB/tq9McYErTOGvIh8C+R189MBwDNAu7w+lscyzWMZItIb6A0QHx9/pnKMMcachTOGvKpelddyEWkIJAAnevFVgWUikoS7514t1+ZVgR2n2P9IYCRAYmJinr8IjDHGFEyBx+RV9XdVraCqNVS1Bu5gb6aqu4AZQDcRKSYiCUAt4CefVGyMMSbf/DKfvKquFJHJwCogG3jIX2fWGGOMOTWfhbynN5/7/SBgkK/2b4wx5uyFzBWvxhhj/slC3hhjQpiFvDHGhDBRDZyzFkUkFdjsxS5igb0+KsdJoXIcYMcSiELlOMCO5YTqqhqX14qACnlviUiyqiY6XYe3QuU4wI4lEIXKcYAdS37YcI0xxoQwC3ljjAlhoRbyI50uwEdC5TjAjiUQhcpxgB3LGYXUmLwxxpi/CrWevDHGmFxCKuRF5EUR+U1ElovIbBGp7HRNBSUiQ0Vkted4PheRMk7XVFAi0lVEVoqIS0SC7kwIEbnGcxvLdSLytNP1FJSIjBaRPSKywulavCUi1URknoikeP7b6ut0TQUhItEi8pOI/Oo5jv/4vI1QGq4RkdKqesjz+hGgnqre73BZBSIi7YC5qpotIkMAVPUph8sqEBGpC7iA94AnVDXZ4ZLyTUTCgT9w3+JyG/AzcKuqrnK0sAIQkVbAYWCcqjZwuh5viEgloJKqLhORUsBSoGOw/buIe572Eqp6WEQige+Bvqq65AwfzbeQ6smfCHiPEpziRiXBQFVnq2q25+0S3HPyByVVTVHVNU7XUUBJwDpV3aCqmcBE3Le3DDqquhDY53QdvqCqO1V1med1OpCCH+8+5y/qdtjzNtLz8GluhVTIA4jIIBHZCtwGDHS6Hh+5G5jpdBFFVBVga673fr2VpTl7IlIDaAr86HApBSIi4SKyHNgDfKOqPj2OoAt5EflWRFbk8bgRQFUHqGo1YDzQx9lqT+9Mx+LZZgDuOfnHO1fpmeXnWIJUvm9laQqfiJQEpgKP/u0v+aChqjmq2gT3X+tJIuLToTS/3DTEn051O8I8fAJ8BTznx3K8cqZjEZEeQHugjQb4lydn8e8SbPJ9K0tTuDxj2FOB8ar6mdP1eEtVD4jIfOAawGdfjgddT/50RKRWrrcdgNVO1eItEbkGeArooKpHna6nCPsZqCUiCSISBXTDfXtL4yDPF5YfACmq+prT9RSUiMSdOHNORIoDV+Hj3Aq1s2umAnVwn8mxGbhfVbc7W1XBiMg6oBiQ5lm0JIjPFLoJeAuIAw4Ay1X1akeLOgsich3wBhAOjPbc9SzoiMgEoDXu2Q53A8+p6geOFlVAInIp8B3wO+7/3wGeUdWvnavq7IlII2As7v+2woDJqvqCT9sIpZA3xhjzVyE1XGOMMeavLOSNMSaEWcgbY0wIs5A3xpgQZiFvjDEhzELeGGNCmIW8McaEMAt5Y4wJYf8PReMFzmKrczQAAAAASUVORK5CYII=\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "# 画出更新之后的结果\n", + "y_pred = multi_linear(x_train)\n", + "\n", + "plt.plot(x_train.data.numpy()[:, 0], y_pred.data.numpy(), label='fitting curve', color='r')\n", + "plt.plot(x_train.data.numpy()[:, 0], y_sample, label='real curve', color='b')\n", + "plt.legend()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "可以看到,经过 100 次更新之后,可以看到拟合的线和真实的线已经完全重合了" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "collapsed": true + }, + "source": [ + "## 5. 练习题\n", + "\n", + "上面的例子是一个三次的多项式,尝试使用二次的多项式去拟合它,看看最后能做到多好\n", + "\n", + "**提示:参数 `w = torch.randn(2, 1)`,同时重新构建 x 数据集**" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.7" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/6_pytorch/1_NN/2-logistic-regression.ipynb b/6_pytorch/4-logistic-regression.ipynb similarity index 99% rename from 6_pytorch/1_NN/2-logistic-regression.ipynb rename to 6_pytorch/4-logistic-regression.ipynb index 9110970..eaf2864 100644 --- a/6_pytorch/1_NN/2-logistic-regression.ipynb +++ b/6_pytorch/4-logistic-regression.ipynb @@ -239,7 +239,9 @@ { "cell_type": "code", "execution_count": 14, - "metadata": {}, + "metadata": { + "collapsed": true + }, "outputs": [], "source": [ "np_data = np.array(data, dtype='float32') # 转换成 numpy array\n", @@ -260,7 +262,9 @@ { "cell_type": "code", "execution_count": 15, - "metadata": {}, + "metadata": { + "collapsed": true + }, "outputs": [], "source": [ "# 定义 logistic 回归模型\n", @@ -335,7 +339,9 @@ { "cell_type": "code", "execution_count": 17, - "metadata": {}, + "metadata": { + "collapsed": true + }, "outputs": [], "source": [ "# 计算loss, 使用clamp的目的是防止数据过小而对结果产生较大影响。\n", @@ -476,7 +482,9 @@ { "cell_type": "code", "execution_count": 31, - "metadata": {}, + "metadata": { + "collapsed": true + }, "outputs": [], "source": [ "# 使用 torch.optim 更新参数\n", @@ -616,7 +624,9 @@ { "cell_type": "code", "execution_count": 35, - "metadata": {}, + "metadata": { + "collapsed": true + }, "outputs": [], "source": [ "# 使用自带的loss\n", @@ -704,7 +714,7 @@ ], "metadata": { "kernelspec": { - "display_name": "Python 3", + "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, @@ -718,7 +728,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.7.9" + "version": "3.9.7" } }, "nbformat": 4, diff --git a/6_pytorch/4_GAN/mnist/MNIST/processed/test.pt b/6_pytorch/4_GAN/mnist/MNIST/processed/test.pt deleted file mode 100644 index b1a1249..0000000 Binary files a/6_pytorch/4_GAN/mnist/MNIST/processed/test.pt and /dev/null differ diff --git a/6_pytorch/4_GAN/mnist/MNIST/processed/training.pt b/6_pytorch/4_GAN/mnist/MNIST/processed/training.pt deleted file mode 100644 index d88a4a8..0000000 Binary files a/6_pytorch/4_GAN/mnist/MNIST/processed/training.pt and /dev/null differ diff --git a/6_pytorch/4_GAN/mnist/MNIST/raw/t10k-images-idx3-ubyte b/6_pytorch/4_GAN/mnist/MNIST/raw/t10k-images-idx3-ubyte deleted file mode 100644 index 1170b2c..0000000 Binary files a/6_pytorch/4_GAN/mnist/MNIST/raw/t10k-images-idx3-ubyte and /dev/null differ diff --git a/6_pytorch/4_GAN/mnist/MNIST/raw/t10k-images-idx3-ubyte.gz b/6_pytorch/4_GAN/mnist/MNIST/raw/t10k-images-idx3-ubyte.gz deleted file mode 100644 index 5ace8ea..0000000 Binary files a/6_pytorch/4_GAN/mnist/MNIST/raw/t10k-images-idx3-ubyte.gz and /dev/null differ diff --git a/6_pytorch/4_GAN/mnist/MNIST/raw/t10k-labels-idx1-ubyte b/6_pytorch/4_GAN/mnist/MNIST/raw/t10k-labels-idx1-ubyte deleted file mode 100644 index d1c3a97..0000000 Binary files a/6_pytorch/4_GAN/mnist/MNIST/raw/t10k-labels-idx1-ubyte and /dev/null differ diff --git a/6_pytorch/4_GAN/mnist/MNIST/raw/t10k-labels-idx1-ubyte.gz b/6_pytorch/4_GAN/mnist/MNIST/raw/t10k-labels-idx1-ubyte.gz deleted file mode 100644 index a7e1415..0000000 Binary files a/6_pytorch/4_GAN/mnist/MNIST/raw/t10k-labels-idx1-ubyte.gz and /dev/null differ diff --git a/6_pytorch/4_GAN/mnist/MNIST/raw/train-images-idx3-ubyte b/6_pytorch/4_GAN/mnist/MNIST/raw/train-images-idx3-ubyte deleted file mode 100644 index bbce276..0000000 Binary files a/6_pytorch/4_GAN/mnist/MNIST/raw/train-images-idx3-ubyte and /dev/null differ diff --git a/6_pytorch/4_GAN/mnist/MNIST/raw/train-images-idx3-ubyte.gz b/6_pytorch/4_GAN/mnist/MNIST/raw/train-images-idx3-ubyte.gz deleted file mode 100644 index b50e4b6..0000000 Binary files a/6_pytorch/4_GAN/mnist/MNIST/raw/train-images-idx3-ubyte.gz and /dev/null differ diff --git a/6_pytorch/4_GAN/mnist/MNIST/raw/train-labels-idx1-ubyte b/6_pytorch/4_GAN/mnist/MNIST/raw/train-labels-idx1-ubyte deleted file mode 100644 index d6b4c5d..0000000 Binary files a/6_pytorch/4_GAN/mnist/MNIST/raw/train-labels-idx1-ubyte and /dev/null differ diff --git a/6_pytorch/4_GAN/mnist/MNIST/raw/train-labels-idx1-ubyte.gz b/6_pytorch/4_GAN/mnist/MNIST/raw/train-labels-idx1-ubyte.gz deleted file mode 100644 index 707a576..0000000 Binary files a/6_pytorch/4_GAN/mnist/MNIST/raw/train-labels-idx1-ubyte.gz and /dev/null differ diff --git a/6_pytorch/1_NN/4-deep-nn.ipynb b/6_pytorch/5-deep-nn.ipynb similarity index 99% rename from 6_pytorch/1_NN/4-deep-nn.ipynb rename to 6_pytorch/5-deep-nn.ipynb index ddc5828..8488a8c 100644 --- a/6_pytorch/1_NN/4-deep-nn.ipynb +++ b/6_pytorch/5-deep-nn.ipynb @@ -66,7 +66,9 @@ { "cell_type": "code", "execution_count": 2, - "metadata": {}, + "metadata": { + "collapsed": true + }, "outputs": [], "source": [ "import numpy as np\n", @@ -80,7 +82,9 @@ { "cell_type": "code", "execution_count": 3, - "metadata": {}, + "metadata": { + "collapsed": true + }, "outputs": [], "source": [ "# 使用内置函数下载 mnist 数据集\n", @@ -98,7 +102,9 @@ { "cell_type": "code", "execution_count": 4, - "metadata": {}, + "metadata": { + "collapsed": true + }, "outputs": [], "source": [ "a_data, a_label = train_set[0]" @@ -261,7 +267,9 @@ { "cell_type": "code", "execution_count": 9, - "metadata": {}, + "metadata": { + "collapsed": true + }, "outputs": [], "source": [ "def data_tf(x):\n", @@ -298,7 +306,9 @@ { "cell_type": "code", "execution_count": 11, - "metadata": {}, + "metadata": { + "collapsed": true + }, "outputs": [], "source": [ "from torch.utils.data import DataLoader\n", @@ -317,7 +327,9 @@ { "cell_type": "code", "execution_count": 12, - "metadata": {}, + "metadata": { + "collapsed": true + }, "outputs": [], "source": [ "a, a_label = next(iter(train_data))" @@ -346,7 +358,9 @@ { "cell_type": "code", "execution_count": 14, - "metadata": {}, + "metadata": { + "collapsed": true + }, "outputs": [], "source": [ "# 使用 Sequential 定义 4 层神经网络\n", @@ -399,7 +413,9 @@ { "cell_type": "code", "execution_count": 16, - "metadata": {}, + "metadata": { + "collapsed": true + }, "outputs": [], "source": [ "# 定义 loss 函数\n", @@ -495,7 +511,9 @@ { "cell_type": "code", "execution_count": null, - "metadata": {}, + "metadata": { + "collapsed": true + }, "outputs": [], "source": [ "import matplotlib.pyplot as plt\n", @@ -653,7 +671,7 @@ ], "metadata": { "kernelspec": { - "display_name": "Python 3 (ipykernel)", + "display_name": "Python 3", "language": "python", "name": "python3" }, @@ -667,7 +685,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.7" + "version": "3.5.4" } }, "nbformat": 4, diff --git a/6_pytorch/1_NN/3-nn-sequential-module.ipynb b/6_pytorch/5-nn-sequential-module.ipynb similarity index 99% rename from 6_pytorch/1_NN/3-nn-sequential-module.ipynb rename to 6_pytorch/5-nn-sequential-module.ipynb index ef29018..d37a416 100644 --- a/6_pytorch/1_NN/3-nn-sequential-module.ipynb +++ b/6_pytorch/5-nn-sequential-module.ipynb @@ -41,7 +41,9 @@ { "cell_type": "code", "execution_count": 1, - "metadata": {}, + "metadata": { + "collapsed": true + }, "outputs": [], "source": [ "import torch\n", @@ -87,7 +89,9 @@ { "cell_type": "code", "execution_count": 4, - "metadata": {}, + "metadata": { + "collapsed": true + }, "outputs": [], "source": [ "def plot_decision_boundary(model, x, y):\n", @@ -117,7 +121,9 @@ { "cell_type": "code", "execution_count": 5, - "metadata": {}, + "metadata": { + "collapsed": true + }, "outputs": [], "source": [ "# 变量\n", @@ -171,7 +177,9 @@ { "cell_type": "code", "execution_count": 7, - "metadata": {}, + "metadata": { + "collapsed": true + }, "outputs": [], "source": [ "def plot_logistic(x):\n", @@ -234,7 +242,9 @@ { "cell_type": "code", "execution_count": 9, - "metadata": {}, + "metadata": { + "collapsed": true + }, "outputs": [], "source": [ "# 定义两层神经网络的参数\n", @@ -293,7 +303,9 @@ { "cell_type": "code", "execution_count": 11, - "metadata": {}, + "metadata": { + "collapsed": true + }, "outputs": [], "source": [ "def plot_network(x):\n", @@ -388,7 +400,9 @@ { "cell_type": "code", "execution_count": 13, - "metadata": {}, + "metadata": { + "collapsed": true + }, "outputs": [], "source": [ "# Sequential\n", @@ -448,7 +462,9 @@ { "cell_type": "code", "execution_count": 16, - "metadata": {}, + "metadata": { + "collapsed": true + }, "outputs": [], "source": [ "# 通过 parameters 可以取得模型的参数\n", @@ -502,7 +518,9 @@ { "cell_type": "code", "execution_count": 18, - "metadata": {}, + "metadata": { + "collapsed": true + }, "outputs": [], "source": [ "def plot_seq(x):\n", @@ -556,7 +574,9 @@ { "cell_type": "code", "execution_count": 19, - "metadata": {}, + "metadata": { + "collapsed": true + }, "outputs": [], "source": [ "# 将参数和模型保存在一起\n", @@ -573,7 +593,9 @@ { "cell_type": "code", "execution_count": 20, - "metadata": {}, + "metadata": { + "collapsed": true + }, "outputs": [], "source": [ "# 读取保存的模型\n", @@ -637,7 +659,9 @@ { "cell_type": "code", "execution_count": 23, - "metadata": {}, + "metadata": { + "collapsed": true + }, "outputs": [], "source": [ "# 保存模型参数\n", @@ -769,7 +793,9 @@ { "cell_type": "code", "execution_count": 27, - "metadata": {}, + "metadata": { + "collapsed": true + }, "outputs": [], "source": [ "class module_net(nn.Module):\n", @@ -791,7 +817,9 @@ { "cell_type": "code", "execution_count": 28, - "metadata": {}, + "metadata": { + "collapsed": true + }, "outputs": [], "source": [ "mo_net = module_net(2, 4, 1)" @@ -843,7 +871,9 @@ { "cell_type": "code", "execution_count": 31, - "metadata": {}, + "metadata": { + "collapsed": true + }, "outputs": [], "source": [ "# 定义优化器\n", @@ -887,7 +917,9 @@ { "cell_type": "code", "execution_count": 33, - "metadata": {}, + "metadata": { + "collapsed": true + }, "outputs": [], "source": [ "# 保存模型\n", @@ -920,7 +952,9 @@ { "cell_type": "code", "execution_count": 34, - "metadata": {}, + "metadata": { + "collapsed": true + }, "outputs": [], "source": [ "net = nn.Sequential(\n", @@ -1021,7 +1055,7 @@ ], "metadata": { "kernelspec": { - "display_name": "Python 3 (ipykernel)", + "display_name": "Python 3", "language": "python", "name": "python3" }, @@ -1035,7 +1069,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.7" + "version": "3.5.4" } }, "nbformat": 4, diff --git a/6_pytorch/1_NN/5-param_initialize.ipynb b/6_pytorch/6-param_initialize.ipynb similarity index 96% rename from 6_pytorch/1_NN/5-param_initialize.ipynb rename to 6_pytorch/6-param_initialize.ipynb index b85c461..5415b7c 100644 --- a/6_pytorch/1_NN/5-param_initialize.ipynb +++ b/6_pytorch/6-param_initialize.ipynb @@ -26,7 +26,9 @@ { "cell_type": "code", "execution_count": 1, - "metadata": {}, + "metadata": { + "collapsed": true + }, "outputs": [], "source": [ "import numpy as np\n", @@ -37,7 +39,9 @@ { "cell_type": "code", "execution_count": 2, - "metadata": {}, + "metadata": { + "collapsed": true + }, "outputs": [], "source": [ "# 定义一个 Sequential 模型\n", @@ -53,7 +57,9 @@ { "cell_type": "code", "execution_count": 3, - "metadata": {}, + "metadata": { + "collapsed": true + }, "outputs": [], "source": [ "# 访问第一层的参数\n", @@ -96,7 +102,9 @@ { "cell_type": "code", "execution_count": 5, - "metadata": {}, + "metadata": { + "collapsed": true + }, "outputs": [], "source": [ "# 定义一个 Tensor 直接对其进行替换\n", @@ -138,7 +146,9 @@ { "cell_type": "code", "execution_count": 7, - "metadata": {}, + "metadata": { + "collapsed": true + }, "outputs": [], "source": [ "for layer in net1:\n", @@ -173,7 +183,9 @@ { "cell_type": "code", "execution_count": 8, - "metadata": {}, + "metadata": { + "collapsed": true + }, "outputs": [], "source": [ "class sim_net(nn.Module):\n", @@ -206,7 +218,9 @@ { "cell_type": "code", "execution_count": 9, - "metadata": {}, + "metadata": { + "collapsed": true + }, "outputs": [], "source": [ "net2 = sim_net()" @@ -304,7 +318,9 @@ { "cell_type": "code", "execution_count": 12, - "metadata": {}, + "metadata": { + "collapsed": true + }, "outputs": [], "source": [ "for layer in net2.modules():\n", @@ -331,7 +347,9 @@ { "cell_type": "code", "execution_count": 13, - "metadata": {}, + "metadata": { + "collapsed": true + }, "outputs": [], "source": [ "from torch.nn import init" @@ -436,7 +454,7 @@ ], "metadata": { "kernelspec": { - "display_name": "Python 3 (ipykernel)", + "display_name": "Python 3", "language": "python", "name": "python3" }, @@ -450,7 +468,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.7" + "version": "3.5.4" } }, "nbformat": 4, diff --git a/6_pytorch/backup/PyTorch_quick_intro.ipynb b/6_pytorch/backup/PyTorch_quick_intro.ipynb deleted file mode 100644 index a2df72f..0000000 --- a/6_pytorch/backup/PyTorch_quick_intro.ipynb +++ /dev/null @@ -1,1471 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# PyTorch快速入门\n", - "\n", - "PyTorch的简洁设计使得它入门很简单,在深入介绍PyTorch之前,本节将先介绍一些PyTorch的基础知识,使得读者能够对PyTorch有一个大致的了解,并能够用PyTorch搭建一个简单的神经网络。部分内容读者可能暂时不太理解,可先不予以深究,后续的课程将会对此进行深入讲解。\n", - "\n", - "本节内容参考了PyTorch官方教程[^1]并做了相应的增删修改,使得内容更贴合新版本的PyTorch接口,同时也更适合新手快速入门。另外本书需要读者先掌握基础的Numpy使用,其他相关知识推荐读者参考CS231n的教程[^2]。\n", - "\n", - "[^1]: http://pytorch.org/tutorials/beginner/deep_learning_60min_blitz.html\n", - "[^2]: http://cs231n.github.io/python-numpy-tutorial/" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 1. Tensor\n", - "\n", - "Tensor是PyTorch中重要的数据结构,可认为是一个高维数组。它可以是一个数(标量)、一维数组(向量)、二维数组(矩阵)以及更高维的数组。Tensor和Numpy的ndarrays类似,但Tensor可以使用GPU进行加速。Tensor的使用和Numpy及Matlab的接口十分相似,下面通过几个例子来看看Tensor的基本使用。" - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "metadata": {}, - "outputs": [], - "source": [ - "from __future__ import print_function\n", - "import torch as t" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "tensor([[0., 0., 0.],\n", - " [0., 0., 0.],\n", - " [0., 0., 0.],\n", - " [0., 0., 0.],\n", - " [0., 0., 0.]])" - ] - }, - "execution_count": 2, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "# 构建 5x3 矩阵,只是分配了空间,未初始化\n", - "x = t.Tensor(5, 3) \n", - "x" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "tensor([[0.3807, 0.4897, 0.0356],\n", - " [0.6701, 0.0606, 0.1818],\n", - " [0.8798, 0.7115, 0.8265],\n", - " [0.4094, 0.2264, 0.2041],\n", - " [0.9088, 0.9256, 0.3438]])" - ] - }, - "execution_count": 3, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "# 使用[0,1]均匀分布随机初始化二维数组\n", - "x = t.rand(5, 3) \n", - "x" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "torch.Size([5, 3])\n" - ] - }, - { - "data": { - "text/plain": [ - "(3, 3)" - ] - }, - "execution_count": 4, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "print(x.size()) # 查看x的形状\n", - "x.size()[1], x.size(1) # 查看列的个数, 两种写法等价" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "`torch.Size` 是tuple对象的子类,因此它支持tuple的所有操作,如x.size()[0]" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "tensor([[1.1361, 1.4054, 0.9468],\n", - " [1.6410, 0.5193, 0.3720],\n", - " [0.9482, 1.6716, 1.4168],\n", - " [1.3925, 0.9253, 0.2908],\n", - " [1.4907, 1.7178, 0.7246]])" - ] - }, - "execution_count": 5, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "y = t.rand(5, 3)\n", - "# 加法的第一种写法\n", - "x + y" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "tensor([[1.1361, 1.4054, 0.9468],\n", - " [1.6410, 0.5193, 0.3720],\n", - " [0.9482, 1.6716, 1.4168],\n", - " [1.3925, 0.9253, 0.2908],\n", - " [1.4907, 1.7178, 0.7246]])" - ] - }, - "execution_count": 6, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "# 加法的第二种写法\n", - "t.add(x, y)" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "tensor([[1.1361, 1.4054, 0.9468],\n", - " [1.6410, 0.5193, 0.3720],\n", - " [0.9482, 1.6716, 1.4168],\n", - " [1.3925, 0.9253, 0.2908],\n", - " [1.4907, 1.7178, 0.7246]])" - ] - }, - "execution_count": 7, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "# 加法的第三种写法:指定加法结果的输出目标为result\n", - "result = t.Tensor(5, 3) # 预先分配空间\n", - "t.add(x, y, out=result) # 输入到result\n", - "result" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "最初y\n", - "tensor([[0.7554, 0.9157, 0.9113],\n", - " [0.9709, 0.4587, 0.1902],\n", - " [0.0684, 0.9601, 0.5903],\n", - " [0.9831, 0.6989, 0.0867],\n", - " [0.5819, 0.7923, 0.3808]])\n", - "第一种加法,y的结果\n", - "tensor([[0.7554, 0.9157, 0.9113],\n", - " [0.9709, 0.4587, 0.1902],\n", - " [0.0684, 0.9601, 0.5903],\n", - " [0.9831, 0.6989, 0.0867],\n", - " [0.5819, 0.7923, 0.3808]])\n", - "第二种加法,y的结果\n", - "tensor([[1.1361, 1.4054, 0.9468],\n", - " [1.6410, 0.5193, 0.3720],\n", - " [0.9482, 1.6716, 1.4168],\n", - " [1.3925, 0.9253, 0.2908],\n", - " [1.4907, 1.7178, 0.7246]])\n" - ] - } - ], - "source": [ - "print('最初y')\n", - "print(y)\n", - "\n", - "print('第一种加法,y的结果')\n", - "y.add(x) # 普通加法,不改变y的内容\n", - "print(y)\n", - "\n", - "print('第二种加法,y的结果')\n", - "y.add_(x) # inplace 加法,y变了\n", - "print(y)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "注意,函数名后面带下划线**`_`** 的函数会修改Tensor本身。例如,`x.add_(y)`和`x.t_()`会改变 `x`,但`x.add(y)`和`x.t()`返回一个新的Tensor, 而`x`不变。" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "tensor([0.4897, 0.0606, 0.7115, 0.2264, 0.9256])" - ] - }, - "execution_count": 9, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "# Tensor的选取操作与Numpy类似\n", - "x[:, 1]" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Tensor还支持很多操作,包括数学运算、线性代数、选择、切片等等,其接口设计与Numpy极为相似。更详细的使用方法,会在第三章系统讲解。\n", - "\n", - "Tensor和Numpy的数组之间的互操作非常容易且快速。对于Tensor不支持的操作,可以先转为Numpy数组处理,之后再转回Tensor。" - ] - }, - { - "cell_type": "code", - "execution_count": 11, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "tensor([1., 1., 1., 1., 1.])" - ] - }, - "execution_count": 11, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "a = t.ones(5) # 新建一个全1的Tensor\n", - "a" - ] - }, - { - "cell_type": "code", - "execution_count": 12, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "array([1., 1., 1., 1., 1.], dtype=float32)" - ] - }, - "execution_count": 12, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "b = a.numpy() # Tensor -> Numpy\n", - "b" - ] - }, - { - "cell_type": "code", - "execution_count": 13, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[1. 1. 1. 1. 1.]\n", - "tensor([1., 1., 1., 1., 1.], dtype=torch.float64)\n" - ] - } - ], - "source": [ - "import numpy as np\n", - "a = np.ones(5)\n", - "b = t.from_numpy(a) # Numpy->Tensor\n", - "print(a)\n", - "print(b) " - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Tensor和numpy对象共享内存,所以他们之间的转换很快,而且几乎不会消耗什么资源。但这也意味着,如果其中一个变了,另外一个也会随之改变。" - ] - }, - { - "cell_type": "code", - "execution_count": 14, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[2. 2. 2. 2. 2.]\n", - "tensor([2., 2., 2., 2., 2.], dtype=torch.float64)\n" - ] - } - ], - "source": [ - "b.add_(1) # 以`_`结尾的函数会修改自身\n", - "print(a)\n", - "print(b) # Tensor和Numpy共享内存" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Tensor可通过`.cuda` 方法转为GPU的Tensor,从而享受GPU带来的加速运算。" - ] - }, - { - "cell_type": "code", - "execution_count": 15, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "tensor([[1.5168, 1.8951, 0.9824],\n", - " [2.3111, 0.5800, 0.5538],\n", - " [1.8280, 2.3831, 2.2433],\n", - " [1.8020, 1.1518, 0.4949],\n", - " [2.3995, 2.6434, 1.0684]], device='cuda:0')\n" - ] - } - ], - "source": [ - "# 在不支持CUDA的机器下,下一步不会运行\n", - "if t.cuda.is_available():\n", - " x = x.cuda()\n", - " y = y.cuda()\n", - " x + y\n", - "print(x+y)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "此处可能发现GPU运算的速度并未提升太多,这是因为x和y太小且运算也较为简单,而且将数据从内存转移到显存还需要花费额外的开销。GPU的优势需在大规模数据和复杂运算下才能体现出来。\n" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 2. Autograd: 自动微分\n", - "\n", - "深度学习的算法本质上是通过反向传播求导数,而PyTorch的**`Autograd`**模块则实现了此功能。在Tensor上的所有操作,Autograd都能为它们自动提供微分,避免了手动计算导数的复杂过程。\n", - " \n", - "`autograd.Variable`是Autograd中的核心类,它简单封装了Tensor,并支持几乎所有Tensor有的操作。Tensor在被封装为Variable之后,可以调用它的`.backward`实现反向传播,自动计算所有梯度。Variable的数据结构如图2-6所示。\n", - "\n", - "\n", - "![图2-6:Variable的数据结构](imgs/autograd_Variable.svg)\n", - "\n", - "\n", - "Variable主要包含三个属性。\n", - "- `data`:保存Variable所包含的Tensor\n", - "- `grad`:保存`data`对应的梯度,`grad`也是个Variable,而不是Tensor,它和`data`的形状一样。\n", - "- `grad_fn`:指向一个`Function`对象,这个`Function`用来反向传播计算输入的梯度,具体细节会在下一章讲解。" - ] - }, - { - "cell_type": "code", - "execution_count": 16, - "metadata": {}, - "outputs": [], - "source": [ - "from torch.autograd import Variable" - ] - }, - { - "cell_type": "code", - "execution_count": 17, - "metadata": { - "scrolled": true - }, - "outputs": [ - { - "data": { - "text/plain": [ - "tensor([[1., 1.],\n", - " [1., 1.]], requires_grad=True)" - ] - }, - "execution_count": 17, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "# 使用Tensor新建一个Variable\n", - "x = Variable(t.ones(2, 2), requires_grad = True)\n", - "x" - ] - }, - { - "cell_type": "code", - "execution_count": 18, - "metadata": { - "scrolled": true - }, - "outputs": [ - { - "data": { - "text/plain": [ - "tensor(4., grad_fn=)" - ] - }, - "execution_count": 18, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "y = x.sum()\n", - "y" - ] - }, - { - "cell_type": "code", - "execution_count": 19, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "" - ] - }, - "execution_count": 19, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "y.grad_fn" - ] - }, - { - "cell_type": "code", - "execution_count": 20, - "metadata": {}, - "outputs": [], - "source": [ - "y.backward() # 反向传播,计算梯度" - ] - }, - { - "cell_type": "code", - "execution_count": 21, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "tensor([[1., 1.],\n", - " [1., 1.]])" - ] - }, - "execution_count": 21, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "# y = x.sum() = (x[0][0] + x[0][1] + x[1][0] + x[1][1])\n", - "# 每个值的梯度都为1\n", - "x.grad " - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "注意:`grad`在反向传播过程中是累加的(accumulated),**这意味着每一次运行反向传播,梯度都会累加之前的梯度,所以反向传播之前需把梯度清零。**" - ] - }, - { - "cell_type": "code", - "execution_count": 22, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "tensor([[2., 2.],\n", - " [2., 2.]])" - ] - }, - "execution_count": 22, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "y.backward()\n", - "x.grad" - ] - }, - { - "cell_type": "code", - "execution_count": 23, - "metadata": { - "scrolled": true - }, - "outputs": [ - { - "data": { - "text/plain": [ - "tensor([[3., 3.],\n", - " [3., 3.]])" - ] - }, - "execution_count": 23, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "y.backward()\n", - "x.grad" - ] - }, - { - "cell_type": "code", - "execution_count": 25, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "tensor([[0., 0.],\n", - " [0., 0.]])" - ] - }, - "execution_count": 25, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "# 以下划线结束的函数是inplace操作,就像add_\n", - "x.grad.data.zero_()" - ] - }, - { - "cell_type": "code", - "execution_count": 26, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "tensor([[1., 1.],\n", - " [1., 1.]])" - ] - }, - "execution_count": 26, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "y.backward()\n", - "x.grad" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Variable和Tensor具有近乎一致的接口,在实际使用中可以无缝切换。" - ] - }, - { - "cell_type": "code", - "execution_count": 28, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "tensor([[0.5403, 0.5403, 0.5403, 0.5403, 0.5403],\n", - " [0.5403, 0.5403, 0.5403, 0.5403, 0.5403],\n", - " [0.5403, 0.5403, 0.5403, 0.5403, 0.5403],\n", - " [0.5403, 0.5403, 0.5403, 0.5403, 0.5403]])\n" - ] - }, - { - "data": { - "text/plain": [ - "tensor([[0.5403, 0.5403, 0.5403, 0.5403, 0.5403],\n", - " [0.5403, 0.5403, 0.5403, 0.5403, 0.5403],\n", - " [0.5403, 0.5403, 0.5403, 0.5403, 0.5403],\n", - " [0.5403, 0.5403, 0.5403, 0.5403, 0.5403]])" - ] - }, - "execution_count": 28, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "x = Variable(t.ones(4,5))\n", - "y = t.cos(x)\n", - "x_tensor_cos = t.cos(x.data)\n", - "print(y)\n", - "x_tensor_cos" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 3. 神经网络 (FIXME)\n", - "\n", - "Autograd实现了反向传播功能,但是直接用来写深度学习的代码在很多情况下还是稍显复杂,torch.nn是专门为神经网络设计的模块化接口。nn构建于 Autograd之上,可用来定义和运行神经网络。nn.Module是nn中最重要的类,可把它看成是一个网络的封装,包含网络各层定义以及forward方法,调用forward(input)方法,可返回前向传播的结果。下面就以最早的卷积神经网络:LeNet为例,来看看如何用`nn.Module`实现。LeNet的网络结构如图2-7所示。\n", - "\n", - "![图2-7:LeNet网络结构](imgs/nn_lenet.png)\n", - "\n", - "这是一个基础的前向传播(feed-forward)网络: 接收输入,经过层层传递运算,得到输出。\n", - "\n", - "### 3.1 定义网络\n", - "\n", - "定义网络时,需要继承`nn.Module`,并实现它的forward方法,把网络中具有可学习参数的层放在构造函数`__init__`中。如果某一层(如ReLU)不具有可学习的参数,则既可以放在构造函数中,也可以不放,但建议不放在其中,而在forward中使用`nn.functional`代替。" - ] - }, - { - "cell_type": "code", - "execution_count": 25, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Net(\n", - " (conv1): Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1))\n", - " (conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))\n", - " (fc1): Linear(in_features=400, out_features=120, bias=True)\n", - " (fc2): Linear(in_features=120, out_features=84, bias=True)\n", - " (fc3): Linear(in_features=84, out_features=10, bias=True)\n", - ")\n" - ] - } - ], - "source": [ - "import torch.nn as nn\n", - "import torch.nn.functional as F\n", - "\n", - "class Net(nn.Module):\n", - " def __init__(self):\n", - " # nn.Module子类的函数必须在构造函数中执行父类的构造函数\n", - " # 下式等价于nn.Module.__init__(self)\n", - " super(Net, self).__init__()\n", - " \n", - " # 卷积层 '1'表示输入图片为单通道, '6'表示输出通道数,'5'表示卷积核为5*5\n", - " self.conv1 = nn.Conv2d(1, 6, 5) \n", - " # 卷积层\n", - " self.conv2 = nn.Conv2d(6, 16, 5) \n", - " # 仿射层/全连接层,y = Wx + b\n", - " self.fc1 = nn.Linear(16*5*5, 120) \n", - " self.fc2 = nn.Linear(120, 84)\n", - " self.fc3 = nn.Linear(84, 10)\n", - "\n", - " def forward(self, x): \n", - " # 卷积 -> 激活 -> 池化 \n", - " x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))\n", - " x = F.max_pool2d(F.relu(self.conv2(x)), 2) \n", - " # reshape,‘-1’表示自适应\n", - " x = x.view(x.size()[0], -1) \n", - " x = F.relu(self.fc1(x))\n", - " x = F.relu(self.fc2(x))\n", - " x = self.fc3(x) \n", - " return x\n", - "\n", - "net = Net()\n", - "print(net)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "只要在nn.Module的子类中定义了forward函数,backward函数就会自动被实现(利用`Autograd`)。在`forward` 函数中可使用任何Variable支持的函数,还可以使用if、for循环、print、log等Python语法,写法和标准的Python写法一致。\n", - "\n", - "网络的可学习参数通过`net.parameters()`返回,`net.named_parameters`可同时返回可学习的参数及名称。" - ] - }, - { - "cell_type": "code", - "execution_count": 26, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "10\n" - ] - } - ], - "source": [ - "params = list(net.parameters())\n", - "print(len(params))" - ] - }, - { - "cell_type": "code", - "execution_count": 27, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "conv1.weight : torch.Size([6, 1, 5, 5])\n", - "conv1.bias : torch.Size([6])\n", - "conv2.weight : torch.Size([16, 6, 5, 5])\n", - "conv2.bias : torch.Size([16])\n", - "fc1.weight : torch.Size([120, 400])\n", - "fc1.bias : torch.Size([120])\n", - "fc2.weight : torch.Size([84, 120])\n", - "fc2.bias : torch.Size([84])\n", - "fc3.weight : torch.Size([10, 84])\n", - "fc3.bias : torch.Size([10])\n" - ] - } - ], - "source": [ - "for name,parameters in net.named_parameters():\n", - " print(name,':',parameters.size())" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "forward函数的输入和输出都是Variable,只有Variable才具有自动求导功能,而Tensor是没有的,所以在输入时,需把Tensor封装成Variable。" - ] - }, - { - "cell_type": "code", - "execution_count": 28, - "metadata": { - "scrolled": true - }, - "outputs": [ - { - "data": { - "text/plain": [ - "torch.Size([1, 10])" - ] - }, - "execution_count": 28, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "input = Variable(t.randn(1, 1, 32, 32))\n", - "out = net(input)\n", - "out.size()" - ] - }, - { - "cell_type": "code", - "execution_count": 29, - "metadata": {}, - "outputs": [], - "source": [ - "net.zero_grad() # 所有参数的梯度清零\n", - "out.backward(Variable(t.ones(1,10))) # 反向传播" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "需要注意的是,torch.nn只支持mini-batches,不支持一次只输入一个样本,即一次必须是一个batch。但如果只想输入一个样本,则用 `input.unsqueeze(0)`将batch_size设为1。例如 `nn.Conv2d` 输入必须是4维的,形如$nSamples \\times nChannels \\times Height \\times Width$。可将nSample设为1,即$1 \\times nChannels \\times Height \\times Width$。" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### 3.2 损失函数\n", - "\n", - "nn实现了神经网络中大多数的损失函数,例如nn.MSELoss用来计算均方误差,nn.CrossEntropyLoss用来计算交叉熵损失。" - ] - }, - { - "cell_type": "code", - "execution_count": 30, - "metadata": { - "scrolled": true - }, - "outputs": [ - { - "data": { - "text/plain": [ - "tensor(28.6268, grad_fn=)" - ] - }, - "execution_count": 30, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "\n", - "output = net(input)\n", - "target = Variable(t.arange(0,10).float().unsqueeze(0)) \n", - "criterion = nn.MSELoss()\n", - "loss = criterion(output, target)\n", - "loss" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "如果对loss进行反向传播溯源(使用`gradfn`属性),可看到它的计算图如下:\n", - "\n", - "```\n", - "input -> conv2d -> relu -> maxpool2d -> conv2d -> relu -> maxpool2d \n", - " -> view -> linear -> relu -> linear -> relu -> linear \n", - " -> MSELoss\n", - " -> loss\n", - "```\n", - "\n", - "当调用`loss.backward()`时,该图会动态生成并自动微分,也即会自动计算图中参数(Parameter)的导数。" - ] - }, - { - "cell_type": "code", - "execution_count": 31, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "反向传播之前 conv1.bias的梯度\n", - "tensor([0., 0., 0., 0., 0., 0.])\n", - "反向传播之后 conv1.bias的梯度\n", - "tensor([-0.0368, 0.0240, 0.0169, 0.0118, -0.0122, -0.0259])\n" - ] - } - ], - "source": [ - "# 运行.backward,观察调用之前和调用之后的grad\n", - "net.zero_grad() # 把net中所有可学习参数的梯度清零\n", - "print('反向传播之前 conv1.bias的梯度')\n", - "print(net.conv1.bias.grad)\n", - "loss.backward()\n", - "print('反向传播之后 conv1.bias的梯度')\n", - "print(net.conv1.bias.grad)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### 3.3 优化器" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "在反向传播计算完所有参数的梯度后,还需要使用优化方法来更新网络的权重和参数,例如随机梯度下降法(SGD)的更新策略如下:\n", - "```\n", - "weight = weight - learning_rate * gradient\n", - "```\n", - "\n", - "手动实现如下:\n", - "\n", - "```python\n", - "learning_rate = 0.01\n", - "for f in net.parameters():\n", - " f.data.sub_(f.grad.data * learning_rate)# inplace 减法\n", - "```\n", - "\n", - "`torch.optim`中实现了深度学习中绝大多数的优化方法,例如RMSProp、Adam、SGD等,更便于使用,因此大多数时候并不需要手动写上述代码。" - ] - }, - { - "cell_type": "code", - "execution_count": 32, - "metadata": {}, - "outputs": [], - "source": [ - "import torch.optim as optim\n", - "#新建一个优化器,指定要调整的参数和学习率\n", - "optimizer = optim.SGD(net.parameters(), lr = 0.01)\n", - "\n", - "# 在训练过程中\n", - "# 先梯度清零(与net.zero_grad()效果一样)\n", - "optimizer.zero_grad() \n", - "\n", - "# 计算损失\n", - "output = net(input)\n", - "loss = criterion(output, target)\n", - "\n", - "#反向传播\n", - "loss.backward()\n", - "\n", - "#更新参数\n", - "optimizer.step()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "\n", - "\n", - "### 3.4 数据加载与预处理\n", - "\n", - "在深度学习中数据加载及预处理是非常复杂繁琐的,但PyTorch提供了一些可极大简化和加快数据处理流程的工具。同时,对于常用的数据集,PyTorch也提供了封装好的接口供用户快速调用,这些数据集主要保存在torchvison中。\n", - "\n", - "`torchvision`实现了常用的图像数据加载功能,例如Imagenet、CIFAR10、MNIST等,以及常用的数据转换操作,这极大地方便了数据加载,并且代码具有可重用性。\n" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 4. 小试牛刀:CIFAR-10分类\n", - "\n", - "下面我们来尝试实现对CIFAR-10数据集的分类,步骤如下: \n", - "\n", - "1. 使用torchvision加载并预处理CIFAR-10数据集\n", - "2. 定义网络\n", - "3. 定义损失函数和优化器\n", - "4. 训练网络并更新网络参数\n", - "5. 测试网络\n", - "\n", - "### 4.1 CIFAR-10数据加载及预处理\n", - "\n", - "CIFAR-10[^3]是一个常用的彩色图片数据集,它有10个类别: 'airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck'。每张图片都是$3\\times32\\times32$,也即3-通道彩色图片,分辨率为$32\\times32$。\n", - "\n", - "[^3]: http://www.cs.toronto.edu/~kriz/cifar.html" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": {}, - "outputs": [], - "source": [ - "import torch as t\n", - "import torchvision as tv\n", - "import torchvision.transforms as transforms\n", - "from torchvision.transforms import ToPILImage\n", - "show = ToPILImage() # 可以把Tensor转成Image,方便可视化" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Files already downloaded and verified\n", - "Files already downloaded and verified\n" - ] - } - ], - "source": [ - "# 第一次运行程序torchvision会自动下载CIFAR-10数据集,\n", - "# 大约100M,需花费一定的时间,\n", - "# 如果已经下载有CIFAR-10,可通过root参数指定\n", - "\n", - "# 定义对数据的预处理\n", - "transform = transforms.Compose([\n", - " transforms.ToTensor(), # 转为Tensor\n", - " transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), # 归一化\n", - " ])\n", - "\n", - "# 训练集\n", - "trainset = tv.datasets.CIFAR10(\n", - " root='../data/', \n", - " train=True, \n", - " download=True,\n", - " transform=transform)\n", - "\n", - "trainloader = t.utils.data.DataLoader(\n", - " trainset, \n", - " batch_size=4,\n", - " shuffle=True, \n", - " num_workers=2)\n", - "\n", - "# 测试集\n", - "testset = tv.datasets.CIFAR10(\n", - " '../data/',\n", - " train=False, \n", - " download=True, \n", - " transform=transform)\n", - "\n", - "testloader = t.utils.data.DataLoader(\n", - " testset,\n", - " batch_size=4, \n", - " shuffle=False,\n", - " num_workers=2)\n", - "\n", - "classes = ('plane', 'car', 'bird', 'cat',\n", - " 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Dataset对象是一个数据集,可以按下标访问,返回形如(data, label)的数据。" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "ship\n" - ] - }, - { - "data": { - "image/png": "iVBORw0KGgoAAAANSUhEUgAAAGQAAABkCAIAAAD/gAIDAAALVElEQVR4nO1cW3MVxxGe2d1zk46EhEASlpCEBaEo43K5UqmUKz8jpCo/MQ/Jj0j5JQllLGIw2NxsK4iLERJH17PXPEz316OZxdLoeb4XWruzvbPDfNOX6Tn64cuRUkopVde1OomqEbmsaqcZhIKbFTVJVVV5jRsWRGdRlaRc8d2Gbmtu3/ADTdM4Ql4m0tXavYs+NI1mVepjX9pUckUXlXMX7RMVcWbEwQpAppkCEACttMgsJiw1uOK1EYHvJWhtvQWqUhY0s0FrJqbGYy5V00S6Bwjf5RpdSZKU/vaorRrpldau2oRfFGdWAOJgBSBLZMLy5OS/Ey1DCakRXqAZBDZJunEayRVrkguNEpe3iSwO3DkxWCCv9R0ed1J0hvsO9uFtYLTSNg3d5QhsjTMrAHGwApBZfHJnvj2zwYtaYTLz5GzQ5qQiS2w8U6tsHmm3WeL1wmK9e+VEJ+C7ijlkN5VvZQluJM5HqTbvF0Y6zqwAxMEKQBysAGRwWH0XO7GMqCxVWNr4AawFvivQ1O6CaLvdWpNLXXNEnYgn7boC0r0Ga6ulSvrAjgVPg6pkj58vQUOt8S2W68APIhiIHvx5EAcrAJk18wkWc+ygF3EsLnLuSey9qwNudC0+hJU5Ep/Ddf0tEvqeDFwHiztYHBpXgzg0csv3VERV43EzevDnQRysAEg+S8JgvldbbEgSmBKka+mWzHyhFWJyNIJOC4hs4VJLPotttJDPzZbZy0cj7V2zqFJXuZh7lmp7zqAPSMYpV4g4HXGwAiCBdIt32tjWENlk7d/FE87f2mOHnSxqGtCQA1rtPqi8KxBOmjD0wPOQvTy45fm2JA1ASY3eeHyMOB1xsAKQiY/n08q68BteX4vQss9DQm3lfxEb4lErGhVb62mSXLXVP5ek1h0v2e0pP/HtkuJ2I9w4swIQBysAWcPTrvIqAM6I1CcdiMIzuYBHmGR4MEFWl1mQ8pNlUzhv0QolCOzxCotV3fD/OvIwKLNgVbVmd9oLRWtRLouD1q6vHGdWAOJgBUBI0VKMcDYknt8olUaNn8axd114ejeunRKC1O5mkrWDeyI6dL7DN82+oYQ1TKz8rV8A4Wd7Ik5HHKwAZFq5+/26JegTaEmQIg/jjnjLhEey0dr8hJXBbgjebO3XukkebM3aqkBkLQkc3PXKJrxMqd3hunYXkxQ+s/tVER9HHKwAZLLm+5nEVrRUyzGJavdOW+7UripgS4dSPNGIJCZtKaZI47CG1OonzC7KpPx9B+isGrwOHRUaViBpTZ5qmqb8FRFnRhysAMTBCoCUSYK24H9b4liWNvHFUZXfkrjlhYP5n1mrQ8aBcCUmPOFuUbMclf68q4QOpFYVJ+JorINwgLAdlXiOQtW6THsZ7RhInwdxsAJgbbJaRbzm38r2tgE/Nhan2fWMMZNB0IP9D9C0vf3OCEXB2StW1ZuYcl47nBxSr/hQTpL15TO482VJroZfoCDuiFecUdvBAF/WnBqLZ3fOgzhYAcjatm1cwYb2JrNVe8R/s+1Dk4QLfp/98BCq7t69a4TxeGyEPCc+Fg1Zyi++/NIIn9++bQTQcHK2B1U4QqekNApW3k1tV6UbFdi5A1hPmGZvszXiDIiDFYAsscp86N/W3BPguay1xmRmTR5/Gy6xXbh0ERdXlz+hFzEdtt+/N0JeEw0zVvr4+wdGuH79Bt868Qb+CPSKbTrTFoF3guJcvlLZpcbMOkmXt5Q2RpyGOFgBaNndCd/fYQ1ylo6Jyf8X+TGZuV5X3njzxroRpqbIBf3mm3tG6A5njXBwdER9YtZfnL3g99M6g4cKRWTZvEIoT0r8PLhStXcsPs6sAMTBCoCYlMqLqqQw1vb6pJCBnTdVOQ+CAjjR8fbtKyN8d/9b6Dw+PjbC5i+/GCHNiKTXrpOw9XLLCF999SfuFPWqKqQeIvUOi9f8OR22ffiZCvldB8mQW7UOqPzDOHBqO86sAMTBCkBWeT+XIlV6lt2QX3GQ/U9qX1aF00YOjLEvOneZrJvqiDVMFQV3U3Nz1GyOXNa8yo2w9YpoOL+wyMq5JMi22rUwivopd9wtnFq5YeOJPSfv5EyTRGsYjjhYAcgQOllzklDVYiPQLFMwgkix8hlLsaL0f3BhetoIPzx5YoT5K8vQeXBwYISpGaLh/v6+EV5vEfuevPjJCH/7+z+M8Jc7fzVCryuZUuvnlOhKXoBE2hFg2cUVtew+fNESzWKtwzkQBysAcbACkB0XpXNJ9kUsM4/cccXubJmT/52mXW5BQ//zTz8b4e3bX42wf3hohPxEJRScD96w6Q2MsLh01QhXr103wmBIy193YpJ7YvWZ/Ymyoe6N+St6aYe/y1udJeQQVVhwk9oNSOLMCkAcrABk9+7/10jwtuEldKzcU6/DfnNN/vrkgPzvJCEaNglduXdvwwgbG/eNsLu3Z4SF1TXoXF4mN+Lp06dGmGNXfmVlxQjrN24aYW2Nkl9vft02wrgQHoJZ45w2ipBTyziQxg6TtfdLRCtKey1q4SZpcC9EfBxxsAKQvf+wa6TBgCxRxkmlzLKGmoPJNSbIzDTlgvsDqkJ49uJ/dGuGMr/r69eMsDMi13x6fhE6//Xv/xhhc3PTCCWnqO7c+bMRZmcptH786LER3rwmGua2OWQTdshmt9MhIwinPpX9Hg6k4dNbNMTeKtYlv4Y64nTEwQpABpNSHNAEnp2l3FOv30W7hUt0scPcHI12jbC3T/Gw4jNqv7tJlmtpiUi3u0c03DnMofOPf/i9Eb74/DNqtks6+/zqmRnyRY8OaJvnYH/EfWeiWdVRiIgrzohhdwe0bbyAv2yj4W9UL0WcjjhYAcgSnszb22Rl9njCPzvaQbseVwpcmiVepFLaQCPe53I9mNGq5NxQ2bJBsrJ8hVRxVT4MMRzjfEz28ZPFy0bY3KRUV29yILqYUKMRkTTPmYZcnIsMV8qVvzCCRdFCQ+tcbsxnhSMOVgCyhmfdxUs0z1EOW42lWLbhY9mDASVzUQePCp5KUZuDQ7KPBVfyjXMOPGsxYTnzGDSE3cmYKSknWLocga6vXnUeV0qV7HlWnDhqeM8JDNOpe1K8kjNDkjgqeenAmlDHFM05EAcrABkog1mHdAccQqWULjkvyns5OVfN9jPKzHSEO8je8OOY+aX1Yww1NjvlPdyM+ctv2d+jDmRMzP60dC/nOG5+boaUF2TT9yoUPXT4HbKBRVcSoXQxphdVXAQMWxlnVgDiYAUgO2YaznEyBDwBv5RSyyuU1ex1aTI/evS9EV5uvTHCYEhbCUh4dlLyG3WXnUxl5yS50LxyDWuGA6mcGtIDEsbwNot9UcQBYMo1VDOTE0Y4PqRDL3VO2VosF3ND3h9ZmIcq1Dq8eU0PVtXgRHcjzoI4WAGIgxWAbOEy0fWIyzQS9iFu3/4M7VaWKTO1NyLmT0xQNvnwmIz00xfPjfDkx2eknVUhRzbJJ+GU5a9P8PrS4aieM2MSig/6tHCguPKoOIYq/KbTaIeC//l5itKHvJIOp+gtV68sGGHpCn17t2M5NLwX++7dB/5k+sA4swIQBysAGfI+MMljrtPf2JDK4offkYBULJJWq2trRrh165YRUGb14AEduHn+nBi6s7MLnb0eu/68EwNh0KFb3Q7Fz91u12lTWbWNSUqdQeHFCgf8K4urRri6St7PBU6E9bFzbKnCNm2vR+m50ZAS7nFmBSAOVgAyJGum+QDN+JBouPVqE+0O93aNAIp1mBf//PprI3Q9WoE7S0tLRsjzH6ETaazhkExkxldqjl1hm0bcAcTkCJ6VUkfHtIZ8yiVKO2wWYaw7XVI+9SkRM0mQ/hYavt+mF/X7ZD3n5siUx5kVgDhYAfg/pQ4eZ65sAxcAAAAASUVORK5CYII=\n", - "text/plain": [ - "" - ] - }, - "execution_count": 5, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "(data, label) = trainset[100]\n", - "print(classes[label])\n", - "\n", - "# (data + 1) / 2是为了还原被归一化的数据\n", - "show((data + 1) / 2).resize((100, 100))" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Dataloader是一个可迭代的对象,它将dataset返回的每一条数据拼接成一个batch,并提供多线程加速优化和数据打乱等操作。当程序对dataset的所有数据遍历完一遍之后,相应的对Dataloader也完成了一次迭代。" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - " cat deer horse plane\n" - ] - }, - { - "data": { - "image/png": "iVBORw0KGgoAAAANSUhEUgAAAZAAAABkCAIAAAAnqfEgAAA09UlEQVR4nO19SZAk53Xel0tl1l7V1dvs0wBmAAx2AiAEimCQIilrs6wIy3JYctjhi0+KUIRPvvrkcPhqRVg32WFH2AeFFZYoOyTKpCRCJkWCBEgAxDL7TE/3dE8vtW+5+fC/72VNVTWla9P/u/Trqsw///wzK/N9b/keYMWKFStWrFixYsWKFStWrFixYsWKFStWrFixYsWKlf+/xFn86Lf/2T+W75zMKNFkDGDc78i/45FR0mQq+6Ty13VdjitKPI2NksHnMV0eKpvdOQgLRikWQ6P4BdnF90UJAtmmVCoZpcCvCp4MW3AdAFkmw2ZpwoOlPL5sWQzkQOVi0SiVUpHnLisTp7II//zf/AfMyMVXNo1SXz1rlO7BUHaJZBcvkANVm3KgUi2QyfgyfhRnAApcyEd7B0bZuHDeKGkk3/WPxrJv7HK2Mlp/0DOK68n61Os1AFEki58kMScuc5tOp3Nn6vuebOLJJ54vB/ID+aRYCgCMB3KmDu+feCrDfvD2D/G4/M6//FdG+fJrbxmlfSyzvX94DOCv3v66+ff2w4dGefPNrxjlaedYxu+Jcu7zv26Ug0jO/c7tD41yeHDDKN3jXQD1alXWpCpX6v1P7hilc/xAvsoio3zx818wyo3rd41y4dIlo2zfet8ov/qLbwHYurhl/r11Z1+mtLpqlJsc/79+7S+M8upXvmyU15590ijR9n2j/Nvf//eYkd/+F79jFL3VPU8uhxvImTr8KsuyOSXgnex4sk2SJLODeK4oSSw/h6AgP4d4MpEDpforczm+bDOZyP1jfkxONuHEZQM3rMmBQln2zJNpT/UmTKc8Ndk5iuSWTibR7OnEsVyX3/29f4fHxYUVK1asnBLxl3wWy9ssLPKJjgEAL+ubfx1XHpkFvpML+oB39QEvg43G8liNJzRweJwgDDBjb4Uh7Z1yyA3EXnBdtQJ8blyYG81xOL6TAfByO27+lRUEYkYVaa8FtCmcTGbrOnw1LTNCAaRTOdz+jthEDt/5NFmQuLJN4vBt78uhx7EYKYVKCKBcF4OxnkRcBNkygyhRJGfUO5R9hzR11WZUu2k4HIGvWQApzUzP05dnxtk6/ETPja9NDuupFea4AMo0b8OCLGC708cJctyVrz65c8soMV+5e8dHAFpnG/J5IMe9+f4PjHKxJed+jldzjdbZ+lf/vlE2X31OTqSzY5S/+J9/AKDIpdgdyXJ1Y4EIo1he7Jc3zxjllRc+I1O6LVbed97+c1kMyBX5zrsfA3ju+c/KuTuPjPLBt75plMBvGWV15aJR9g9l38++9kWjTNY/kXX5fcxKgSupii6+S5zh+PP3fERLJDeTaUl5jgcgpYnk8YeZ8cZOE7GS9GrGEUdz5VYo0GD3qGSpAwCp/ELVcnf5401dudNcXw5d8NXUAkeT2QZ8gMTeFECaqkF3oiFlLSwrVqycGrEPLCtWrJwaWQIJ6011bItSqzoAJiX67cZi29H6QyEoGyXjE3A8IRih0z1OxnMHCn0fQKFApBaKoUhXMqpVOt2J5tT0dR0ZP8l96iKu5wFwGTFIErEzpxOZ9ngkSMqt1eXQBKGuR2OYs/LS5c90n+7GCaFHRsjmevKVz+jBeBRxMrKxH8qy+BUPgEuvdq0l4DGhP3hIS9ohCi5VK0bpHQsqqVYqevbmT7/fBxDRyC8QVqiikDCHHlSCgLibG8eED+PhCECrIeumoY8pgf+iXL99xyiNspzaG6+/bpRPv3ULwL37AhVrZTmLF56/YpTi4Z5RDnZ3jbLaEVd3CXLE5oogyj//+h8ZZWf7AYDLq/L5rR+/Z5RWTQDmL/7SV43y5OUto9y9e90od7YFsg0GgvTf+rmfN8pzz70MoN2RmycZDoxybiKfxEdyOYpDXjLePA8PZNqF6QDLRH0deofnCiGg3oger1QKdWWAG4sYL0rG+E+qPhqNhjEM4/De0/s2pmNef3cakJlOYgCOw9splN9qyhlkKc89476hukqIRhnIcghgUSgASDgnxcWLYi0sK1asnBqxDywrVqycGlkCCTvdtlEGPTERy2EAoMCYWsEXIKAJQUFBsE9EY9IlUitWxCB0mdyhIa1yuQSgWhM4WQg0nMc8KaZHKb5ziPw0bqUBRM9T1FMAgFRtVDmLlHPTUJpLWDEYCFyNmPRUDJmZlaOtx+TlF58wykFHhr13QxBBjZh2TMw8ZKCqALGB11uSt1IsVwAEFYYCaQyPOKXpVAaJafa7BVmfUlmWztXV4O4m4BLHmoelIVReIIIHDc2oOD9ByYCZ6HBCyKlJT4tS1qw6Xt82b7B+rwPgledeMP+eXZEo29OXn5adCbI0D2sqgyG+I+lXLhfh7gcSW5xM2gDe/sF75t8Wces/+YVfNcoLr71mlGuvvmKUvY6M75bkjP7HH/yhTOaZa0b5/OfeAuARAP7NLUn72rwuKPJi7YJR1mJZsG4iMc3f+0+/a5RnN+tYKry4izlWWaIXiBdR8+MYPNWL6Co2dBzMYDoN52kcX38vE8L5IJCNM41up3Rc8CfpSf4gcx75w9cDezz3mEeMuGKe5pEl+tvk3o4DIE71Ll0emoe1sKxYsXKKxD6wrFixcmpkCSTUGhePj7MszgBMIwkVqcHm+VpownBSLNtMUwFBil+qRWYbEm2Vy2UAtTqT+mm+akGAT6tVUxlT2pmaCanpjppKWqvXMRNJUYs6Yb6ixs6iCZNa+ZXm14UEMmFYwjJ57bWrRrmz3ZUtCfeePHfOKPcZG7qx25ZTpjF/9axsY1Lw2jyv9khAhKafJoTDKVGcmvdTAtgC1zbT+ItjJi9nMWH5xXSqRRXgtEVKCsA19sSvKiWBXWaRc5RNyOA4J4Z11taaRtnbk8TOo0MBehurGwDOr0tpy4U1Ua5eFUjoeM8aZTiVZfnff/onRgnee9soFy8JNj/bkEW4cPYqgKeeF4B2sSwobIPKB++8Z5TWplTtXHh2yyi/9Gu/bJRvf//HRrlzS4p1PnttD0CLAeWtp+S4hSPJNT18IOfVLAu2vX5fynp2+nIn7N1hCPxx0QxPF/orYBhXf27O/HXRH4i6YnLJzC5yv2nQOckYIs9kJhEjiQoJC76WyLCujuE842+JCeXiTBMGOKlMA5e8b1mRA24TEkhquHASjwGkC3B4UayFZcWKlVMjSyysp69KFsx0SL/vcALkZo4aL+omd7W0kqk0SWNFNuZLQA0rL3cEZgCCAr3yno7G8TWDA+qApBeZVpLPhCn16ZpH/JiedbXBtAp6PJHzcjL17osFUVY39kJ20pyUCzLa5qoct/WqrJvHhKyMGS4x1ydgOkyDFa2mlGHKtQ19pm5FNDNpu3hMfklY5JSFjDPQwp3wkrlIAUz53o743vNSrRti+hg96C6vS5V5XuD7s8BainIpBICMeTq+XtMTLazzGy1OUk7t1vU7MoVRAqDsyOdFKh6jOgnN2wlfxccDyWOaHom91uu1jaLBnEmcAXj9rc+bf+OurMlfv3fTKJvnxfjqj2Tdxn1ZqJXqulG+8BUp/dn96AM50MN7AGqbcjqdqWCIqCEJX/sPxIzqemLFNBixiVzZ6+jhPpZJHjtayMNSh7pDsJJnI/KUk3T+JpfymnxLOVCs151GjAZdtHSmzPvK02xHpTlABoCAJI8IpTS6tKbNDzTHivZaHhAghuNt4yYuZlHRQiBIxVpYVqxYOTViH1hWrFg5NbIEEl48R4Kno7ZRBt4AAOjXU8tNszOU8ibjeG6gSf2CfRIsFNNkGYBuR/ypmtNRqQouU5+uFotosYvrBJyDAkwxI0fjCYBBn9wSPK+AhSYBvcsFzUnJa9mVBouwa9GdaY7Ls2hWiETo9e/t0xPJ9XninACNpkcfN1FVr9OfPS6JMBCPGAcgmovpL89YKKPrr1k8RZ7j+uoKgDt7h+bfPmuDWEKPDVa0ZDnTgxJFcJGhaJHu0sDHDLWZOk0znOglnXZlDm984UtG0WS07bu7AA6PJAfq/v1t2aAo3vGXPvumTJuXu5pwShW5SWLClg8/IRFCWAZQKEswZ0jYcv+RHKgliA3PvS4reQUyWi0UV0ahQIINrn8hiwFUSrI4RyQF+6Pvfs8o62QBubb1lFE+/8JLRvnu9z82SufBMZaJGg6aH6cgUf3xM84N3pz5AKw/Y2TMoEV1v+TYkIM5iTpGZJsolhs4JvFJyF/xJGJBUpICcDOlOZGJRwpF9TGhhHTMuko53+FYf9d0KTguAF/DSpmFhFasWDn9Yh9YVqxYOTWyBBJutsQqrjOLZ1gdAQiVpY9W64Qg5ZjgUaMGoNmvgYAiI0qaQjUZTQAMmBg0HIjF2Okz4csXJSxLtKhcEaVGIgeHkE0Vkz526axQDIdFmYCvBQ2F+eBjzuhAbDshPoqTEyAhUSSYgtTtMhoyFmXcE0N6tSX462yzKQciJOz2RwAO9yWXxyNcbTE6eNBlNGogkZppVyJlIZc0cpQxQuz8V178WQDXrslR/uzP/8ooTcLtr35JCIvffV9yhXZ5EZXRQa94HMv43V4XyyK/IxJgLEp1Iti8vyO1LI/274lyuA9gvSHpVy1Glm/cFwrju4/+l1HWGGlqEk0cfyQAcJ+h2CtPXjbKOC0AuPup5E997itfMsozz79olHu3BHse9+WUD1mak2Yy/kZVbpLzrwmsazRCAB3SkHz5F/+BUb7+7W8bpcKg9sULQuC32VozyirvWxKgoD2XD6d8ivzAXQjwAcpznW9l/ii7tSZkmWRGdb8U+FtWMmWP/paE0buUSFAT/RLNBeMVTxOjsMCLBA8e102nnSmDZjYf+5shj5RNTA7gdKKELjYPy4oVK6df7APLihUrp0aWQMLGStMoTSomUKI1LlofEzE5U/vcDAiCtM1GwoyyiEUnEakLsmkEIIsFpAxisai7EzFwh33BPqORlL8orNvckCDOuTVJydtsSWhppekBWFlhNUlRwamS/ykdIPgVo5BEi2pa/62QUDnX2QcHGZNaPe66UiVToJZK5PTVCYAC3xxPnBWWcZcTOFORfQ/asm43yCE3JnauNORkJ7TM723fA/Brv/Qr5t/L6wxTrgjsOmoLCDo8FKa6ep1xQ8bF+gMCPYIFzw8xA6Vz3osT0msBDNvCpH58IJUrF87IZFbX1wEM23Jj/Pjjj4xy8drLRgm5Stu7km/5Aq/y+gtStTM+FG6/ClM6G/VNAPfvyS4txhM3tqSY5skLW0aJSSTwta8J+V/GOpJ2h0x7vPf+8v0fALjy7PPm31/49d80ys+Qj/DqmQ2ZCQQA3iSjQ28oi1Cpyt2Cx0nwc6597UDlqAdDy2Lm+fg11qxMeCVl1EvS2S2VfV+pH/w8XEhXA+vqYtJYxmQiVEY9zykASBhhVEioPyUFgooN9YhT1oeVeEs3ecsdHR4AmE5l2FTZBRfEWlhWrFg5NWIfWFasWDk1sgQSBsRHga9+/hTAhFTuWn0U0Fht1gSghUwTHXPjEWHFhJGFlEXbncEIwEFfzL9+KltGzHabJGTL5r6ahloo+3OKdspaWasACAKdpMu5Efcpk3Ruxyoduxask8RaK6weF4e4KSyIkR/4smWzLhVkBbLgN0ge7zCaNuqLsr7SANBsyuQ3G7KSJVrUXeZ8VmmW7+9ISHFCtgYFmCBQ/fCj6wCe2RKQ9dqLAmQ0IfYP//hrMj7DtAGp5fpMiewTxWsqabNRx0x3Jq1cUwKMRVk7tyXbMEHx6JEwtT/zwmcArL8mKHhIlNLcbBrl6lXBfUe3BVupYwEs+mutCuLbPCMVgoNHfQCIBNPppdRwVUFf05z//o7EDT/4kbAA3r4tHU/diYyzUgkxUzR35ZrEHL/0pS8Z5RxzcUMSW97/xjeMsncol2xtQ6Z94yHTXAHMdNZKcyYMzlEDfwTmmqY7Q/KnsTm9tzOzhflvRAZ6OJojrSx9DAWy2DN25iczJZ1k4DqYCTVqxF+5GWLmCSRsWJdxtJjg8Ec/lG67AaPMzVoNQIMI0bMEflasWPkpkCUWVveR1FIUlbPYzTDrDsyrrlnfz66f2vx9HMrzNSQtwTGrUvYGsteN4zGAg2N59mfkBigVWQwB2bdFTiUlzyoVlZ2ZBlQow5arRQAVdoEN6EHkqycnCdLnOK2l3EGoTlD3BKanUN3zXJZqhb587YBCn3VA0uRqIAZUzFfF8xdXATw6Eg9xoFSztOw89nbd3tnj3MQmrbXElJuQv0EpKiJkAO7tyS4/87q4sTsHbaPceyi+9pjXZcAUJ59xhipjBaOxuIh39/cAnKMLX2uDFnmWVe48POQpy+Sikfj7G80VAFvPihn11pelmU2Bfc+P9o+Mcv26dNaZHstoEdOIGrRZBgOxkh492AOQ0L1drMuaH+2Ie/5od4erIZ/84Pv/1yiffCo2qVaYFJj9FDoJgJ0dyRF793t/Y5Q3f+ErcoJchAJtijL79PTpud9Y28Qy0Uw3ZTLIHetsvaNtcd0FA0T/z+iGN/VzZabdqVc+oZmpju2Ye2esRtKKnIIrN2FF+c3jDEDGQfo03LS5vNKaJ3ThJ+yjo7/QJxn9UBe9AXPa2iezpTlWrFj5KRD7wLJixcqpkSWQ8OE9Jraw4U29XsEMd4JaoB7tTFf9mjn9mJL8ySCHXUlF2TkQn+4kdgFkNLmVXnk0lEFYz4BaILZinYqWzrgZU0Los0/HAYBxJIOMdNp0EIbMVamw14vOX7vmKEeadricE/XcT2kMZxBlQst2TKu7Q9q51qrglwsMUxS9NoBGRY4yUpIJ+kRbdZntE5fFOd0niPtgW/DRtM9UmqkACdNsdUKs21yT4/aYujVgrxS3JPhrQndph/wZDqvntao+jscAuj0SbHgnOkdV3v1E/OUXL8ocilPBR+P+IQA/kMt8/oKcYIEgaLIr94w/0YagbL1z2DZKEpP3cVUmY9agS6dvfyizPbwpxTo3PpRqpEe7UiT0YPcux+d9lTunZTKDKAZQZMnU994RkobLL0lA41VGNvbv3OTZs6sokf5owKDB4zLT1xZU9KfEHxeXJVuAhOrYVn+N2SLn2mTExuX8yxXxcPdJ0nA8ZksnJmAWMrkbpwPJoZtMepipxNLSH522/pzdPIyQdwQyf6ukyUTOYx4BmHLxNb9sUayFZcWKlVMj9oFlxYqVUyNL8M5wKHZgmpCeAT6APgtl1GwOGUasVNkRh00r4bIhKLP7t9lT5O49ApnYx0w0pErGtXpJAEhYEPNyheOvr0v5Ra0lympd7NiiJwca9DuYyb5xaSS7uR0tn0TEbhrjUv4JNdE15WROYkKqjKFMJVHrs+i8Tf6G47aE5FbKZzltOaNPbj4AkKRy3GaDsa0SQ7S0qMuMh5YZsskmpOXjqa0QxSfRBMDgSC7lkPHKPRbipAPBZR6nXaoKa0KayCC72+IcWGnIpVlZqQPod2XYFjOPPP/EN19Kn0BQk5jmxaKUB7368lUAXk4OIUihvy1RvP2PpHVNlTU0Zd4bet88uCf46+HtT2WbcgXA5nPS1ijal7M4vCmjjXYl5lgjkirk/HOcNjPyXIa9xnEEoEs+jyIjvy5z36rsLXS71+EocoGe3LpklEYoi4wf4zFR3OTMZzYpJMwWOtvqXZryk1SDjCkw08HX4z3pE4WlU1nSfibeiaQoJBNeWcD7MdPTpt0dzsoDWfMB+FzAIG94Qy5PbX7MWyPlJ0oCkUQKhF3M9HBwbWmOFStWfgrEPrCsWLFyamQJJFxdbxqlxMR5x/EBDIdiXo6HgobG2nNpLAZnpaaEYWVurNx4rDmg+Voo+gB6x5IcyARJBJkAkEpLBvGVfqws5muT4bZGWY5Y8hjgy0IAYcCET1r7mimX94MipcSI+G48Fmim/b6Uo25Oxiw9B+uTQmZR+oyd1aoy7SQRADhiPbpWSoxHCYAS8UWZLVFdBpjiiIEYZuudYRDncy+1eEayF0OL2Hv4CMBKUzbY2xM83u4IWllbkaRQXhYMGMYtsW9VifzlKRNSjx4dANjYEFyjeDnDyTZ8UU7tY1IX/PyvCb3BM09fBrDNrNRGQ67pp9els1bnjqRxrvDmCIkmzgSykmNfoM0x+24V2n0Aq4cyyZU+Q43rcl+V+3KCh8RuVfLxt+klGKrfgBjQdKXLpszaXZdbUdNre0cyWkz81Scv4OsvkwWwJIwOf/zNP55ZJBQYdFYSdk3K1XCbo9wM2gEMWqyz0McUKQCf9wxIrK5ga8LSpQ6rnZIVuVsmKUGoNh9jr4EbN+4B+PimZM8+T9LEl648weMwYJ0uJKbm1WyEfgzBGzBYrzNRmY+aRbEWlhUrVk6NLGtVTxdusajPeAeAS3unHDIzpU/vMsmV2m15WyZa6co2lnsP5f0zINVvIZwAQMYu6rR3jpkMNY3kQTukk2/EYT2+1jbpfW+xbNjHFDPe04h+9DEf2wMahuOpvkW1YnPB05ktf6Z7bDZZJhfYZExCZLLrliraKV4+GfJt36CX+vlnrwHotcUC8hjQ8D368nmJGnWx9Wor8tXZqUwy9cV26LFItRIWATSaq/xc3qJHtLBSfRXz3a6l1HFXNl5pyiIXWREVTX3M0IepBEGAE+RzbwkXc78v5/jaF6UEp9hcBbAK8cEXQ5nShz+WCuSDB2KU1WM5rybtWY0VXGjKsly5JMbL6HAA4AKr4je0UQz9watk3O522kZpVeUED5iMdsyrGUdaWRIAWNuUJf3q3/uyUc6fER7k40dSX6U5ekX1x9PMLCz5wQHIfe1qAemt6LustqHTWkv01dQKNKGJ47lZBlbSYLbih9dOMUSFRWvdqfwutCYPzNV6+1ti6t69swPgeCA/80PS0o0yiSo4tIG8sCljVGTFFPRo/yo1JzuDIwD7exIeGVoLy4oVKz8FYh9YVqxYOTWyLA+LVE36nfH/RYRsKZ1/PtmHayTeTRJ5Ao7HtMO72leDdiyrakajPoACwVFIh3qpqA5vgVRFuqKLbIFTKimbAksB+PA1BmfO0KTVQvTweXQVV7WhjnJaaWcgxZQn1I0XmRDEyvac+6FAtDgZCpCZkHOqEgqAvXRRmvrEgzaAY3p2o5FsqcsC1lJoN1kFAjFd+IMRIxsjjYf0AJw7K2lfIZ3KUawxEIYgCoovmHlERDMecfyuzGpzvQWgypImbWIUs7frovzWP/0tDivw5PKWnLvpxXKmIJf71qfSbfSDG8IVdbQtFTMXuRq1TWE7SOl9L5NOY60mN2FcrAGo1mXYR7vCbPXJR5KoNSTFmMPLvcbOPTu8ZC4LxTKWXpkb7CyJmL/yxZ81isKx7i69/hz2+avXjNLYENh4i3Vvc6K9ncplpra5giJz9iulqUoUPyoReTz3leFWiXir5B1VSUhy3GOOHityEJA3hX6JoChIOSzLWZ895wC4XJGUukuXBAm2zoj3vV6Se0OxZ0TqBeVGPz4SEpFety0nkk0xQ0euNFuLYi0sK1asnBqxDywrVqycGlkCCaeMLEwZp5hGJu5Gur5qg4pYjIk++BjbirnxOqOE/cJ1o7Sj20YZPRoCiMasHqDpWGQNxyYN6SAv1pHxNRDj5HSxGukDZiChdnZR/jNtFAqCIKa8zNQ+6Pje8md6ljebFCVgxUycchhGmlihgc2LckYrqxKA++DuJwAc2s8xSSYU4RbyfBzwE05b6yEYCa2xjmRSLgKoMkx5/rxgww8/YR1GprhSJjemHc7OpFhZkdk6KYno+gMA00gOd+GCgLufUJrzMlOQVBSemHKUkP8ePJJY832SC2rRfodT6uXRWy7CiJlTd6R8pLXaALC+Lud+f1foie9QmTAoHDJPLSSbhcfqkFIsaDfq0wHi+QCeufK0+XeFfL5dcgpOWD80Q6cnylki2fF0OdjJi2wSjQmyIQ3ZLHymBCqDiLLl5UuaQ0IHM6zEE6L7eCJh4hFJLHoD+SR02LSGkLDWkBtg66nnjLK5UgbQWJHlqnEBU57X8EhKvro9iZl2hqLE7PLravkOc7WMOyK/1W0jVStWrPwUiH1gWbFi5dTIEkhYaYiZp3mPpnC835MgQo+FLKknNl6ZXWEKRY2yMWanXWccBpIcEkVXygDqbH8CQsIeCyaG7LhTqgj2PJ9IZ5SnrjA/TUNmkQbIRpjpdqMxNeUwU0M6TeYTR7W5iEOSvDhdbp0mLM1xq5oTyCjeWJQaeeWnicIKmczDPSluMB1KHH6eKjbUYn0m1irPtaN5gwqQtXyHgaRi4AGokGu/wRTQ7QeCm7S0KCI2TPz5dESfk9GoTbFUAqB0jT0y+fmFE5n8ZhgQFbyzVsMMS/Q9ILlgn6nFLinrHvEE7/PemxKDTAl2AtKwnxuPAFRYYHTILMQ9Vswocz+YRhtqDQoD0zXSjYwnMmytXAPwmZdfNf+WQwHX/Uxqy3bJGZ8yAqvhVPVLrDKePicjlrsNSJWn6+b6RIKFcO6rvL8pu0npNsVyDUC5JouwvnnOKDFTcDUMff5YLuLNmwKZXaLR1TXZ68ozguvrwRTAwf4+j0soSq/HjRtCjuizWssrsH+VEmDwvsp46Q0dvkYJs5wrdF6shWXFipVTI/aBZcWKlVMjSyBha0MiGgGLCk0VXrE2nyLps5pfa80URGh4osoenGsrEvtrkMPg3nEXQJDI59WKZPoVmER3/UOpYAqKrBPUNmJklY6mDKuR8KBWbwDItCElZ5Jno/Er7QPqERFUtS0SIaEyEc5JwvaWU6KVKGbCbSZnVCTl3gpP+fBQAmG9EakRyg4AXztWMripHGlFXW1meCYsLnOYSlrOiD0JPQyTX5GZvT96/0dUPpRzpzVeIe1crMHBFiOJiYw2YcrfpQvnMVMC1meAyTmhGRoe61u1PAdXIfeEfIQZ+dR9ZiEeD9pG2SbSr69K5eDxkSCy8ZHEFgemQnBbQPeAV7nHINR4QmKQVPBRyHju+aclCHh4Q+jetWOu63gAfN4PDotMA8IxveV6LLUrsl5vRNxaZnb0nGgIPmAisubZehnD6MwTBmtgx5pByls744VIHR/AyrrQ5DdfFEz37o/lN9UhOYdDbgY3oxuEwxUIkDukbPw/3/oagDNrgjQvXdkySo3Ek00y62uPu4x8kBPS9UWMoyqT3zQ2bhxmTRdOvJ2shWXFipVTI0ssLJ/v7YSpRqaopappIMxj8vMHIX2W9PkFZPhVuqXzZ8VwOzqSF8XBcQqgTkaBc2eY3lWUfQ535SXplyQOUG+JzzII+P7RPClNLIoed5PzP+2fqvOPeYLq7VObAvxqkZmAW8poQ+YBuXQ8O+zfE9CKKfNNtXOk1Q8yzsVKFYBHXoqiGGfodpmrQusppYUy1nwWR17gzUbTKJ22mBuGQ0K7zH7ne9Lipc+Xf4vkR/lFpIWlRnG/Kzk1SowlPv085Y1lK9GJiTN5t1ptV6udXRwHM+TAA1rN2sF8dUXskfvbsj7PvC5cWr/8G79plDY5p771Z39qlFsffQDgJkejqQqfpxzFskrKdfUiG95ceU6U778vdULK1mAS1pT6LclLdkS2traMcuO2UDCnGmeg8RWcQNeQ0WmtjWM0OVDtWbAlsNaN+SzbKigxVvaY4pE3/PZtNoiNZLn2yRP9YJt9betijqW1FgeRr/a2JYmy23kE4OJ5CZR98y++ZZRPb0rZ0xfefNMoRdqMMR8Bak9hoZOxk4wxE7ZSo3VRrIVlxYqVUyP2gWXFipVTI0sMVM9TjEAn5WgCYMrKA/Vea7HLiMzCpYrY8DWCFC2mqZEwoMgsj+7+DoCDHSmmP96TXKFzZwQklpgRVmeNSIMl+IhkMt0DMXojlm8Y5OEzzaRE7rQiqy6mdO4qZV2SKsCUSarBDxAkPi5JTN47JjRVGxUOIrUUiYI4El1o2VM2JmXtxAXgaZEQL0hQ0katrJhhMYq6+VdYGtWoNY3S6YkNX2utAUg47IiTvHhJMmsmpIXQNKkmmZd9+n11ETZWZf3rlSqATlcgVZX5cScB57+LKNTtD7v8hPk4BBFrLG351X/0D43y87/8K0bRhLKXX37FKDc++QSzLUs5Pr3Y+O//+b8Y5cMPhYv5lVdfM8rVa8KvsPonAjAPD/Y5TISZihZVUl2ldQFKD/aFjWAGDosSx8tvp8XPtW1PzIuo4Qv9ATo5piISZIjJFEtd/5HwXnz3h+8Z5a03XjRK90gSvrRipn0keVjOnixLWuEV6chX9+7fw0xIbUhyizLb8e4+pFNC64c0NMSomrY+Ukq/glfHDHDWB8uiWAvLihUrp0bsA8uKFSunRpZY8iOChYjG9mgwwkzqk/Ii+AuJLRMqI9qKdW3BUhe0ePZJ6dd4aecQQLEkSHBtVSKATbI1BDtijddr3Jdoca1JnjCSPYxGZPIreAAyzS5h+xm1ltXgnEw1SqgpQnJGWqwznSy34UfsBpQXEaQCOUOaxwViq/5EJqnceJoL5qMEoFRhIlUkUTxXg1u8Rg4JwmtVpgIx5DRhbKnclDVcPX8ZwHc/EESg+V+tVtMoY22xqaEZzlZLf0psi6vT3tnZBeCQgHxtQ0BQTpf4E2Q+SCiirV/G2tqzTyTCRjjPPifpUS8xn6jKQJIi5Zeek6+effp5zOB9h3ApYbLVN77+TaPcuCWRsvNnBSmvMgx99oyc2kcfCz7yvQxApyOQp0dc7CXKPSAHWueyzESfZRtd9jnRBYwXeoimjEdnZAGJCEJTHjHWPD5l8osdAH0WMF08L11tblwXTsTDQ0F5CZFmSAK/2lh+s1caTxnl+IEsnan9ijkBpVpZWyMzJXlB1AWkzVZd8o142sxV6UbSDLMI92SxFpYVK1ZOjdgHlhUrVk6NLIGEHZawZzRoo8kUQMwsu5wXQbEhYYVyJIz67NDVF4xz1BGrcpv9vuLxBMCjtuC+aCgINF4V3DcmFfpGS2zUFiNxTaYURgPS5hG7mRiKBo+0KaOam1pi7mtDR9ZbpAwXDtkKbMDGWXOiTVgdFjBpeYrLlDmXxrYW3Kh5XGLIcnN9HUDsCmTrH7MZGgOvASRps8E6D4+xlTEJ2h02Xitkss3b338PwDvvCyTM6foIGVoM/IWsdtIFGjIcXCYkVDR98OgRgFJVPi+0BbtVSPK9KEvq7vPYWQZygQB5y9sJJ6mVRltbV41yYUPoOlzG/pgbq61CUQoKAAKukuZksogLHjNjq+wm22LtlM+uoitapMVZu14AYEpc2dkX3gufrUO7rL/RWHOFhTjxWIZtd9hB93HRYhRtmKYAKl8uLSnLZBuFhFONr2k4chwDOLorBUYTpRJkSdPmatMoh20paXqwL8rHJJ4/c5G/Tf7233j5eQAhS6b2HgkuBoPyRToWfA1gav9XXjJHyRd5juZJk+Sc7id0UrAWlhUrVk6RLM3DIj/RQGyTbqcLoMvXqVYnqPu0yNIT/UTrJ3Obhf4/ZVNdqxUB7IzknbPP7JXjjmSInFkXD2id7+/QE7tj/768OpQmOPdnuz6ASCegLx0+8kO1sGhYaWdKbc9ZY4qT75ewTHxlj/LnncmZmqJa7pvKS7hZkjM5uymnVi07AHZph47YrEVZooZkcRrQc18iD3WfVt4KU6iOyDL8V995B8DukQyrPFytGluudnlxM1l/bRSuSu5bpfVaa9TwWJoM3cDREkNKziMn/9VaFu1INAUQs+TV4QaafhWSuezqk+L9BQ3nw35bRmNNeN5D1HVnJ6l+7knu2JYlrdVYbcYq6AKXv0Ur0uNL3bBmu3QMj5nyFrIWvUvryWGPmUZNzLQJk8J6h4+wTPRHl3umlfWM47MwKk/IAk2VYu6t5ldxCiDJ2EmXPGLrLfmkHsq+jYZMcr8jk/zwUKzI47/8vlFefEYM29VGFUDKwu8tpvVlnLY+UALaXA5o6ualOWo0yl/jfU9zE9JSJFuxYuX0i31gWbFi5dTIEkioTLUHB4IBb16/iRnvrKIhNek6tHi1EalyAAyHYl5qr0e1YyvlCoC1NTIE0XNfrYqyUWGy0lTGf3R/wG3FqiwRNVSqgrbKUh5EWKGkXdwyoK9dG4X26GKPSDlUoc+10SximWifGIWEE0Kbiie7aCJbysILNyaJkiuwq9PbA7B/LG7OAfPXnKlAEoeYutsXDMIoAiL6leNMwM4hYXtvNAYwJJSbEt17ZL2aMnIyg+ZkDk9ckUS5IcGUo61cXB/AiOGR7W1xPNcbNZwg47Yg/THpp5UJw/TR7LYlznBM2miHkK1MX3gpkAnc+OgdGY0FJdrRNucUdlzMZDYleqZ6FlNByg2WbSWE2yMeeqXG24bNkEajAYBeX27FB7usaNlnjYsr98zWNQFKBV/u7bsPhcygTz7iOUkXilE0KSkn5FJqNq3RySnGFC3K/27qADhHF3vKlL1Of/z4Hqhq32LGWPq8vg+74nS/PGKRXHMdgM/DuEqRTEY5ZSjxmTaYOepZZ47Y3wb9LCS0YsXKT4PYB5YVK1ZOjSyBhNFUSw3kk+PjNoCAbGGanlNmUxYNBUYsKHFon9eaYpSC3RPHTK7pDaYAWi0BRyXmGdVrsm8rZJ19qlPS9jDgHAQJhgShyWQMwGfYokDb1+W+MZFOxjNFXrXDE2FCWfEExrUkD3URt0ZkUygxdsaErNSfhy39RJDyoNcDMCJw9hh/ySBmuWaNOZybl+fasBCKCWsZj2h6uBbZSKZODmuPAHDKuKT2wKmtSLJbnxlzfRLjalzYxLA0qKr9XzV7aFGuv/vXRkmI0Vwm5kzGQwCTMct6xuQLZFyswlFr5HQcdgWE6rLHE+bQcVYGekyYyKZdOZWTo8rksIwJWZ5OiTHZBo/dZLhwHE8AdFg25LHl7bvvvCuzbcit3jy7ZZSDA8lsuvERmamnyyFhlvMu8COFhGRrcB0t0qI7gmeUsfxIT9ZxUzxGC671N3LKEUPMHj0YdbJH9pXogkBVHQiOHwLwGK/UnsT5/JWaXHOpPA2zEsny55bp/CVKyF/oQn2SirWwrFixcmrEPrCsWLFyamRZI1WS5Gkk7pXPvIgZ3oJUe3Qo7UGslQGa0iZ/tSC7wCBdic1Ws8IYQFhg8C5gx1OPliFHc8lrrqGHcok5nzQ8p0RVJssxb2vKKBKI8rRqR0kaHKLdSs5Ip71elkPCkK19qkwOTBMx3ccMF3oMQq1tbHC2emoCSyvlMmZab2ptReqzoEFT8rjIgXb91MVmbGuDDW+ubm0CeLB/aP6tV5syPKk4XJfx0IbsEoTkHSTBYeqx/IWXpl6vAoi4bpW67FIqLV8lAMODB5wlV1vbZGYJgIA34dkVWfzVCkt/tJssM28TBvg8oonRQCnrGJ/yPACOEkhwECX1X+Et4RG2qIchY4ZtsymTWV+XTMuHhx0AD/eF5ODlt94wCsuicMAI4MGekEDssLVqiaHGp68xA/Ybb2NGlP1dZ+uoK8NR3KTrNs9qsEiEkbkOgIxwMqf+YDuoIe80dTVcZLsslwdSnKd0KVkSAcjyDrvzWdk6t9SZd8Vozq3iVs2szpDh5L5Ks2ItLCtWrJwasQ8sK1asnBpZYskrqnMJCTfPbQJot7UwTdBQoy5xpZ1dyZ3rkaRBbdAG8zkd8hNoALHZdAD4OafdfNGZlu9rhDEhmhgw/22qQTRoDMIBEBJSpUQTXkHbZNHi5ZELxD4BdC+GlpYtEYA6EwtLFe2syVxZwi7NxDt3ecsovbZgtM6xzL/klQAEFZl8maukJBNqJXtcU8IL1IjmylWpTHSZqfjsM78BYO9IcjKnrN8cMvA3JB1jSHykSLlA3D3oS6rnmFHCMAwBHLPozPVlSrW6DPIH/+3reFzikRwx0ppBhuRMBNbl7bRSlUEukU4v1jCrImaGoRXOK0O/r4Wb6RRAwg0KmspIIOMxAzMM5AL5gRw64d2ysiokfBvrAufbgymAARewUSPx5LUrcoIsvVSqxatba5y2fFUqaNL1Y6LxMjUhNHlSs0O1u5cy+TkLRXnOYrTx8f983lBBoDFH+WqzJqe8slqbO2KFHhjfhCM5pZh5yDPJ5HpK88fWuWmZpPbZmyd1xDzmVbEWlhUrVk6NLDEf1ljPPWHByt7DfQADZu7oO63Dogpt7F5mLlVeWEBFa2jU+26eqpowpalPE3aUSflujPle1eSaESdTpAvf44FMmb7yE+hxS2W+Tlma4+hbekR7iqU5+Ymc4AgMQy4drb8Se757ZO91QvlEE5fUQoy0z1ChBODSplTDVMqy7/GhpPBoX1uPZ+TrW4YvKG04OuYbb+3sWQDPPCsvfyWDznKfLp2jmhSTc0vQn61FQBSz+1iJN5jQ5NHU+tf4j3O7KGmvu5A9NJA4iRxFAyk/88bnjPLUM6/I6azJSg5GbaNo59aU5FDawF1KWHg/jHjzuMycGkznm+Devrctn0zkvuonTOvjbyQMygAmzH2bcnGqVZlAhaTAai/UAjFVcnPphH4wGdSe4mwfpw8G/eiYdXDr1fzbyIXVhe8xx62gRGI0sYplTbTkiXC2ahQbA01ZrZW2QSW/yksm4cxt4/EHPo2NUTzPArIo1sKyYsXKqRH7wLJixcqpkSWQsH0otKeDriC+9tExAJcZFgVmLc0Yq4RU2s9D+9DQTk66ZHQgPDEpSGO6qDUfX/cNyQsYEmSVWFleq4ubWV25WrUTGuhEj3pKi3fI7iNePjXln+A60Fg96Iw57PI2J0WlfCChcFBjp9hQgEB7KGd08/odOUetzCBO8ZwAQMBkNGWbW2mIT1eRoDLVVZmmNGoLf4NHE1orS/b39gCUalJkr+BaO+Jo5lfAkiYdRJdUs5OUwNe4S0tlWa6EHYm0M82i+FxkZdbWxCK3EABIIjncgC4ILZCaktn3xg2BbN/+tjj1xwyPKBmA8mf4foCZOo8h4zPquZ+SpUNb43zjm38p585hP70v4ZHDHrPqPB9AkYTLMUG9x8vhaOegHNaBc5NFSE7IV8tyAEiEzltR0eLfBQAu8ByoZ0bxsnxRoJt8xAZRY/4wA3ZFcjkZ7/FQQH4MDqtoTieg9Mf5qS0oM/N+7LyyhUQzFWthWbFi5dSIfWBZsWLl1MiyrjmM/WlOfau5AqBUJkkb8cWI6Tl37kkNvVZoF0NWP3BjBXoTbcgaRZjJeMoWaL2SvBfm/INVTdARu5j4LBxv1JoAiuU6R2P0cDrP6DYiLMpYxqHYJ+XKuCc80stsHKLkFvu7Eter1gUtDsdyRsdMaAqZC6YVRQdHHQAeU94aZBlXuBoSgCsRQ0QMpTxqEcN25Yog5cy0eGEsLVsICU0nMoMeoXqRjAsaYJ2yoH8SycaGBNElQkjIxZhmy4NfmCl/yQj04pwOJAWQMWvpnXd/aJT9fQnV1Vj6UyEvyOYq+ydtSCxbqSLVpWBSzLTPwJAAcMAbr9tuG2XrkuRJXb78ihyxKbfNjbtSVfOjDz+SM8kyAHXGmissySoSEiakIdGUQM/TXCdNblwOdtT7MVHSRy0byuGYIqb5Op5sIQAnx104mtIBelx2TVjrHsuyh3SahLxtQrojzC3nsnguvxMWIGF+RCpLsKGWH7kuZhwLP6FGx1pYVqxYOTViH1hWrFg5NbIEEmpfLLXLgrAAYGdPCJ6Vpj1j4qKnJRS0Y7URaUJsmBN6KVkaEgBBUcsj2FiJikYhnbz3kTI6EOjR4lQO7/ZRB0DQI0sfT2dJBDMPMLH6nLAoL1w4ISKjjTw1xTQmy3iv05YjssFRmXA14V5TZnjGaQRgwFBpleUpYGA0dQUsjIZisWv5SImRSo+NvxI2JTNWd85LRxSmDPQjksdrC6+UxnuR5UFaGVNg9Uzq+AASrb/3tKxKcybnpXskoUxtKxvzSvV7PQAxP3/9jTdl8rzu2p+1we5kF84J5i2WeO6MdmU5i2SEmb6/PpnvC7y4ffKyq5RYRKUg6oWrTxjl2ScvzW7p85RzzkJ2ZkvIrK/gtEhuPJeWgeadzom2jE0XCPBmwDYh4dz/M6KxcpMY7CykceqvO81RpLZllVsi1kgiKQPTPOZuRluMYM5PSSuK3Lwp2QJa5FeOm83ukthGqlasWLFixYoVK1asWLFixYoVK1asWLFixYoVK1asWLFixYqVUyf/D3PcGe48X+nJAAAAAElFTkSuQmCC\n", - "text/plain": [ - "" - ] - }, - "execution_count": 6, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "dataiter = iter(trainloader)\n", - "images, labels = dataiter.next() # 返回4张图片及标签\n", - "print(' '.join('%11s'%classes[labels[j]] for j in range(4)))\n", - "show(tv.utils.make_grid((images+1)/2)).resize((400,100))" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### 4.2 定义网络\n", - "\n", - "拷贝上面的LeNet网络,修改self.conv1第一个参数为3通道,因CIFAR-10是3通道彩图。" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Net(\n", - " (conv1): Conv2d(3, 6, kernel_size=(5, 5), stride=(1, 1))\n", - " (conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))\n", - " (fc1): Linear(in_features=400, out_features=120, bias=True)\n", - " (fc2): Linear(in_features=120, out_features=84, bias=True)\n", - " (fc3): Linear(in_features=84, out_features=10, bias=True)\n", - ")\n" - ] - } - ], - "source": [ - "import torch.nn as nn\n", - "import torch.nn.functional as F\n", - "\n", - "class Net(nn.Module):\n", - " def __init__(self):\n", - " super(Net, self).__init__()\n", - " self.conv1 = nn.Conv2d(3, 6, 5) \n", - " self.conv2 = nn.Conv2d(6, 16, 5) \n", - " self.fc1 = nn.Linear(16*5*5, 120) \n", - " self.fc2 = nn.Linear(120, 84)\n", - " self.fc3 = nn.Linear(84, 10)\n", - "\n", - " def forward(self, x): \n", - " x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2)) \n", - " x = F.max_pool2d(F.relu(self.conv2(x)), 2) \n", - " x = x.view(x.size()[0], -1) \n", - " x = F.relu(self.fc1(x))\n", - " x = F.relu(self.fc2(x))\n", - " x = self.fc3(x) \n", - " return x\n", - "\n", - "\n", - "net = Net()\n", - "print(net)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### 4.3 定义损失函数和优化器(loss和optimizer)" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "metadata": {}, - "outputs": [], - "source": [ - "from torch import optim\n", - "criterion = nn.CrossEntropyLoss() # 交叉熵损失函数\n", - "optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### 4.4 训练网络\n", - "\n", - "所有网络的训练流程都是类似的,不断地执行如下流程:\n", - "\n", - "- 输入数据\n", - "- 前向传播+反向传播\n", - "- 更新参数\n" - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/home/bushuhui/.virtualenv/dl/lib/python3.5/site-packages/ipykernel_launcher.py:25: UserWarning: invalid index of a 0-dim tensor. This will be an error in PyTorch 0.5. Use tensor.item() to convert a 0-dim tensor to a Python number\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[1, 2000] loss: 2.210\n", - "[1, 4000] loss: 1.958\n", - "[1, 6000] loss: 1.723\n", - "[1, 8000] loss: 1.590\n", - "[1, 10000] loss: 1.532\n", - "[1, 12000] loss: 1.467\n", - "[2, 2000] loss: 1.408\n", - "[2, 4000] loss: 1.374\n", - "[2, 6000] loss: 1.345\n", - "[2, 8000] loss: 1.331\n", - "[2, 10000] loss: 1.338\n", - "[2, 12000] loss: 1.286\n", - "Finished Training\n" - ] - } - ], - "source": [ - "from torch.autograd import Variable\n", - "\n", - "t.set_num_threads(8)\n", - "for epoch in range(2): \n", - " \n", - " running_loss = 0.0\n", - " for i, data in enumerate(trainloader, 0):\n", - " \n", - " # 输入数据\n", - " inputs, labels = data\n", - " inputs, labels = Variable(inputs), Variable(labels)\n", - " \n", - " # 梯度清零\n", - " optimizer.zero_grad()\n", - " \n", - " # forward + backward \n", - " outputs = net(inputs)\n", - " loss = criterion(outputs, labels)\n", - " loss.backward() \n", - " \n", - " # 更新参数 \n", - " optimizer.step()\n", - " \n", - " # 打印log信息\n", - " running_loss += loss.data[0]\n", - " if i % 2000 == 1999: # 每2000个batch打印一下训练状态\n", - " print('[%d, %5d] loss: %.3f' \\\n", - " % (epoch+1, i+1, running_loss / 2000))\n", - " running_loss = 0.0\n", - "print('Finished Training')" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "此处仅训练了2个epoch(遍历完一遍数据集称为一个epoch),来看看网络有没有效果。将测试图片输入到网络中,计算它的label,然后与实际的label进行比较。" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "lines_to_next_cell": 2 - }, - "outputs": [], - "source": [ - "dataiter = iter(testloader)\n", - "images, labels = dataiter.next() # 一个batch返回4张图片\n", - "print('实际的label: ', ' '.join(\\\n", - " '%08s'%classes[labels[j]] for j in range(4)))\n", - "show(tv.utils.make_grid(images / 2 - 0.5)).resize((400,100))" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "接着计算网络预测的label:" - ] - }, - { - "cell_type": "code", - "execution_count": 12, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "预测结果: cat ship ship ship\n" - ] - } - ], - "source": [ - "# 计算图片在每个类别上的分数\n", - "outputs = net(Variable(images))\n", - "# 得分最高的那个类\n", - "_, predicted = t.max(outputs.data, 1)\n", - "\n", - "print('预测结果: ', ' '.join('%5s'\\\n", - " % classes[predicted[j]] for j in range(4)))" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "已经可以看出效果,准确率50%,但这只是一部分的图片,再来看看在整个测试集上的效果。" - ] - }, - { - "cell_type": "code", - "execution_count": 13, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "10000张测试集中的准确率为: 54 %\n" - ] - } - ], - "source": [ - "correct = 0 # 预测正确的图片数\n", - "total = 0 # 总共的图片数\n", - "for data in testloader:\n", - " images, labels = data\n", - " outputs = net(Variable(images))\n", - " _, predicted = t.max(outputs.data, 1)\n", - " total += labels.size(0)\n", - " correct += (predicted == labels).sum()\n", - "\n", - "print('10000张测试集中的准确率为: %d %%' % (100 * correct / total))" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "训练的准确率远比随机猜测(准确率10%)好,证明网络确实学到了东西。" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### 4.5 在GPU训练\n", - "就像之前把Tensor从CPU转到GPU一样,模型也可以类似地从CPU转到GPU。" - ] - }, - { - "cell_type": "code", - "execution_count": 44, - "metadata": {}, - "outputs": [], - "source": [ - "if t.cuda.is_available():\n", - " net.cuda()\n", - " images = images.cuda()\n", - " labels = labels.cuda()\n", - " output = net(Variable(images))\n", - " loss= criterion(output,Variable(labels))" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "如果发现在GPU上并没有比CPU提速很多,实际上是因为网络比较小,GPU没有完全发挥自己的真正实力。" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "对PyTorch的基础介绍至此结束。总结一下,本节主要包含以下内容。\n", - "\n", - "1. Tensor: 类似Numpy数组的数据结构,与Numpy接口类似,可方便地互相转换。\n", - "2. autograd/Variable: Variable封装了Tensor,并提供自动求导功能。\n", - "3. nn: 专门为神经网络设计的接口,提供了很多有用的功能(神经网络层,损失函数,优化器等)。\n", - "4. 神经网络训练: 以CIFAR-10分类为例演示了神经网络的训练流程,包括数据加载、网络搭建、训练及测试。\n", - "\n", - "通过本节的学习,相信读者可以体会出PyTorch具有接口简单、使用灵活等特点。从下一章开始,本书将深入系统地讲解PyTorch的各部分知识。" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.6.9" - } - }, - "nbformat": 4, - "nbformat_minor": 2 -} diff --git a/6_pytorch/1_NN/data.txt b/6_pytorch/data.txt similarity index 100% rename from 6_pytorch/1_NN/data.txt rename to 6_pytorch/data.txt diff --git a/6_pytorch/1_NN/imgs/ResNet.png b/6_pytorch/imgs/ResNet.png similarity index 100% rename from 6_pytorch/1_NN/imgs/ResNet.png rename to 6_pytorch/imgs/ResNet.png diff --git a/6_pytorch/0_basic/imgs/com_graph.svg b/6_pytorch/imgs/com_graph.svg similarity index 100% rename from 6_pytorch/0_basic/imgs/com_graph.svg rename to 6_pytorch/imgs/com_graph.svg diff --git a/6_pytorch/0_basic/imgs/com_graph_backward.svg b/6_pytorch/imgs/com_graph_backward.svg similarity index 100% rename from 6_pytorch/0_basic/imgs/com_graph_backward.svg rename to 6_pytorch/imgs/com_graph_backward.svg diff --git a/6_pytorch/1_NN/imgs/lena.png b/6_pytorch/imgs/lena.png similarity index 100% rename from 6_pytorch/1_NN/imgs/lena.png rename to 6_pytorch/imgs/lena.png diff --git a/6_pytorch/1_NN/imgs/lena3.png b/6_pytorch/imgs/lena3.png similarity index 100% rename from 6_pytorch/1_NN/imgs/lena3.png rename to 6_pytorch/imgs/lena3.png diff --git a/6_pytorch/1_NN/imgs/lena512.png b/6_pytorch/imgs/lena512.png similarity index 100% rename from 6_pytorch/1_NN/imgs/lena512.png rename to 6_pytorch/imgs/lena512.png diff --git a/6_pytorch/1_NN/imgs/linear_sep.png b/6_pytorch/imgs/linear_sep.png similarity index 100% rename from 6_pytorch/1_NN/imgs/linear_sep.png rename to 6_pytorch/imgs/linear_sep.png diff --git a/6_pytorch/1_NN/imgs/multi_perceptron.png b/6_pytorch/imgs/multi_perceptron.png similarity index 100% rename from 6_pytorch/1_NN/imgs/multi_perceptron.png rename to 6_pytorch/imgs/multi_perceptron.png diff --git a/6_pytorch/1_NN/imgs/nn-forward.gif b/6_pytorch/imgs/nn-forward.gif similarity index 100% rename from 6_pytorch/1_NN/imgs/nn-forward.gif rename to 6_pytorch/imgs/nn-forward.gif diff --git a/6_pytorch/1_NN/imgs/residual.png b/6_pytorch/imgs/residual.png similarity index 100% rename from 6_pytorch/1_NN/imgs/residual.png rename to 6_pytorch/imgs/residual.png diff --git a/6_pytorch/1_NN/imgs/resnet1.png b/6_pytorch/imgs/resnet1.png similarity index 100% rename from 6_pytorch/1_NN/imgs/resnet1.png rename to 6_pytorch/imgs/resnet1.png diff --git a/6_pytorch/0_basic/imgs/tensor_data_structure.svg b/6_pytorch/imgs/tensor_data_structure.svg similarity index 100% rename from 6_pytorch/0_basic/imgs/tensor_data_structure.svg rename to 6_pytorch/imgs/tensor_data_structure.svg diff --git a/6_pytorch/1_NN/imgs/trans.bkp.PNG b/6_pytorch/imgs/trans.bkp.PNG similarity index 100% rename from 6_pytorch/1_NN/imgs/trans.bkp.PNG rename to 6_pytorch/imgs/trans.bkp.PNG diff --git a/6_pytorch/1_NN/optimizer/6_1-sgd.ipynb b/6_pytorch/optimizer/6_1-sgd.ipynb similarity index 100% rename from 6_pytorch/1_NN/optimizer/6_1-sgd.ipynb rename to 6_pytorch/optimizer/6_1-sgd.ipynb diff --git a/6_pytorch/1_NN/optimizer/6_2-momentum.ipynb b/6_pytorch/optimizer/6_2-momentum.ipynb similarity index 100% rename from 6_pytorch/1_NN/optimizer/6_2-momentum.ipynb rename to 6_pytorch/optimizer/6_2-momentum.ipynb diff --git a/6_pytorch/1_NN/optimizer/6_3-adagrad.ipynb b/6_pytorch/optimizer/6_3-adagrad.ipynb similarity index 100% rename from 6_pytorch/1_NN/optimizer/6_3-adagrad.ipynb rename to 6_pytorch/optimizer/6_3-adagrad.ipynb diff --git a/6_pytorch/1_NN/optimizer/6_4-rmsprop.ipynb b/6_pytorch/optimizer/6_4-rmsprop.ipynb similarity index 100% rename from 6_pytorch/1_NN/optimizer/6_4-rmsprop.ipynb rename to 6_pytorch/optimizer/6_4-rmsprop.ipynb diff --git a/6_pytorch/1_NN/optimizer/6_5-adadelta.ipynb b/6_pytorch/optimizer/6_5-adadelta.ipynb similarity index 100% rename from 6_pytorch/1_NN/optimizer/6_5-adadelta.ipynb rename to 6_pytorch/optimizer/6_5-adadelta.ipynb diff --git a/6_pytorch/1_NN/optimizer/6_6-adam.ipynb b/6_pytorch/optimizer/6_6-adam.ipynb similarity index 100% rename from 6_pytorch/1_NN/optimizer/6_6-adam.ipynb rename to 6_pytorch/optimizer/6_6-adam.ipynb diff --git a/6_pytorch/2_CNN/1-basic_conv.ipynb b/7_deep_learning/1_CNN/1-basic_conv.ipynb similarity index 100% rename from 6_pytorch/2_CNN/1-basic_conv.ipynb rename to 7_deep_learning/1_CNN/1-basic_conv.ipynb diff --git a/6_pytorch/2_CNN/2-batch-normalization.ipynb b/7_deep_learning/1_CNN/2-batch-normalization.ipynb similarity index 100% rename from 6_pytorch/2_CNN/2-batch-normalization.ipynb rename to 7_deep_learning/1_CNN/2-batch-normalization.ipynb diff --git a/6_pytorch/2_CNN/3-lr-decay.ipynb b/7_deep_learning/1_CNN/3-lr-decay.ipynb similarity index 100% rename from 6_pytorch/2_CNN/3-lr-decay.ipynb rename to 7_deep_learning/1_CNN/3-lr-decay.ipynb diff --git a/6_pytorch/2_CNN/4-data-augumentation.ipynb b/7_deep_learning/1_CNN/4-data-augumentation.ipynb similarity index 100% rename from 6_pytorch/2_CNN/4-data-augumentation.ipynb rename to 7_deep_learning/1_CNN/4-data-augumentation.ipynb diff --git a/6_pytorch/2_CNN/4-regularization.ipynb b/7_deep_learning/1_CNN/4-regularization.ipynb similarity index 100% rename from 6_pytorch/2_CNN/4-regularization.ipynb rename to 7_deep_learning/1_CNN/4-regularization.ipynb diff --git a/6_pytorch/2_CNN/5-data-augumentation.ipynb b/7_deep_learning/1_CNN/5-data-augumentation.ipynb similarity index 100% rename from 6_pytorch/2_CNN/5-data-augumentation.ipynb rename to 7_deep_learning/1_CNN/5-data-augumentation.ipynb diff --git a/6_pytorch/2_CNN/5-regularization.ipynb b/7_deep_learning/1_CNN/5-regularization.ipynb similarity index 100% rename from 6_pytorch/2_CNN/5-regularization.ipynb rename to 7_deep_learning/1_CNN/5-regularization.ipynb diff --git a/6_pytorch/2_CNN/6-vgg.ipynb b/7_deep_learning/1_CNN/6-vgg.ipynb similarity index 100% rename from 6_pytorch/2_CNN/6-vgg.ipynb rename to 7_deep_learning/1_CNN/6-vgg.ipynb diff --git a/6_pytorch/2_CNN/7-googlenet.ipynb b/7_deep_learning/1_CNN/7-googlenet.ipynb similarity index 100% rename from 6_pytorch/2_CNN/7-googlenet.ipynb rename to 7_deep_learning/1_CNN/7-googlenet.ipynb diff --git a/6_pytorch/2_CNN/8-resnet.ipynb b/7_deep_learning/1_CNN/8-resnet.ipynb similarity index 100% rename from 6_pytorch/2_CNN/8-resnet.ipynb rename to 7_deep_learning/1_CNN/8-resnet.ipynb diff --git a/6_pytorch/2_CNN/9-densenet.ipynb b/7_deep_learning/1_CNN/9-densenet.ipynb similarity index 100% rename from 6_pytorch/2_CNN/9-densenet.ipynb rename to 7_deep_learning/1_CNN/9-densenet.ipynb diff --git a/6_pytorch/2_CNN/CNN_Introduction.pptx b/7_deep_learning/1_CNN/CNN_Introduction.pptx similarity index 100% rename from 6_pytorch/2_CNN/CNN_Introduction.pptx rename to 7_deep_learning/1_CNN/CNN_Introduction.pptx diff --git a/6_pytorch/2_CNN/README.md b/7_deep_learning/1_CNN/README.md similarity index 100% rename from 6_pytorch/2_CNN/README.md rename to 7_deep_learning/1_CNN/README.md diff --git a/6_pytorch/2_CNN/cat.png b/7_deep_learning/1_CNN/cat.png similarity index 100% rename from 6_pytorch/2_CNN/cat.png rename to 7_deep_learning/1_CNN/cat.png diff --git a/6_pytorch/2_CNN/images/data_normalize.png b/7_deep_learning/1_CNN/images/data_normalize.png similarity index 100% rename from 6_pytorch/2_CNN/images/data_normalize.png rename to 7_deep_learning/1_CNN/images/data_normalize.png diff --git a/6_pytorch/2_CNN/utils.py b/7_deep_learning/1_CNN/utils.py similarity index 100% rename from 6_pytorch/2_CNN/utils.py rename to 7_deep_learning/1_CNN/utils.py diff --git a/6_pytorch/3_RNN/nlp/n-gram.ipynb b/7_deep_learning/2_RNN/nlp/n-gram.ipynb similarity index 100% rename from 6_pytorch/3_RNN/nlp/n-gram.ipynb rename to 7_deep_learning/2_RNN/nlp/n-gram.ipynb diff --git a/6_pytorch/3_RNN/nlp/seq-lstm.ipynb b/7_deep_learning/2_RNN/nlp/seq-lstm.ipynb similarity index 100% rename from 6_pytorch/3_RNN/nlp/seq-lstm.ipynb rename to 7_deep_learning/2_RNN/nlp/seq-lstm.ipynb diff --git a/6_pytorch/3_RNN/nlp/word-embedding.ipynb b/7_deep_learning/2_RNN/nlp/word-embedding.ipynb similarity index 100% rename from 6_pytorch/3_RNN/nlp/word-embedding.ipynb rename to 7_deep_learning/2_RNN/nlp/word-embedding.ipynb diff --git a/6_pytorch/3_RNN/pytorch-rnn.ipynb b/7_deep_learning/2_RNN/pytorch-rnn.ipynb similarity index 100% rename from 6_pytorch/3_RNN/pytorch-rnn.ipynb rename to 7_deep_learning/2_RNN/pytorch-rnn.ipynb diff --git a/6_pytorch/3_RNN/rnn-for-image.ipynb b/7_deep_learning/2_RNN/rnn-for-image.ipynb similarity index 100% rename from 6_pytorch/3_RNN/rnn-for-image.ipynb rename to 7_deep_learning/2_RNN/rnn-for-image.ipynb diff --git a/6_pytorch/3_RNN/time-series/data.csv b/7_deep_learning/2_RNN/time-series/data.csv similarity index 100% rename from 6_pytorch/3_RNN/time-series/data.csv rename to 7_deep_learning/2_RNN/time-series/data.csv diff --git a/6_pytorch/3_RNN/time-series/lstm-time-series.ipynb b/7_deep_learning/2_RNN/time-series/lstm-time-series.ipynb similarity index 100% rename from 6_pytorch/3_RNN/time-series/lstm-time-series.ipynb rename to 7_deep_learning/2_RNN/time-series/lstm-time-series.ipynb diff --git a/6_pytorch/3_RNN/utils.py b/7_deep_learning/2_RNN/utils.py similarity index 100% rename from 6_pytorch/3_RNN/utils.py rename to 7_deep_learning/2_RNN/utils.py diff --git a/6_pytorch/4_GAN/autoencoder.ipynb b/7_deep_learning/3_GAN/autoencoder.ipynb similarity index 100% rename from 6_pytorch/4_GAN/autoencoder.ipynb rename to 7_deep_learning/3_GAN/autoencoder.ipynb diff --git a/6_pytorch/4_GAN/gan.ipynb b/7_deep_learning/3_GAN/gan.ipynb similarity index 100% rename from 6_pytorch/4_GAN/gan.ipynb rename to 7_deep_learning/3_GAN/gan.ipynb diff --git a/6_pytorch/4_GAN/vae.ipynb b/7_deep_learning/3_GAN/vae.ipynb similarity index 100% rename from 6_pytorch/4_GAN/vae.ipynb rename to 7_deep_learning/3_GAN/vae.ipynb diff --git a/6_pytorch/5_NLP/README.md b/7_deep_learning/4_NLP/README.md similarity index 100% rename from 6_pytorch/5_NLP/README.md rename to 7_deep_learning/4_NLP/README.md diff --git a/6_pytorch/5_NLP/Word2Vec.ipynb b/7_deep_learning/4_NLP/Word2Vec.ipynb similarity index 100% rename from 6_pytorch/5_NLP/Word2Vec.ipynb rename to 7_deep_learning/4_NLP/Word2Vec.ipynb diff --git a/6_pytorch/5_NLP/images/word2vec_01.jpeg b/7_deep_learning/4_NLP/images/word2vec_01.jpeg similarity index 100% rename from 6_pytorch/5_NLP/images/word2vec_01.jpeg rename to 7_deep_learning/4_NLP/images/word2vec_01.jpeg diff --git a/6_pytorch/5_NLP/images/word2vec_02.jpeg b/7_deep_learning/4_NLP/images/word2vec_02.jpeg similarity index 100% rename from 6_pytorch/5_NLP/images/word2vec_02.jpeg rename to 7_deep_learning/4_NLP/images/word2vec_02.jpeg diff --git a/6_pytorch/5_NLP/images/word2vec_03.jpeg b/7_deep_learning/4_NLP/images/word2vec_03.jpeg similarity index 100% rename from 6_pytorch/5_NLP/images/word2vec_03.jpeg rename to 7_deep_learning/4_NLP/images/word2vec_03.jpeg diff --git a/6_pytorch/5_NLP/images/word2vec_04.jpeg b/7_deep_learning/4_NLP/images/word2vec_04.jpeg similarity index 100% rename from 6_pytorch/5_NLP/images/word2vec_04.jpeg rename to 7_deep_learning/4_NLP/images/word2vec_04.jpeg diff --git a/6_pytorch/5_NLP/images/word2vec_05.jpeg b/7_deep_learning/4_NLP/images/word2vec_05.jpeg similarity index 100% rename from 6_pytorch/5_NLP/images/word2vec_05.jpeg rename to 7_deep_learning/4_NLP/images/word2vec_05.jpeg diff --git a/6_pytorch/5_NLP/images/word2vec_06.jpeg b/7_deep_learning/4_NLP/images/word2vec_06.jpeg similarity index 100% rename from 6_pytorch/5_NLP/images/word2vec_06.jpeg rename to 7_deep_learning/4_NLP/images/word2vec_06.jpeg diff --git a/6_pytorch/5_NLP/images/word2vec_07.jpeg b/7_deep_learning/4_NLP/images/word2vec_07.jpeg similarity index 100% rename from 6_pytorch/5_NLP/images/word2vec_07.jpeg rename to 7_deep_learning/4_NLP/images/word2vec_07.jpeg diff --git a/6_pytorch/5_NLP/images/word2vec_08.jpeg b/7_deep_learning/4_NLP/images/word2vec_08.jpeg similarity index 100% rename from 6_pytorch/5_NLP/images/word2vec_08.jpeg rename to 7_deep_learning/4_NLP/images/word2vec_08.jpeg diff --git a/6_pytorch/5_NLP/images/word2vec_09.jpeg b/7_deep_learning/4_NLP/images/word2vec_09.jpeg similarity index 100% rename from 6_pytorch/5_NLP/images/word2vec_09.jpeg rename to 7_deep_learning/4_NLP/images/word2vec_09.jpeg diff --git a/6_pytorch/5_NLP/images/word2vec_10.jpeg b/7_deep_learning/4_NLP/images/word2vec_10.jpeg similarity index 100% rename from 6_pytorch/5_NLP/images/word2vec_10.jpeg rename to 7_deep_learning/4_NLP/images/word2vec_10.jpeg diff --git a/6_pytorch/5_NLP/images/word2vec_11.jpeg b/7_deep_learning/4_NLP/images/word2vec_11.jpeg similarity index 100% rename from 6_pytorch/5_NLP/images/word2vec_11.jpeg rename to 7_deep_learning/4_NLP/images/word2vec_11.jpeg diff --git a/README.md b/README.md index 477fb57..2f76272 100644 --- a/README.md +++ b/README.md @@ -50,25 +50,26 @@ - [nn/param_initialize](6_pytorch/1_NN/5-param_initialize.ipynb) - [optim/sgd](6_pytorch/1_NN/optimizer/6_1-sgd.ipynb) - [optim/adam](6_pytorch/1_NN/optimizer/6_6-adam.ipynb) +9. [Deep Learning](7_deep_learning/README.md) - CNN - - [CNN Introduction](6_pytorch/2_CNN/CNN_Introduction.pptx) + - [CNN Introduction](7_deep_learning/1_CNN/CNN_Introduction.pptx) - [CNN simple demo](demo_code/3_CNN_MNIST.py) - - [cnn/basic_conv](6_pytorch/2_CNN/1-basic_conv.ipynb) - - [cnn/batch-normalization](6_pytorch/2_CNN/2-batch-normalization.ipynb) - - [cnn/lr-decay](6_pytorch/2_CNN/3-lr-decay.ipynb) - - [cnn/regularization](6_pytorch/2_CNN/4-regularization.ipynb) - - [cnn/vgg](6_pytorch/2_CNN/6-vgg.ipynb) - - [cnn/googlenet](6_pytorch/2_CNN/7-googlenet.ipynb) - - [cnn/resnet](6_pytorch/2_CNN/8-resnet.ipynb) - - [cnn/densenet](6_pytorch/2_CNN/9-densenet.ipynb) + - [cnn/basic_conv](7_deep_learning/1_CNN/1-basic_conv.ipynb) + - [cnn/batch-normalization](7_deep_learning/1_CNN/2-batch-normalization.ipynb) + - [cnn/lr-decay](7_deep_learning/2_CNN/1-lr-decay.ipynb) + - [cnn/regularization](7_deep_learning/1_CNN/4-regularization.ipynb) + - [cnn/vgg](7_deep_learning/1_CNN/6-vgg.ipynb) + - [cnn/googlenet](7_deep_learning/1_CNN/7-googlenet.ipynb) + - [cnn/resnet](7_deep_learning/1_CNN/8-resnet.ipynb) + - [cnn/densenet](7_deep_learning/1_CNN/9-densenet.ipynb) - RNN - - [rnn/pytorch-rnn](6_pytorch/3_RNN/pytorch-rnn.ipynb) - - [rnn/rnn-for-image](6_pytorch/3_RNN/rnn-for-image.ipynb) - - [rnn/lstm-time-series](6_pytorch/3_RNN/time-series/lstm-time-series.ipynb) + - [rnn/pytorch-rnn](7_deep_learning/2_RNN/pytorch-rnn.ipynb) + - [rnn/rnn-for-image](7_deep_learning/2_RNN/rnn-for-image.ipynb) + - [rnn/lstm-time-series](7_deep_learning/2_RNN/time-series/lstm-time-series.ipynb) - GAN - - [gan/autoencoder](6_pytorch/4_GAN/autoencoder.ipynb) - - [gan/vae](6_pytorch/4_GAN/vae.ipynb) - - [gan/gan](6_pytorch/4_GAN/gan.ipynb) + - [gan/autoencoder](7_deep_learning/3_GAN/autoencoder.ipynb) + - [gan/vae](7_deep_learning/3_GAN/vae.ipynb) + - [gan/gan](7_deep_learning/3_GAN/gan.ipynb) diff --git a/references_tips/InstallPython.md b/references_tips/InstallPython.md index 1c2a319..a75cf10 100644 --- a/references_tips/InstallPython.md +++ b/references_tips/InstallPython.md @@ -41,10 +41,26 @@ bash ./Anaconda3-2020.11-Linux-x86_64.sh 参考这里的[conda安装和软件源设置说明](https://mirrors.bfsu.edu.cn/help/anaconda/) -``` -conda config --set show_channel_urls yes -conda config --add channels https://mirrors.bfsu.edu.cn/anaconda/pkgs/main/ -conda config --add channels https://mirrors.bfsu.edu.cn/anaconda/pkgs/free/ + +各系统都可以通过修改用户目录下的 `.condarc` 文件。Windows 用户无法直接创建名为 `.condarc` 的文件,可先执行 `conda config --set show_channel_urls yes` 生成该文件之后再修改。 + +Linux下,打开文件编辑器 `gedit ~/.condarc`,然后把下面的内容拷贝到这个文件中: +``` +channels: + - defaults +show_channel_urls: true +default_channels: + - https://mirrors.bfsu.edu.cn/anaconda/pkgs/main + - https://mirrors.bfsu.edu.cn/anaconda/pkgs/r + - https://mirrors.bfsu.edu.cn/anaconda/pkgs/msys2 +custom_channels: + conda-forge: https://mirrors.bfsu.edu.cn/anaconda/cloud + msys2: https://mirrors.bfsu.edu.cn/anaconda/cloud + bioconda: https://mirrors.bfsu.edu.cn/anaconda/cloud + menpo: https://mirrors.bfsu.edu.cn/anaconda/cloud + pytorch: https://mirrors.bfsu.edu.cn/anaconda/cloud + pytorch-lts: https://mirrors.bfsu.edu.cn/anaconda/cloud + simpleitk: https://mirrors.bfsu.edu.cn/anaconda/cloud ```