{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "## 3.2 autograd\n", "\n", "用Tensor训练网络很方便,但从上一小节最后的线性回归例子来看,反向传播过程需要手动实现。这对于像线性回归等较为简单的模型来说,还可以应付,但实际使用中经常出现非常复杂的网络结构,此时如果手动实现反向传播,不仅费时费力,而且容易出错,难以检查。torch.autograd就是为方便用户使用,而专门开发的一套自动求导引擎,它能够根据输入和前向传播过程自动构建计算图,并执行反向传播。\n", "\n", "计算图(Computation Graph)是现代深度学习框架如PyTorch和TensorFlow等的核心,其为高效自动求导算法——反向传播(Back Propogation)提供了理论支持,了解计算图在实际写程序过程中会有极大的帮助。本节将涉及一些基础的计算图知识,但并不要求读者事先对此有深入的了解。关于计算图的基础知识推荐阅读Christopher Olah的文章[^1]。\n", "\n", "[^1]: http://colah.github.io/posts/2015-08-Backprop/\n", "\n", "\n", "### 3.2.1 Variable\n", "PyTorch在autograd模块中实现了计算图的相关功能,autograd中的核心数据结构是Variable。Variable封装了tensor,并记录对tensor的操作记录用来构建计算图。Variable的数据结构如图3-2所示,主要包含三个属性:\n", "\n", "- `data`:保存variable所包含的tensor\n", "- `grad`:保存`data`对应的梯度,`grad`也是variable,而不是tensor,它与`data`形状一致。 \n", "- `grad_fn`: 指向一个`Function`,记录tensor的操作历史,即它是什么操作的输出,用来构建计算图。如果某一个变量是由用户创建,则它为叶子节点,对应的grad_fn等于None。\n", "\n", "\n", "![图3-2:Variable数据结构](imgs/autograd_Variable.png)\n", "\n", "Variable的构造函数需要传入tensor,同时有两个可选参数:\n", "- `requires_grad (bool)`:是否需要对该variable进行求导\n", "- `volatile (bool)`:意为”挥发“,设置为True,则构建在该variable之上的图都不会求导,专为推理阶段设计\n", "\n", "Variable提供了大部分tensor支持的函数,但其不支持部分`inplace`函数,因这些函数会修改tensor自身,而在反向传播中,variable需要缓存原来的tensor来计算反向传播梯度。如果想要计算各个Variable的梯度,只需调用根节点variable的`backward`方法,autograd会自动沿着计算图反向传播,计算每一个叶子节点的梯度。\n", "\n", "`variable.backward(grad_variables=None, retain_graph=None, create_graph=None)`主要有如下参数:\n", "\n", "- grad_variables:形状与variable一致,对于`y.backward()`,grad_variables相当于链式法则${dz \\over dx}={dz \\over dy} \\times {dy \\over dx}$中的$\\textbf {dz} \\over \\textbf {dy}$。grad_variables也可以是tensor或序列。\n", "- retain_graph:反向传播需要缓存一些中间结果,反向传播之后,这些缓存就被清空,可通过指定这个参数不清空缓存,用来多次反向传播。\n", "- create_graph:对反向传播过程再次构建计算图,可通过`backward of backward`实现求高阶导数。\n", "\n", "上述描述可能比较抽象,如果没有看懂,不用着急,会在本节后半部分详细介绍,下面先看几个例子。" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "from __future__ import print_function\n", "import torch as t\n", "from torch.autograd import Variable as V" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Variable containing:\n", " 1 1 1 1\n", " 1 1 1 1\n", " 1 1 1 1\n", "[torch.FloatTensor of size 3x4]" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# 从tensor中创建variable,指定需要求导\n", "a = V(t.ones(3,4), requires_grad = True) \n", "a" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Variable containing:\n", " 0 0 0 0\n", " 0 0 0 0\n", " 0 0 0 0\n", "[torch.FloatTensor of size 3x4]" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "b = V(t.zeros(3,4))\n", "b" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Variable containing:\n", " 1 1 1 1\n", " 1 1 1 1\n", " 1 1 1 1\n", "[torch.FloatTensor of size 3x4]" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# 函数的使用与tensor一致\n", "# 也可写成c = a + b\n", "c = a.add(b)\n", "c" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "d = c.sum()\n", "d.backward() # 反向传播" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "scrolled": true }, "outputs": [ { "data": { "text/plain": [ "(12.0, Variable containing:\n", " 12\n", " [torch.FloatTensor of size 1])" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# 注意二者的区别\n", "# 前者在取data后变为tensor,而后从tensor计算sum得到float\n", "# 后者计算sum后仍然是Variable\n", "c.data.sum(), c.sum()" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "scrolled": false }, "outputs": [ { "data": { "text/plain": [ "Variable containing:\n", " 1 1 1 1\n", " 1 1 1 1\n", " 1 1 1 1\n", "[torch.FloatTensor of size 3x4]" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "a.grad" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "scrolled": false }, "outputs": [ { "data": { "text/plain": [ "(True, False, True)" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# 此处虽然没有指定c需要求导,但c依赖于a,而a需要求导,\n", "# 因此c的requires_grad属性会自动设为True\n", "a.requires_grad, b.requires_grad, c.requires_grad" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(True, True, False)" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# 由用户创建的variable属于叶子节点,对应的grad_fn是None\n", "a.is_leaf, b.is_leaf, c.is_leaf" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "True" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# c.grad是None, 因c不是叶子节点,它的梯度是用来计算a的梯度\n", "# 所以虽然c.requires_grad = True,但其梯度计算完之后即被释放\n", "c.grad is None" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "计算下面这个函数的导函数:\n", "$$\n", "y = x^2\\bullet e^x\n", "$$\n", "它的导函数是:\n", "$$\n", "{dy \\over dx} = 2x\\bullet e^x + x^2 \\bullet e^x\n", "$$\n", "来看看autograd的计算结果与手动求导计算结果的误差。" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [], "source": [ "def f(x):\n", " '''计算y'''\n", " y = x**2 * t.exp(x)\n", " return y\n", "\n", "def gradf(x):\n", " '''手动求导函数'''\n", " dx = 2*x*t.exp(x) + x**2*t.exp(x)\n", " return dx" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Variable containing:\n", " 7.8454 0.4475 5.5884 0.1406\n", " 0.4044 0.5008 0.4989 13.3268\n", " 0.3547 0.0623 1.0497 4.2674\n", "[torch.FloatTensor of size 3x4]" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "x = V(t.randn(3,4), requires_grad = True)\n", "y = f(x)\n", "y" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Variable containing:\n", " 19.0962 2.1796 14.4631 1.0203\n", " -0.3276 0.1172 -0.1745 29.7573\n", " 1.8619 -0.3699 3.9812 11.6386\n", "[torch.FloatTensor of size 3x4]" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "y.backward(t.ones(y.size())) # grad_variables形状与y一致\n", "x.grad" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Variable containing:\n", " 19.0962 2.1796 14.4631 1.0203\n", " -0.3276 0.1172 -0.1745 29.7573\n", " 1.8619 -0.3699 3.9812 11.6386\n", "[torch.FloatTensor of size 3x4]" ] }, "execution_count": 14, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# autograd的计算结果与利用公式手动计算的结果一致\n", "gradf(x) " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 3.2.2 计算图\n", "\n", "PyTorch中`autograd`的底层采用了计算图,计算图是一种特殊的有向无环图(DAG),用于记录算子与变量之间的关系。一般用矩形表示算子,椭圆形表示变量。如表达式$ \\textbf {z = wx + b}$可分解为$\\textbf{y = wx}$和$\\textbf{z = y + b}$,其计算图如图3-3所示,图中`MUL`,`ADD`都是算子,$\\textbf{w}$,$\\textbf{x}$,$\\textbf{b}$即变量。\n", "\n", "![图3-3:computation graph](imgs/com_graph.svg)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "如上有向无环图中,$\\textbf{X}$和$\\textbf{b}$是叶子节点(leaf node),这些节点通常由用户自己创建,不依赖于其他变量。$\\textbf{z}$称为根节点,是计算图的最终目标。利用链式法则很容易求得各个叶子节点的梯度。\n", "$${\\partial z \\over \\partial b} = 1,\\space {\\partial z \\over \\partial y} = 1\\\\\n", "{\\partial y \\over \\partial w }= x,{\\partial y \\over \\partial x}= w\\\\\n", "{\\partial z \\over \\partial x}= {\\partial z \\over \\partial y} {\\partial y \\over \\partial x}=1 * w\\\\\n", "{\\partial z \\over \\partial w}= {\\partial z \\over \\partial y} {\\partial y \\over \\partial w}=1 * x\\\\\n", "$$\n", "而有了计算图,上述链式求导即可利用计算图的反向传播自动完成,其过程如图3-4所示。\n", "\n", "![图3-4:计算图的反向传播](imgs/com_graph_backward.svg)\n", "\n", "\n", "在PyTorch实现中,autograd会随着用户的操作,记录生成当前variable的所有操作,并由此建立一个有向无环图。用户每进行一个操作,相应的计算图就会发生改变。更底层的实现中,图中记录了操作`Function`,每一个变量在图中的位置可通过其`grad_fn`属性在图中的位置推测得到。在反向传播过程中,autograd沿着这个图从当前变量(根节点$\\textbf{z}$)溯源,可以利用链式求导法则计算所有叶子节点的梯度。每一个前向传播操作的函数都有与之对应的反向传播函数用来计算输入的各个variable的梯度,这些函数的函数名通常以`Backward`结尾。下面结合代码学习autograd的实现细节。" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [], "source": [ "x = V(t.ones(1))\n", "b = V(t.rand(1), requires_grad = True)\n", "w = V(t.rand(1), requires_grad = True)\n", "y = w * x # 等价于y=w.mul(x)\n", "z = y + b # 等价于z=y.add(b)" ] }, { "cell_type": "code", "execution_count": 16, "metadata": { "scrolled": true }, "outputs": [ { "data": { "text/plain": [ "(False, True, True)" ] }, "execution_count": 16, "metadata": {}, "output_type": "execute_result" } ], "source": [ "x.requires_grad, b.requires_grad, w.requires_grad" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "True" ] }, "execution_count": 17, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# 虽然未指定y.requires_grad为True,但由于y依赖于需要求导的w\n", "# 故而y.requires_grad为True\n", "y.requires_grad" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(True, True, True)" ] }, "execution_count": 18, "metadata": {}, "output_type": "execute_result" } ], "source": [ "x.is_leaf, w.is_leaf, b.is_leaf" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(False, False)" ] }, "execution_count": 19, "metadata": {}, "output_type": "execute_result" } ], "source": [ "y.is_leaf, z.is_leaf" ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 20, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# grad_fn可以查看这个variable的反向传播函数,\n", "# z是add函数的输出,所以它的反向传播函数是AddBackward\n", "z.grad_fn " ] }, { "cell_type": "code", "execution_count": 21, "metadata": { "scrolled": true }, "outputs": [ { "data": { "text/plain": [ "((, 0),\n", " (, 0))" ] }, "execution_count": 21, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# next_functions保存grad_fn的输入,是一个tuple,tuple的元素也是Function\n", "# 第一个是y,它是乘法(mul)的输出,所以对应的反向传播函数y.grad_fn是MulBackward\n", "# 第二个是b,它是叶子节点,由用户创建,grad_fn为None,但是有\n", "z.grad_fn.next_functions " ] }, { "cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "True" ] }, "execution_count": 22, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# variable的grad_fn对应着和图中的function相对应\n", "z.grad_fn.next_functions[0][0] == y.grad_fn" ] }, { "cell_type": "code", "execution_count": 23, "metadata": { "scrolled": true }, "outputs": [ { "data": { "text/plain": [ "((, 0), (None, 0))" ] }, "execution_count": 23, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# 第一个是w,叶子节点,需要求导,梯度是累加的\n", "# 第二个是x,叶子节点,不需要求导,所以为None\n", "y.grad_fn.next_functions" ] }, { "cell_type": "code", "execution_count": 24, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(None, None)" ] }, "execution_count": 24, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# 叶子节点的grad_fn是None\n", "w.grad_fn,x.grad_fn" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "计算w的梯度的时候,需要用到x的数值(${\\partial y\\over \\partial w} = x $),这些数值在前向过程中会保存成buffer,在计算完梯度之后会自动清空。为了能够多次反向传播需要指定`retain_graph`来保留这些buffer。" ] }, { "cell_type": "code", "execution_count": 25, "metadata": { "scrolled": true }, "outputs": [ { "data": { "text/plain": [ "Variable containing:\n", " 1\n", "[torch.FloatTensor of size 1]" ] }, "execution_count": 25, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# 使用retain_graph来保存buffer\n", "z.backward(retain_graph=True)\n", "w.grad" ] }, { "cell_type": "code", "execution_count": 26, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Variable containing:\n", " 2\n", "[torch.FloatTensor of size 1]" ] }, "execution_count": 26, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# 多次反向传播,梯度累加,这也就是w中AccumulateGrad标识的含义\n", "z.backward()\n", "w.grad" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "PyTorch使用的是动态图,它的计算图在每次前向传播时都是从头开始构建,所以它能够使用Python控制语句(如for、if等)根据需求创建计算图。这点在自然语言处理领域中很有用,它意味着你不需要事先构建所有可能用到的图的路径,图在运行时才构建。" ] }, { "cell_type": "code", "execution_count": 27, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Variable containing:\n", " 1\n", "[torch.FloatTensor of size 1]" ] }, "execution_count": 27, "metadata": {}, "output_type": "execute_result" } ], "source": [ "def abs(x):\n", " if x.data[0]>0: return x\n", " else: return -x\n", "x = V(t.ones(1),requires_grad=True)\n", "y = abs(x)\n", "y.backward()\n", "x.grad" ] }, { "cell_type": "code", "execution_count": 28, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Variable containing:\n", "-1\n", "[torch.FloatTensor of size 1]\n", "\n" ] } ], "source": [ "x = V(-1*t.ones(1),requires_grad=True)\n", "y = abs(x)\n", "y.backward()\n", "print(x.grad)" ] }, { "cell_type": "code", "execution_count": 29, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Variable containing:\n", " 0\n", " 0\n", " 0\n", " 6\n", " 3\n", " 2\n", "[torch.FloatTensor of size 6]" ] }, "execution_count": 29, "metadata": {}, "output_type": "execute_result" } ], "source": [ "def f(x):\n", " result = 1\n", " for ii in x:\n", " if ii.data[0]>0: result=ii*result\n", " return result\n", "x = V(t.arange(-2,4),requires_grad=True)\n", "y = f(x) # y = x[3]*x[4]*x[5]\n", "y.backward()\n", "x.grad" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "变量的`requires_grad`属性默认为False,如果某一个节点requires_grad被设置为True,那么所有依赖它的节点`requires_grad`都是True。这其实很好理解,对于$ \\textbf{x}\\to \\textbf{y} \\to \\textbf{z}$,x.requires_grad = True,当需要计算$\\partial z \\over \\partial x$时,根据链式法则,$\\frac{\\partial z}{\\partial x} = \\frac{\\partial z}{\\partial y} \\frac{\\partial y}{\\partial x}$,自然也需要求$ \\frac{\\partial z}{\\partial y}$,所以y.requires_grad会被自动标为True. \n", "\n", "`volatile=True`是另外一个很重要的标识,它能够将所有依赖于它的节点全部都设为`volatile=True`,其优先级比`requires_grad=True`高。`volatile=True`的节点不会求导,即使`requires_grad=True`,也无法进行反向传播。对于不需要反向传播的情景(如inference,即测试推理时),该参数可实现一定程度的速度提升,并节省约一半显存,因其不需要分配空间计算梯度。" ] }, { "cell_type": "code", "execution_count": 30, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(False, True, True)" ] }, "execution_count": 30, "metadata": {}, "output_type": "execute_result" } ], "source": [ "x = V(t.ones(1))\n", "w = V(t.rand(1), requires_grad=True)\n", "y = x * w\n", "# y依赖于w,而w.requires_grad = True\n", "x.requires_grad, w.requires_grad, y.requires_grad" ] }, { "cell_type": "code", "execution_count": 31, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(False, True, False)" ] }, "execution_count": 31, "metadata": {}, "output_type": "execute_result" } ], "source": [ "x = V(t.ones(1), volatile=True)\n", "w = V(t.rand(1), requires_grad = True)\n", "y = x * w\n", "# y依赖于w和x,但x.volatile = True,w.requires_grad = True\n", "x.requires_grad, w.requires_grad, y.requires_grad" ] }, { "cell_type": "code", "execution_count": 32, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(True, False, True)" ] }, "execution_count": 32, "metadata": {}, "output_type": "execute_result" } ], "source": [ "x.volatile, w.volatile, y.volatile" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "在反向传播过程中非叶子节点的导数计算完之后即被清空。若想查看这些变量的梯度,有两种方法:\n", "- 使用autograd.grad函数\n", "- 使用hook\n", "\n", "`autograd.grad`和`hook`方法都是很强大的工具,更详细的用法参考官方api文档,这里举例说明基础的使用。推荐使用`hook`方法,但是在实际使用中应尽量避免修改grad的值。" ] }, { "cell_type": "code", "execution_count": 33, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(True, True, True)" ] }, "execution_count": 33, "metadata": {}, "output_type": "execute_result" } ], "source": [ "x = V(t.ones(3), requires_grad=True)\n", "w = V(t.rand(3), requires_grad=True)\n", "y = x * w\n", "# y依赖于w,而w.requires_grad = True\n", "z = y.sum()\n", "x.requires_grad, w.requires_grad, y.requires_grad" ] }, { "cell_type": "code", "execution_count": 34, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(Variable containing:\n", " 0.3776\n", " 0.1184\n", " 0.8554\n", " [torch.FloatTensor of size 3], Variable containing:\n", " 1\n", " 1\n", " 1\n", " [torch.FloatTensor of size 3], None)" ] }, "execution_count": 34, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# 非叶子节点grad计算完之后自动清空,y.grad是None\n", "z.backward()\n", "(x.grad, w.grad, y.grad)" ] }, { "cell_type": "code", "execution_count": 35, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(Variable containing:\n", " 1\n", " 1\n", " 1\n", " [torch.FloatTensor of size 3],)" ] }, "execution_count": 35, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# 第一种方法:使用grad获取中间变量的梯度\n", "x = V(t.ones(3), requires_grad=True)\n", "w = V(t.rand(3), requires_grad=True)\n", "y = x * w\n", "z = y.sum()\n", "# z对y的梯度,隐式调用backward()\n", "t.autograd.grad(z, y)" ] }, { "cell_type": "code", "execution_count": 36, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "y的梯度: \n", " Variable containing:\n", " 1\n", " 1\n", " 1\n", "[torch.FloatTensor of size 3]\n", "\n" ] } ], "source": [ "# 第二种方法:使用hook\n", "# hook是一个函数,输入是梯度,不应该有返回值\n", "def variable_hook(grad):\n", " print('y的梯度: \\r\\n',grad)\n", "\n", "x = V(t.ones(3), requires_grad=True)\n", "w = V(t.rand(3), requires_grad=True)\n", "y = x * w\n", "# 注册hook\n", "hook_handle = y.register_hook(variable_hook)\n", "z = y.sum()\n", "z.backward()\n", "\n", "# 除非你每次都要用hook,否则用完之后记得移除hook\n", "hook_handle.remove()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "最后再来看看variable中grad属性和backward函数`grad_variables`参数的含义,这里直接下结论:\n", "\n", "- variable $\\textbf{x}$的梯度是目标函数${f(x)} $对$\\textbf{x}$的梯度,$\\frac{df(x)}{dx} = (\\frac {df(x)}{dx_0},\\frac {df(x)}{dx_1},...,\\frac {df(x)}{dx_N})$,形状和$\\textbf{x}$一致。\n", "- 对于y.backward(grad_variables)中的grad_variables相当于链式求导法则中的$\\frac{\\partial z}{\\partial x} = \\frac{\\partial z}{\\partial y} \\frac{\\partial y}{\\partial x}$中的$\\frac{\\partial z}{\\partial y}$。z是目标函数,一般是一个标量,故而$\\frac{\\partial z}{\\partial y}$的形状与variable $\\textbf{y}$的形状一致。`z.backward()`在一定程度上等价于y.backward(grad_y)。`z.backward()`省略了grad_variables参数,是因为$z$是一个标量,而$\\frac{\\partial z}{\\partial z} = 1$" ] }, { "cell_type": "code", "execution_count": 37, "metadata": { "scrolled": true }, "outputs": [ { "data": { "text/plain": [ "Variable containing:\n", " 2\n", " 4\n", " 6\n", "[torch.FloatTensor of size 3]" ] }, "execution_count": 37, "metadata": {}, "output_type": "execute_result" } ], "source": [ "x = V(t.arange(0,3), requires_grad=True)\n", "y = x**2 + x*2\n", "z = y.sum()\n", "z.backward() # 从z开始反向传播\n", "x.grad" ] }, { "cell_type": "code", "execution_count": 38, "metadata": { "scrolled": true }, "outputs": [ { "data": { "text/plain": [ "Variable containing:\n", " 2\n", " 4\n", " 6\n", "[torch.FloatTensor of size 3]" ] }, "execution_count": 38, "metadata": {}, "output_type": "execute_result" } ], "source": [ "x = V(t.arange(0,3), requires_grad=True)\n", "y = x**2 + x*2\n", "z = y.sum()\n", "y_grad_variables = V(t.Tensor([1,1,1])) # dz/dy\n", "y.backward(y_grad_variables) #从y开始反向传播\n", "x.grad" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "另外值得注意的是,只有对variable的操作才能使用autograd,如果对variable的data直接进行操作,将无法使用反向传播。除了对参数初始化,一般我们不会修改variable.data的值。" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "在PyTorch中计算图的特点可总结如下:\n", "\n", "- autograd根据用户对variable的操作构建其计算图。对变量的操作抽象为`Function`。\n", "- 对于那些不是任何函数(Function)的输出,由用户创建的节点称为叶子节点,叶子节点的`grad_fn`为None。叶子节点中需要求导的variable,具有`AccumulateGrad`标识,因其梯度是累加的。\n", "- variable默认是不需要求导的,即`requires_grad`属性默认为False,如果某一个节点requires_grad被设置为True,那么所有依赖它的节点`requires_grad`都为True。\n", "- variable的`volatile`属性默认为False,如果某一个variable的`volatile`属性被设为True,那么所有依赖它的节点`volatile`属性都为True。volatile属性为True的节点不会求导,volatile的优先级比`requires_grad`高。\n", "- 多次反向传播时,梯度是累加的。反向传播的中间缓存会被清空,为进行多次反向传播需指定`retain_graph`=True来保存这些缓存。\n", "- 非叶子节点的梯度计算完之后即被清空,可以使用`autograd.grad`或`hook`技术获取非叶子节点的值。\n", "- variable的grad与data形状一致,应避免直接修改variable.data,因为对data的直接操作无法利用autograd进行反向传播\n", "- 反向传播函数`backward`的参数`grad_variables`可以看成链式求导的中间结果,如果是标量,可以省略,默认为1\n", "- PyTorch采用动态图设计,可以很方便地查看中间层的输出,动态的设计计算图结构。" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 3.2.3 扩展autograd\n", "\n", "\n", "目前绝大多数函数都可以使用`autograd`实现反向求导,但如果需要自己写一个复杂的函数,不支持自动反向求导怎么办? 写一个`Function`,实现它的前向传播和反向传播代码,`Function`对应于计算图中的矩形, 它接收参数,计算并返回结果。下面给出一个例子。\n", "\n", "```python\n", "\n", "class Mul(Function):\n", " \n", " @staticmethod\n", " def forward(ctx, w, x, b, x_requires_grad = True):\n", " ctx.x_requires_grad = x_requires_grad\n", " ctx.save_for_backward(w,x)\n", " output = w * x + b\n", " return output\n", " \n", " @staticmethod\n", " def backward(ctx, grad_output):\n", " w,x = ctx.saved_variables\n", " grad_w = grad_output * x\n", " if ctx.x_requires_grad:\n", " grad_x = grad_output * w\n", " else:\n", " grad_x = None\n", " grad_b = grad_output * 1\n", " return grad_w, grad_x, grad_b, None\n", "```\n", "\n", "分析如下:\n", "\n", "- 自定义的Function需要继承autograd.Function,没有构造函数`__init__`,forward和backward函数都是静态方法\n", "- forward函数的输入和输出都是Tensor,backward函数的输入和输出都是Variable\n", "- backward函数的输出和forward函数的输入一一对应,backward函数的输入和forward函数的输出一一对应\n", "- backward函数的grad_output参数即t.autograd.backward中的`grad_variables`\n", "- 如果某一个输入不需要求导,直接返回None,如forward中的输入参数x_requires_grad显然无法对它求导,直接返回None即可\n", "- 反向传播可能需要利用前向传播的某些中间结果,需要进行保存,否则前向传播结束后这些对象即被释放\n", "\n", "Function的使用利用Function.apply(variable)" ] }, { "cell_type": "code", "execution_count": 39, "metadata": {}, "outputs": [], "source": [ "from torch.autograd import Function\n", "class MultiplyAdd(Function):\n", " \n", " @staticmethod\n", " def forward(ctx, w, x, b): \n", " print('type in forward',type(x))\n", " ctx.save_for_backward(w,x)\n", " output = w * x + b\n", " return output\n", " \n", " @staticmethod\n", " def backward(ctx, grad_output): \n", " w,x = ctx.saved_variables\n", " print('type in backward',type(x))\n", " grad_w = grad_output * x\n", " grad_x = grad_output * w\n", " grad_b = grad_output * 1\n", " return grad_w, grad_x, grad_b " ] }, { "cell_type": "code", "execution_count": 40, "metadata": { "scrolled": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "开始前向传播\n", "type in backwardtype in forward \n", "\n", "开始反向传播\n" ] }, { "data": { "text/plain": [ "(None, Variable containing:\n", " 1\n", " [torch.FloatTensor of size 1], Variable containing:\n", " 1\n", " [torch.FloatTensor of size 1])" ] }, "execution_count": 40, "metadata": {}, "output_type": "execute_result" } ], "source": [ "x = V(t.ones(1))\n", "w = V(t.rand(1), requires_grad = True)\n", "b = V(t.rand(1), requires_grad = True)\n", "print('开始前向传播')\n", "z=MultiplyAdd.apply(w, x, b)\n", "print('开始反向传播')\n", "z.backward()\n", "\n", "# x不需要求导,中间过程还是会计算它的导数,但随后被清空\n", "x.grad, w.grad, b.grad" ] }, { "cell_type": "code", "execution_count": 41, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "开始前向传播\n", "type in forward \n", "开始反向传播\n", "type in backward \n" ] }, { "data": { "text/plain": [ "(Variable containing:\n", " 1\n", " [torch.FloatTensor of size 1], Variable containing:\n", " 0.9633\n", " [torch.FloatTensor of size 1], Variable containing:\n", " 1\n", " [torch.FloatTensor of size 1])" ] }, "execution_count": 41, "metadata": {}, "output_type": "execute_result" } ], "source": [ "x = V(t.ones(1))\n", "w = V(t.rand(1), requires_grad = True)\n", "b = V(t.rand(1), requires_grad = True)\n", "print('开始前向传播')\n", "z=MultiplyAdd.apply(w,x,b)\n", "print('开始反向传播')\n", "\n", "# 调用MultiplyAdd.backward\n", "# 输出grad_w, grad_x, grad_b\n", "z.grad_fn.apply(V(t.ones(1)))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "之所以forward函数的输入是tensor,而backward函数的输入是variable,是为了实现高阶求导。backward函数的输入输出虽然是variable,但在实际使用时autograd.Function会将输入variable提取为tensor,并将计算结果的tensor封装成variable返回。在backward函数中,之所以也要对variable进行操作,是为了能够计算梯度的梯度(backward of backward)。下面举例说明,有关torch.autograd.grad的更详细使用请参照文档。" ] }, { "cell_type": "code", "execution_count": 42, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(Variable containing:\n", " 10\n", " [torch.FloatTensor of size 1],)" ] }, "execution_count": 42, "metadata": {}, "output_type": "execute_result" } ], "source": [ "x = V(t.Tensor([5]), requires_grad=True)\n", "y = x ** 2\n", "grad_x = t.autograd.grad(y, x, create_graph=True)\n", "grad_x # dy/dx = 2 * x" ] }, { "cell_type": "code", "execution_count": 43, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(Variable containing:\n", " 2\n", " [torch.FloatTensor of size 1],)" ] }, "execution_count": 43, "metadata": {}, "output_type": "execute_result" } ], "source": [ "grad_grad_x = t.autograd.grad(grad_x[0],x)\n", "grad_grad_x # 二阶导数 d(2x)/dx = 2" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "这种设计虽然能让`autograd`具有高阶求导功能,但其也限制了Tensor的使用,因autograd中反向传播的函数只能利用当前已经有的Variable操作。这个设计是在`0.2`版本新加入的,为了更好的灵活性,也为了兼容旧版本的代码,PyTorch还提供了另外一种扩展autograd的方法。PyTorch提供了一个装饰器`@once_differentiable`,能够在backward函数中自动将输入的variable提取成tensor,把计算结果的tensor自动封装成variable。有了这个特性我们就能够很方便的使用numpy/scipy中的函数,操作不再局限于variable所支持的操作。但是这种做法正如名字中所暗示的那样只能求导一次,它打断了反向传播图,不再支持高阶求导。\n", "\n", "\n", "上面所描述的都是新式Function,还有个legacy Function,可以带有`__init__`方法,`forward`和`backwad`函数也不需要声明为`@staticmethod`,但随着版本更迭,此类Function将越来越少遇到,在此不做更多介绍。\n", "\n", "此外在实现了自己的Function之后,还可以使用`gradcheck`函数来检测实现是否正确。`gradcheck`通过数值逼近来计算梯度,可能具有一定的误差,通过控制`eps`的大小可以控制容忍的误差。\n", "关于这部份的内容可以参考github上开发者们的讨论[^3]。\n", "\n", "[^3]: https://github.com/pytorch/pytorch/pull/1016" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "下面举例说明如何利用Function实现sigmoid Function。" ] }, { "cell_type": "code", "execution_count": 44, "metadata": {}, "outputs": [], "source": [ "class Sigmoid(Function):\n", " \n", " @staticmethod\n", " def forward(ctx, x): \n", " output = 1 / (1 + t.exp(-x))\n", " ctx.save_for_backward(output)\n", " return output\n", " \n", " @staticmethod\n", " def backward(ctx, grad_output): \n", " output, = ctx.saved_variables\n", " grad_x = output * (1 - output) * grad_output\n", " return grad_x " ] }, { "cell_type": "code", "execution_count": 45, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "True" ] }, "execution_count": 45, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# 采用数值逼近方式检验计算梯度的公式对不对\n", "test_input = V(t.randn(3,4), requires_grad=True)\n", "t.autograd.gradcheck(Sigmoid.apply, (test_input,), eps=1e-3)" ] }, { "cell_type": "code", "execution_count": 46, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "232 µs ± 68.6 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n", "191 µs ± 6.1 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n", "215 µs ± 23.1 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n" ] } ], "source": [ "def f_sigmoid(x):\n", " y = Sigmoid.apply(x)\n", " y.backward(t.ones(x.size()))\n", " \n", "def f_naive(x):\n", " y = 1/(1 + t.exp(-x))\n", " y.backward(t.ones(x.size()))\n", " \n", "def f_th(x):\n", " y = t.sigmoid(x)\n", " y.backward(t.ones(x.size()))\n", " \n", "x=V(t.randn(100, 100), requires_grad=True)\n", "%timeit -n 100 f_sigmoid(x)\n", "%timeit -n 100 f_naive(x)\n", "%timeit -n 100 f_th(x)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "显然`f_sigmoid`要比单纯利用`autograd`加减和乘方操作实现的函数快不少,因为f_sigmoid的backward优化了反向传播的过程。另外可以看出系统实现的buildin接口(t.sigmoid)更快。" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 3.2.4 小试牛刀: 用Variable实现线性回归\n", "在上一节中讲解了利用tensor实现线性回归,在这一小节中,将讲解如何利用autograd/Variable实现线性回归,以此感受autograd的便捷之处。" ] }, { "cell_type": "code", "execution_count": 47, "metadata": {}, "outputs": [], "source": [ "import torch as t\n", "from torch.autograd import Variable as V\n", "%matplotlib inline\n", "from matplotlib import pyplot as plt\n", "from IPython import display" ] }, { "cell_type": "code", "execution_count": 48, "metadata": {}, "outputs": [], "source": [ "# 设置随机数种子,为了在不同人电脑上运行时下面的输出一致\n", "t.manual_seed(1000) \n", "\n", "def get_fake_data(batch_size=8):\n", " ''' 产生随机数据:y = x*2 + 3,加上了一些噪声'''\n", " x = t.rand(batch_size,1) * 20\n", " y = x * 2 + (1 + t.randn(batch_size, 1))*3\n", " return x, y" ] }, { "cell_type": "code", "execution_count": 49, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 49, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXQAAAD8CAYAAABn919SAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBo\ndHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAAD11JREFUeJzt3V+MXGd9xvHvU8eU5U+1gWxQvEAN\nKHKpSLHpKkobKaJA64AQMVFRSVtktbShEqhQkEVML4CLKkHmj6peRAokTS5oVArGQS3FWCFtWqmk\n3eAQO3XdFMqfrN14KSzQsqKO+fVix2Bv1t6Z9c7OzLvfj7SamXfP6DxaK0/mvOedc1JVSJJG308N\nOoAkaXVY6JLUCAtdkhphoUtSIyx0SWqEhS5JjbDQJakRFrokNcJCl6RGXLSWO7vkkktq8+bNa7lL\nSRp5Dz744LeqamK57da00Ddv3sz09PRa7lKSRl6Sr3eznVMuktQIC12SGmGhS1Ijli30JE9N8s9J\nvpzkkSTv74y/IMkDSR5N8pdJntL/uJKkc+nmE/oPgVdU1UuBrcC1Sa4CPgB8pKouB74DvLl/MSVJ\ny1l2lUst3AHjfzovN3Z+CngF8Jud8buA9wG3rn5ESRpN+w7OsGf/UY7NzbNpfIxd27ewY9tk3/bX\n1Rx6kg1JHgJOAAeArwBzVfVEZ5PHgP6llKQRs+/gDLv3HmJmbp4CZubm2b33EPsOzvRtn10VelWd\nqqqtwHOBK4EXL7XZUu9NcmOS6STTs7OzK08qSSNkz/6jzJ88ddbY/MlT7Nl/tG/77GmVS1XNAX8H\nXAWMJzk9ZfNc4Ng53nNbVU1V1dTExLJfdJKkJhybm+9pfDV0s8plIsl45/kY8CrgCHAf8OudzXYC\n9/QrpCSNmk3jYz2Nr4ZuPqFfBtyX5GHgX4ADVfXXwLuBdyb5D+DZwO19SylJI2bX9i2Mbdxw1tjY\nxg3s2r6lb/vsZpXLw8C2Jca/ysJ8uiRpkdOrWdZylcuaXpxLktaTHdsm+1rgi/nVf0lqhIUuSY2w\n0CWpERa6JDXCQpekRljoktQIC12SGmGhS1IjLHRJaoSFLkmNsNAlqREWuiQ1wkKXpEZY6JLUCAtd\nkhphoUtSIyx0SWqEhS5JjbDQJakRFrokNcJCl6RGWOiS1AgLXZIaYaFLUiMsdElqhIUuSY2w0CWp\nERa6JDXCQpekRljoktQIC12SGmGhS1IjLHRJaoSFLkmNsNAlqRHLFnqS5yW5L8mRJI8keXtn/H1J\nZpI81Pl5Tf/jSpLO5aIutnkCeFdVfSnJM4EHkxzo/O4jVfXB/sWTJHVr2UKvquPA8c7z7yc5Akz2\nO5gkqTc9zaEn2QxsAx7oDL0tycNJ7khy8SpnkyT1oOtCT/IM4FPAO6rqe8CtwIuArSx8gv/QOd53\nY5LpJNOzs7OrEFmStJSuCj3JRhbK/ONVtRegqh6vqlNV9SPgo8CVS723qm6rqqmqmpqYmFit3JKk\nRbpZ5RLgduBIVX34jPHLztjs9cDh1Y8nSepWN6tcrgbeBBxK8lBn7D3ADUm2AgV8DXhLXxJKkrrS\nzSqXfwSyxK8+u/pxJEkr5TdFJakRFrokNcJCl6RGdHNSVGrSvoMz7Nl/lGNz82waH2PX9i3s2OaX\noDW6LHStS/sOzrB77yHmT54CYGZunt17DwFY6hpZTrloXdqz/+iPy/y0+ZOn2LP/6IASSRfOQte6\ndGxuvqdxaRRY6FqXNo2P9TQujQILXevSru1bGNu44ayxsY0b2LV9y4ASSRfOk6Jal06f+HSVi1pi\noWvd2rFt0gJXU5xykaRGWOiS1AgLXZIaYaFLUiMsdElqhKtcJKlHw3phNwtdknowzBd2c8pFknow\nzBd2s9AlqQfDfGE3C12SejDMF3az0CWpB8N8YTdPikpSD4b5wm4WuiT1aFgv7OaUiyQ1wkKXpEZY\n6JLUCAtdkhphoUtSIyx0SWqEhS5JjbDQJakRFrokNcJCl6RGWOiS1AgLXZIasWyhJ3lekvuSHEny\nSJK3d8afleRAkkc7jxf3P64k6Vy6+YT+BPCuqnoxcBXw1iQ/D9wE3FtVlwP3dl5rBO07OMPVt3yB\nF9z0N1x9yxfYd3Bm0JEkrcCyhV5Vx6vqS53n3weOAJPAdcBdnc3uAnb0K6T65/QNb2fm5il+csNb\nS10aPT3NoSfZDGwDHgCeU1XHYaH0gUtXO5z6b5hveCupN10XepJnAJ8C3lFV3+vhfTcmmU4yPTs7\nu5KM6qNhvuGtpN50VehJNrJQ5h+vqr2d4ceTXNb5/WXAiaXeW1W3VdVUVU1NTEysRmatomG+4a2k\n3nSzyiXA7cCRqvrwGb/6DLCz83wncM/qx1O/DfMNbyX1ppt7il4NvAk4lOShzth7gFuATyR5M/AN\n4A39iah+GuYb3krqTapqzXY2NTVV09PTa7Y/SWpBkgeramq57fymqCQ1wkKXpEZY6JLUCAtdkhph\noUtSI7pZtqhVsu/gjMsDJfWNhb5GTl8E6/R1U05fBAuw1CWtCgt9jZzvIlgW+uB41KSWWOhrxItg\nDR+PmtQaT4quES+CNXy8dLBaY6GvES+CNXw8alJrLPQ1smPbJDdffwWT42MEmBwf4+brr/DQfoA8\nalJrnENfQzu2TVrgQ2TX9i1nzaGDR00abRa61i0vHazWWOha1zxqUkucQ5ekRljoktQIC12SGmGh\nS1IjLHRJaoSFLkmNsNAlqREWuiQ1wkKXpEZY6JLUCAtdkhphoUtSIyx0SWqEhS5JjbDQJakRFrok\nNWIkbnCx7+CMd5WRpGUMfaHvOzhz1n0fZ+bm2b33EIClLklnGPoplz37j551E1+A+ZOn2LP/6IAS\nSdJwGvpCPzY339O4JK1XQ1/om8bHehqXpPVq2UJPckeSE0kOnzH2viQzSR7q/LymXwF3bd/C2MYN\nZ42NbdzAru1b+rVLSRpJ3XxCvxO4donxj1TV1s7PZ1c31k/s2DbJzddfweT4GAEmx8e4+forPCEq\nSYssu8qlqu5Psrn/Uc5tx7ZJC1ySlnEhc+hvS/JwZ0rm4lVLJElakZUW+q3Ai4CtwHHgQ+faMMmN\nSaaTTM/Ozq5wd5Kk5ayo0Kvq8ao6VVU/Aj4KXHmebW+rqqmqmpqYmFhpTknSMlZU6EkuO+Pl64HD\n59pWkrQ2lj0pmuRu4OXAJUkeA94LvDzJVqCArwFv6WNGSVIXulnlcsMSw7f3IYsk6QIM/TdFJUnd\nsdAlqREWuiQ1wkKXpEZY6JLUCAtdkhphoUtSIyx0SWqEhS5JjbDQJakRFrokNcJCl6RGWOiS1AgL\nXZIaYaFLUiMsdElqhIUuSY2w0CWpERa6JDXCQpekRljoktQIC12SGmGhS1IjLHRJaoSFLkmNsNAl\nqREWuiQ1wkKXpEZY6JLUCAtdkhphoUtSIyx0SWqEhS5JjbDQJakRFrokNcJCl6RGLFvoSe5IciLJ\n4TPGnpXkQJJHO48X9zemJGk53XxCvxO4dtHYTcC9VXU5cG/ntSRpgJYt9Kq6H/j2ouHrgLs6z+8C\ndqxyLklSj1Y6h/6cqjoO0Hm8dPUiSZJWou8nRZPcmGQ6yfTs7Gy/dydJ69ZKC/3xJJcBdB5PnGvD\nqrqtqqaqampiYmKFu5MkLWelhf4ZYGfn+U7gntWJI0laqW6WLd4N/BOwJcljSd4M3AL8apJHgV/t\nvJYkDdBFy21QVTec41evXOUskqQL4DdFJakRFrokNcJCl6RGWOiS1AgLXZIaYaFLUiMsdElqhIUu\nSY2w0CWpERa6JDXCQpekRix7LZdRs+/gDHv2H+XY3DybxsfYtX0LO7ZNDjqWJPVdU4W+7+AMu/ce\nYv7kKQBm5ubZvfcQgKUuqXlNTbns2X/0x2V+2vzJU+zZf3RAiSRp7TRV6Mfm5nsal6SWNFXom8bH\nehqXpJY0Vei7tm9hbOOGs8bGNm5g1/YtA0okSWunqZOip098uspF0nrUVKHDQqlb4JLWo6amXCRp\nPbPQJakRFrokNcJCl6RGWOiS1IhU1drtLJkFvr7MZpcA31qDOBfCjKtnFHKacXWMQkYYzpw/W1UT\ny220poXejSTTVTU16BznY8bVMwo5zbg6RiEjjE7OpTjlIkmNsNAlqRHDWOi3DTpAF8y4ekYhpxlX\nxyhkhNHJ+SRDN4cuSVqZYfyELklagaEq9CRfS3IoyUNJpgedZylJxpN8Msm/JTmS5JcGnelMSbZ0\n/n6nf76X5B2DzrVYkj9K8kiSw0nuTvLUQWdaLMnbO/keGaa/YZI7kpxIcviMsWclOZDk0c7jxUOY\n8Q2dv+WPkgx8Fck5Mu7p/Lf9cJJPJxkfZMZeDVWhd/xKVW0d4mVDfwp8rqp+DngpcGTAec5SVUc7\nf7+twC8CPwA+PeBYZ0kyCfwhMFVVLwE2AG8cbKqzJXkJ8PvAlSz8O782yeWDTfVjdwLXLhq7Cbi3\nqi4H7u28HqQ7eXLGw8D1wP1rnmZpd/LkjAeAl1TVLwD/Duxe61AXYhgLfWgl+RngGuB2gKr6v6qa\nG2yq83ol8JWqWu7LXINwETCW5CLgacCxAedZ7MXAF6vqB1X1BPD3wOsHnAmAqrof+Pai4euAuzrP\n7wJ2rGmoRZbKWFVHqmpobvB7joyf7/x7A3wReO6aB7sAw1boBXw+yYNJbhx0mCW8EJgF/jzJwSQf\nS/L0QYc6jzcCdw86xGJVNQN8EPgGcBz4blV9frCpnuQwcE2SZyd5GvAa4HkDznQ+z6mq4wCdx0sH\nnKcFvwv87aBD9GLYCv3qqnoZ8GrgrUmuGXSgRS4CXgbcWlXbgP9l8Ie2S0ryFOB1wF8NOstinfnd\n64AXAJuApyf57cGmOltVHQE+wMIh+OeALwNPnPdNakaSP2bh3/vjg87Si6Eq9Ko61nk8wcK875WD\nTfQkjwGPVdUDndefZKHgh9GrgS9V1eODDrKEVwH/WVWzVXUS2Av88oAzPUlV3V5VL6uqa1g4NH90\n0JnO4/EklwF0Hk8MOM/ISrITeC3wWzVi67qHptCTPD3JM08/B36NhcPeoVFV/wV8M8npu06/EvjX\nAUY6nxsYwumWjm8AVyV5WpKw8HccqpPLAEku7Tw+n4WTecP69wT4DLCz83wncM8As4ysJNcC7wZe\nV1U/GHSeXg3NF4uSvJCfrMa4CPiLqvqTAUZaUpKtwMeApwBfBX6nqr4z2FRn68z5fhN4YVV9d9B5\nlpLk/cBvsHBYexD4var64WBTnS3JPwDPBk4C76yqewccCYAkdwMvZ+GqgI8D7wX2AZ8Ans/C/zDf\nUFWLT5wOOuO3gT8DJoA54KGq2j5kGXcDPw38d2ezL1bVHwwk4AoMTaFLki7M0Ey5SJIujIUuSY2w\n0CWpERa6JDXCQpekRljoktQIC12SGmGhS1Ij/h/CJYJPfXoR0gAAAABJRU5ErkJggg==\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# 来看看产生x-y分布是什么样的\n", "x, y = get_fake_data()\n", "plt.scatter(x.squeeze().numpy(), y.squeeze().numpy())" ] }, { "cell_type": "code", "execution_count": 50, "metadata": { "scrolled": false }, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXwAAAD8CAYAAAB0IB+mAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBo\ndHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAAIABJREFUeJzt3Xl4VOX5xvHvkz1AICA7EsMaQJag\nEQXUWrWCW6EuqP1VsWqpbW0LCArWVtwqrVbp4qXFaqVWawBRBFFqBUWtG0gWIOyCLGEnhCVAlvf3\nRwYLIZMMyayZ+3NduZg5c+bM09PjnZP3vPMcc84hIiINX0yoCxARkeBQ4IuIRAkFvohIlFDgi4hE\nCQW+iEiUUOCLiEQJBb6ISJRQ4IuIRAkFvohIlIgL5oe1bNnSpaenB/MjRSTK5G/ZV+PrqY3iaZ+a\nTKxZkCqqvyVLluxyzrWq73aCGvjp6eksXrw4mB8pIlFm8OQFbCkqOWm5GfzlprO4sm+7EFRVP2a2\n0R/b8XlIx8xizWypmc31PO9kZp+Z2RozyzazBH8UJCJSH+OHZJAcH3vCshiDSVefGZFh70+nMob/\nS6DguOe/A55yznUD9gK3+7MwEZG6GN6/AzcN6EiMZ8SmWXI8f7i+HyMHpYe0rnDg05COmZ0OXAk8\nCow1MwMuBr7vWWUaMAl4JgA1ioj4ZP/hUh6as4IZSzbTu0NTptyQSdfWKaEuK2z4OoY/BbgHOLbn\nTgOKnHNlnuebgQ5+rk1ExGdfbNjDmOwcthaVcNe3u/KLS7qREKeJiMerNfDN7Cpgh3NuiZlddGxx\nNatW21jfzEYBowDS0tLqWKaISPWOllXw1H9W8+wH6+jYvBEz7hzI2We0CHVZYcmXM/zBwHfN7Aog\nCWhK5Rl/qpnFec7yTwe2Vvdm59xUYCpAVlaW7rYiIn6zevt+Rr+aw4rCYm48pyP3X9WLJolBnXwY\nUWr9e8c5N9E5d7pzLh24EVjgnPs/YCFwnWe1kcDsgFUpInKcigrH8x99xVV//ojtxYd57pYsJl/b\nV2Ffi/rsnXuBV83sEWAp8Lx/ShIR8a5wXwnjZuTy8drdXNKjNZOv7UurlMRQlxURTinwnXPvA+97\nHq8HBvi/JBGR6r2Zu5X7X8+nrMLx2DV9uPGcjlgEfWM21PT3j4iEvX2HSvnNm8uYnbOV/mmpPDUi\nk/SWjUNdVsRR4ItIWPt47S7Gzchlx/4jjP1Od356URfiYjXdsi4U+CISlg6XlvP4/FU8/9FXdG7V\nmFk/GUS/jqmhLiuiKfBFJOws37qPMdk5rN5+gFsGnsHEy3uSnBBb+xulRgp8EQkb5RWOqYvW8+S7\nq0htlMCLPzyHizJah7qsBkOBLyJhYdOeQ9w9PZfPN+zh8t5tefR7fWjRWE14/UmBLyIh5Zxj5pLN\nPDhnBQB/uL4f15zVQdMtA0CBLyIhs+fgUe6blc87y7cxIL0FfxjRj44tGoW6rAZLgS8iIbFw1Q7u\nmZlH0aGjTLi8Bz+6oDOxMTqrDyQFvogE1aGjZTw2byUvfbqR7m2aMO2HA+jVvmmoy4oKCnwRCZqc\nTUWMzc5h/a6D3HF+J8YNySApXtMtg0WBLyIBV1ZewdML1/GnBWtok5LIK3ecy6CuLUNdVtRR4ItI\nQH216yBjsnPI2VTE8Mz2PDisN82S40NdVlRS4ItIQDjneOXzr3lkbgHxscafb+rP1f3ah7qsqKbA\nFxG/27H/MBNey2fByh2c37Ulj1/fl3bNkkNdVtRT4IuIX81fvo2Js/I5eKSMB67uxciB6cRoumVY\nUOCLiF8cOFLGQ3OWM33xZs5s35QpN2TSrU1KqMuS4yjwRaTeFm/Yw5jpOWzZW8JPL+rC6Eu7kxCn\nnvXhptbAN7MkYBGQ6Fl/pnPuATN7EfgWsM+z6q3OuZxAFSoi4edoWQV/fG81z7y/jg7Nk8n+8UDO\nSW8R6rLEC1/O8I8AFzvnDphZPPCRmb3teW28c25m4MoTkXC1dsd+RmfnsGxLMdeffTq/uboXKUma\nbhnOag1855wDDniexnt+XCCLEpHwVVHhmPbJBia/vZLGiXH89eazGXJm21CXJT7waZDNzGLNLAfY\nAbzrnPvM89KjZpZnZk+ZWaKX944ys8Vmtnjnzp1+KltEQmHbvsOM/PvnPDhnBYO7tuSd0Rco7COI\nVZ7A+7iyWSrwOvBzYDewDUgApgLrnHMP1fT+rKwst3jx4rpXKyIhMyd3K/e/sYyjZRXcf1VPvj8g\nrd49699YuoXH569ia1EJ7VOTGT8kg+H9O/ip4obDzJY457Lqu51TmqXjnCsys/eBoc65JzyLj5jZ\n34Fx9S1GRMLPvpJSHpi9jDdytpLZMZWnbsikU8vG9d7uG0u3MHFWPiWl5QBsKSph4qx8AIV+gNQ6\npGNmrTxn9phZMnApsNLM2nmWGTAcWBbIQkUk+P67bheXT1nEnLxCxlzanZl3DvRL2AM8Pn/VN2F/\nTElpOY/PX+WX7cvJfDnDbwdMM7NYKn9BTHfOzTWzBWbWCjAgB7gzgHWKSBAdLi3nD/9exd8++or0\n0xrz2k8Gkdkx1a+fsbWo5JSWS/35MksnD+hfzfKLA1KRiIRUQWExY7JzWLltPz84L437ruhJowT/\nf0ezfWoyW6oJ9/ap6rkTKPoqnIgAUF7hmLpoHcP+8jG7Dx7l7z88h0eG9wlI2AOMH5JBcpWbnyTH\nxzJ+SEZAPk/UWkFEgM17D3H39Fw++2oPQ85sw2PX9KVF44SAfuaxC7OapRM8CnyRKOacY9aXW5j0\n5nIc8MT1/bj2rA71nm7pq+H9Oyjgg0iBLxKl9h48yq/eyGde/jbOSW/OkyMy6diiUajLkgBS4ItE\noQ9W72T8jFz2HjrKvUN7MOrCzsSqZ32Dp8AXiSIlR8uZ/HYB0z7ZSLfWTXjh1nPo3aFZqMuSIFHg\ni0SJvM1FjMnOYd3Og9w2uBP3DM0gqcosGWnYFPgiDVxZeQXPvL+OP763hpZNEnn5jnMZ3LVlqMuS\nEFDgizRgG3cfZEx2Dl9+XcR3+7Xn4WG9adZIPeujlQJfpAFyzvHqF5t4eO4K4mKMP96YybBMTX+M\ndgp8kQZm5/4jTJyVx38KdjCoy2k8cX0/tSsQQIEv0qD8Z8V27n0tj/1Hyvj1Vb344aB0YjTdUjwU\n+CINwMEjZTw8dwWvfrGJnu2a8soNmWS0TQl1WRJmFPgiEW7Jxr2Myc5h095D3PmtLoz5TjcS4zTd\nUk6mwBeJUKXlFfzpvTU8vXAt7Zolkz1qIAM6tQh1WRLGFPgiEWjtjgOMyc4hf8s+rjv7dB64uhcp\nSZpuKTVT4ItEEOcc//hkI7+dV0CjhFie/cFZDO3dLtRlSYSoNfDNLAlYBCR61p/pnHvAzDoBrwIt\ngC+Bm51zRwNZrEg02158mPEz81i0eicXZbTi99f2pXXTpFCXJRHElzP8I8DFzrkDZhYPfGRmbwNj\ngaecc6+a2bPA7cAzAaxVJGrNyy/kvtfzOVxazsPDe/ODc9OC1rNeGg5f7mnrgAOep/GeHwdcDHzf\ns3waMAkFvohfFR8uZdLs5cxauoV+pzfjyRsy6dKqSajLkgjl0xi+mcUCS4CuwNPAOqDIOVfmWWUz\noO9ti/jRp+t3c/f0XLYVH+aXl3Tjrou7Eh+r21BL3fkU+M65ciDTzFKB14Ge1a1W3XvNbBQwCiAt\nLa2OZYpEjyNl5Tz579VM/XA9Z7RoxMw7B9I/rXmoy5IG4JRm6TjniszsfeA8INXM4jxn+acDW728\nZyowFSArK6vaXwoiUmnltmJGv5rDym37+f65adx/ZU8aJWgynfhHrX8fmlkrz5k9ZpYMXAoUAAuB\n6zyrjQRmB6pIkYauosLxtw/Xc+WfPmL19v0AfLBqJ/9evj3ElUlD4supQztgmmccPwaY7pyba2Yr\ngFfN7BFgKfB8AOsUabC2FJUwbnoun6zfTYxBhfvf8omz8gEY3l+XyKT+fJmlkwf0r2b5emBAIIoS\niQbOOWbnbOXXs5dRUeFITY6nqKT0hHVKSst5fP4qBb74hS75i4RA0aGj3PWvpYzOziGjTQpv//JC\n9lUJ+2O2FpUEuTppqHQ1SCTIPlyzk3Ezctl94Cjjh2Rw57e6EBtjtE9NZks14a6bl4i/6AxfJEgO\nl5Yz6c3l3Pz856QkxfPGzwbzs293JdZzg5LxQzJIjj+xrXFyfCzjh2SEolxpgHSGLxIEy7bsY3R2\nDmt3HODWQelMuLwHSVXC/dg4/ePzV7G1qIT2qcmMH5Kh8XvxGwW+SACVVzie/WAdT727mtOaJPDS\n7QO4oFsrr+sP799BAS8Bo8AXCZCvdx9izPQclmzcy5V92/Ho8N6kNkoIdVkSxRT4In7mnGPG4s08\nOGc5MTHGlBsyGZbZXt0tJeQU+CJ+tPvAESbOyuffK7YzsPNpPDGiHx00y0bChAJfxE8WrNzOPTPz\nKC4p4/4re3Lb4E7ExOisXsKHAl+kng4eKePReQW88tnX9Gibwj/vOJcebZuGuiyRkyjwRerhy6/3\nMjY7h417DvHjCzsz9rLuJMbF1v5GkRBQ4IvUQWl5BX9esJanF66lbdMk/vWj8ziv82mhLkukRgp8\nkVO0bucBxmbnkLt5H9f078CkYWfSNCk+1GWJ1EqBL+Ij5xz//OxrHn1rBUnxsTz9/bO4sm+7UJcl\n4jMFvogPdhQf5p7X8nh/1U4u7N6Kx6/rS5umSaEuS+SUKPBFvHhj6RYen7+KLUUlxBjExhgPDTuT\nm887Q1+ikoikwBepxhtLtzDhtTwOl1UAlXehijejaVK8wl4iltoji1TjkbdWfBP2xxwpq+Dx+atC\nVJFI/flyE/OOZrbQzArMbLmZ/dKzfJKZbTGzHM/PFYEvVySwjpSVM/ntlew6cLTa13X3KYlkvgzp\nlAF3O+e+NLMUYImZvet57Snn3BOBK08keFZv388vX82hoLCYRgmxHDpaftI6uvuURDJfbmJeCBR6\nHu83swJADbulwaiocPz9vxv43TsrSUmM47lbsjh4pIyJs/IpKf1f6Ef73aeOXcTWzVki1yldtDWz\ndKA/8BkwGLjLzG4BFlP5V8Deat4zChgFkJaWVs9yRfyrcF8J42bk8vHa3VzaszWTr+1LyyaJ37yu\ngKv0xtItJ/wC3FJUwsRZ+QBRu08ikTnnfFvRrAnwAfCoc26WmbUBdgEOeBho55y7raZtZGVlucWL\nF9ezZBH/mJ2zhV+/sYzDZRU0io+lqKSUDlEe7N4Mnryg2husd0hN5uMJF4egouhiZkucc1n13Y5P\nZ/hmFg+8BrzsnJsF4JzbftzrzwFz61uMSDDsO1TKr2cv483craSf1ojCfYcpKikFdObqjbeL1bqI\nHVl8maVjwPNAgXPuyeOWH/+d8u8By/xfnoh/fbx2F0OmLGJefiHjLuvO0bIKjlSZfllSWq7pl1V4\nu1iti9iRxZd5+IOBm4GLq0zB/L2Z5ZtZHvBtYEwgCxWpj8Ol5Tw0ZwX/97fPaJQYy6yfDuKui7tR\nuO9wtevrzPVE44dkkBx/YtvnaL+IHYl8maXzEVDdVwvn+b8cEf9bvnUfo1/NYc2OA4wceAYTLu9J\nckJleLVPTa52bFpnric6Nryli9iRTa0VpMEqr3BMXbSeJ99dRfNGCUy7bQDf6t7qhHXGD8nQ9Esf\nDe/fQQEf4RT40iBt2nOIsdNz+GLDXq7o05ZHh/eheeOEk9bTmatEEwW+NCjOOWYu2cyDc1ZgwJMj\n+vG9/h1qbHimM1eJFgp8aTD2HDzKxFl5zF++nQGdWvDkiH6c3rxRqMsSCRsKfGkQFq7cwfiZeRSX\nlHLfFT24/fzOxMaojbHI8RT4EtEOHS3jt/MK+OenX5PRJoWXbh9Az3ZNQ12WSFhS4EvEytlUxNjs\nHL7afZA7zu/EuCEZJFWZKy4i/6PAl4hTVl7BXxau5c8L1tImJZGX7ziXQV1ahroskbCnwJeI8tWu\ng4zOziF3UxHDM9vz4LDeNEuOD3VZIhFBgS8RwTnHK59/zSNzC0iIi+HPN/Xn6n7tQ12WSERR4EvY\n27H/MPfOzGPhqp1c0K0lj1/Xj7bNkkJdlkjEUeBLWJu/fBsTZ+Vz8EgZk67uxS0D04mpx3RL3bVJ\nopkCX8LSgSNlPPjmcmYs2cyZ7Zsy5YZMurVJqdc2ddcmiXYKfAk7X2zYw9jpOWzZW8Jd3+7KLy7p\nRkKcL528a/b4/FUnNEmD//W+V+BLNFDgS9g4WlbBlP+s5tkP1nF680ZM//FAstJb+G37umuTRDsF\nvoSF1dv3M/rVHFYUFnNDVkd+fXUvmiTW7fD0Nk6v3vcS7RT4ElIVFY4X/7uBye+spEliHH+9+WyG\nnNm2zturaZxeve8l2tUa+GbWEfgH0BaoAKY65/5oZi2AbCAd2ACMcM7tDVyp0tAU7ith/Iw8Plq7\ni0t6tGbytX1plZJYr23WNE7/8YSLv1lHs3QkGvlyhl8G3O2c+9LMUoAlZvYucCvwnnNusplNACYA\n9wauVGlI5uRu5Vev51Na7vjt9/pw04CONfas91Vt4/TqfS/RzJd72hYChZ7H+82sAOgADAMu8qw2\nDXgfBb7UYt+hUn7z5jJm52wls2MqU27IJL1lY79tX+P0It6d0lw3M0sH+gOfAW08vwyO/VJo7e/i\npGH579pdDP3jIubmFTL2O92ZeedAv4Y9VI7TJ1fpmKlxepFKPl+0NbMmwGvAaOdcsa9/fpvZKGAU\nQFpaWl1qlDB0Kt9YPewZQ3/+o6/o3LIxs34yiH4dUwNSl+5RK+KdOedqX8ksHpgLzHfOPelZtgq4\nyDlXaGbtgPedczWeRmVlZbnFixf7oWwJpaozYaDyLPqxa/qcFKwrthYzOnspq7cf4JaBZzDx8p4k\nJ6hnvcipMLMlzrms+m6n1iEdqzyVfx4oOBb2Hm8CIz2PRwKz61uMRIaaZsIcU17hePaDdQx7+iP2\nHirlxR+ew0PDeivsRULIlyGdwcDNQL6Z5XiW3QdMBqab2e3A18D1gSlRwk1tM2E27TnE3dNz+XzD\nHoae2ZbfXtOHFo0TglmiiFTDl1k6HwHeBuwv8W85Egm8zYRp1yyJmUs2M+nN5QA8cX0/rj2rg1+m\nW4pI/dW/I5VEnepmwiTFxdAqJZFxM3Lp1a4pb//yAq47+3SFvUgYUWsFOWVVZ8K0aJxAaXkFKwqL\nmXB5D350QWdi69GzXkQCQ4EvdTK8fweGnNmW384r4KVPN9K9TROm3NCfXu2bhro0EfFCgS91krup\niDHZOazfdZA7zu/EuCEZJMVrBo5IOFPgyykpK6/g6YXr+NOCNbROSeSVO85lUNeWoS5LRHygwBef\nfbXrIGOyc8jZVMSwzPY89N3eNGsUH+qyRMRHCnyplXOOf32+iYfnriA+1vjTTf35br/2oS5LRE6R\nAl9qtHP/ESa8lsd7K3cwuOtpPHF9P9o1U+dJkUikwBev3l2xnQmv5bH/SBm/uaoXtw5KJ0bTLUUi\nlgJfTnLgSBkPz1lB9uJN9GrXlFdvzKRbm5RQlyUi9aTAr4dTaREcKZZs3MOY7Fw27T3ETy/qwuhL\nu5MQpy9kizQECvw6qulm2ZEY+kfLKvjje6t55v11tE9NZvqPB3JOeotQlyUifqTAr6OaWgRHWuCv\n3bGf0dk5LNtSzPVnn85vru5FSpKmW4o0NAr8OqqtRXAkqKhw/OOTDTz29koaJcTy7A/OYmjvdqEu\nS0QCRIFfR5F+s+ztxYcZNyOXD9fs4qKMVvz+ur60TkkKdVkiEkC6GldHkXyz7LfyCrnsqUV8sWEP\njwzvzd9vPUdhLxIFdIZfR5F4s+ziw6U8MHs5ry/dQr+OqTw1oh+dWzUJdVkiEiQK/HoY3r9DWAf8\n8T5dv5u7p+eyrfgwoy/txs++3ZX42FP7A68hTkMViSa+3MT8BTPbYWbLjls2ycy2mFmO5+eKwJYp\ndXWkrJzfzivgpuc+JSEuhpl3DmT0pd3rFPYTZ+WzpagEx/+mob6xdEtgChcRv/Plv/oXgaHVLH/K\nOZfp+Znn37LEHwoKixn2l4+Zumg93x+Qxlu/OJ/+ac3rtK2apqGKSGTw5Sbmi8wsPfCliL9UVDj+\n9tF6npi/mqbJ8bxwaxYX92hTr202hGmoItGuPmP4d5nZLcBi4G7n3N7qVjKzUcAogLS0tHp8nPhi\n895DjJuRy6fr93BZrzY8dk0fTmuSWO/tRvo0VBGp+7TMZ4AuQCZQCPzB24rOuanOuSznXFarVq3q\n+HFSG+ccry/dzOVTPiR/8z5+f11f/nrz2X4Je4jsaagiUqlOZ/jOue3HHpvZc8Bcv1Ukp6zo0FF+\n9foy3sovJOuM5jx1QyYdWzTy62dE4jRUETlRnQLfzNo55wo9T78HLKtpfQmcD9fsZNyMXPYcPMo9\nQzP48YVdiA1Qz/pImoYqIierNfDN7F/ARUBLM9sMPABcZGaZgAM2AD8OYI1SjcOl5Ux+eyUv/ncD\n3Vo34fmR59C7Q7NQlyUiYcyXWTo3VbP4+QDUIj7K37yPMdNzWLvjALcN7sQ9QzNIqjK+LiJSlb5p\nG0HKyit49oN1TPnPGlo2SeSft5/L+d1ahrosEYkQCvwIsXH3QcZOz2XJxr1c1bcdjwzvTWqjhFCX\nJSIRRIEf5pxzTF+8iYfmrCAmxvjjjZkMy9SFUxE5dQr8MLbrwBEmzsrn3RXbGdj5NP4wop++6CQi\ndabAD1P/WbGdCbPyKD5cxv1X9uS2wZ2ICdB0SxGJDgr8MHPwSBmPvLWCf32+iZ7tmvLyHZlktE0J\ndVki0gAo8MPEG0u38OhbBew8cASAi3u05pkfnEViXO3TLdWnXkR8oVschoHXlmxm3Izcb8Ie4JN1\nu3k7f1ut71WfehHxlQI/xNbtPMCEWXmUVbgTlvvaa1596kXEVxrSCRHnHC99upGH566gtNxVu44v\nvebVp15EfKXAD4EdxYcZPzOPD1bvpKaJN75MwVSfehHxlYZ0guzt/EIum7KIz77aTbPkeCqqP7n3\nude8+tSLiK8U+EFSfLiUsdNz+MnLX5LWohFzf34BxSWlXtd/7Jo+Ps20Gd6/A49d04cOqckY0CE1\n2ef3ikh00ZBOEHy2fjdjp+dSuK+EX1zclZ9f0o342BivwzEdUpNPKbDVp15EfKEz/AA6UlbOY28X\ncONznxIXa8z8ySDGXpZBfGzlbtdwjIgEk87wA2TVtv2Mzs6hoLCYmwakcf+VPWmceOLu1m0DRSSY\nFPh+VlHheOHjr/j9O6tomhzH327J4tJebbyur+EYEQkWX25x+AJwFbDDOdfbs6wFkA2kU3mLwxHO\nub2BKzMybC0q4e7puXyyfjeX9mzD5Gv70LJJYqjLEhEBfBvDfxEYWmXZBOA951w34D3P86g2O2cL\nQ6YsIndzEb+7tg/P3XK2wl5Ewoov97RdZGbpVRYPo/LG5gDTgPeBe/1YV8TYd6iU+2cvY07uVs4+\nozlPjujHGac1DnVZIiInqesYfhvnXCGAc67QzFr7saawUlMnyo/W7GLcjFx2HTjCuMu6c+e3uhAX\nq4lPIhKeAn7R1sxGAaMA0tLS/L79QLYGPtaJ8lhzsmOdKI+WVVCwrZi/f7yBLq0a89wtg+lzejO/\nfKaISKDUNfC3m1k7z9l9O2CHtxWdc1OBqQBZWVleGgnUjbdABvwS+t46Ud73ej5lFY5bB6Vz79Ae\nJCfU3rNeRCTU6jr+8CYw0vN4JDDbP+WcmkC3BvbWcbKswvGP2wYw6btnKuxFJGLUGvhm9i/gEyDD\nzDab2e3AZOA7ZrYG+I7nedAFujWwt46T7ZomcWH3Vn75DBGRYPFlls5NXl66xM+1nLJAtwYed1l3\n7nkt74R+9UlxMdx7eQ+/bF9EJJgiekpJIHvR7D5whHeWb6O03JEQV7mbOqQmM/navvpmrIhEpIhu\nrRCoXjQLV+5g/Mw8iktKue+KHtxxfmdiarpTiYhIBIjowAf/9qI5dLSMR98q4OXPvqZH2xReun0A\nPds19cu2RURCLeID31+Wfr2XsdNz2bD7IKMu7MzY73QnKV4zcESk4Yj6wC8tr+AvC9byl4Vrads0\niVfuOI+BXU4LdVkiIn4X1YG/fucBxkzPJXdTEdf078CkYWfSNCk+1GWJiARERM/SqSvnHP/8dCND\np3xI3uYiAD77ag8LCrx+YVhEJOJF3Rn+jv2HuXdmHgtX7STGwHmm2Pu7LYOISLiJqjP8d5ZtY+iU\nD/nvut00S46nokpnH3+2ZRARCTdREfj7D5cyfkYud/5zCe1Tk3jrF+dTXFJa7br+assgIhJuGvyQ\nzhcb9jAmO4etRSXc9e2u/OKSbiTExQS8LYOISLhpsGf4R8sq+N07Kxnx10+IMWPGnQMZNyTjmzYJ\ngWzLICISjhrkGf7q7fsZ/WoOKwqLufGcjtx/VS+aJJ74PzVQbRlERMJVxAV+TXe4qqhw/P2/G/jd\nOytJSYzjuVuy+E6vNl635c+2DCIi4S6iAr+mO1yd27kF42bk8vHa3VzSozWTr+1Lq5TEUJYrIhJW\nIirwvd3h6sE5yymvcJRVOB67pg83ntMRM3W3FBE5XkQFvrcpk3sPldI/LZWnRmSS3rJxkKsSEYkM\nERX43qZSpiTFMePHA4mLbbCTjkRE6q1eCWlmG8ws38xyzGyxv4ryZvyQDJLiTiw5MTaGh4f1VtiL\niNTCH2f433bO7fLDdmrVrU0TUhslsK34MADtmiVx79AemmkjIuKDiBjSKa9wTF20niffXUXzRglM\nu20A3+reKtRliYhElPoGvgP+bWYO+KtzbmrVFcxsFDAKIC0t7ZQ/YNOeQ9w9PZfPN+zhij5teXR4\nH5o3Tqhn2SIi0ae+gT/YObfVzFoD75rZSufcouNX8PwSmAqQlZXlqttIdZxzzFyymQfnrMCAJ0f0\n43v9O2i6pYhIHdUr8J1zWz3/7jCz14EBwKKa31W7PQePct+sfN5Zvo0BnVrw5Ih+nN68UX03KyIS\n1eoc+GbWGIhxzu33PL4MeKi+BS1ctYN7ZuZRdOgoEy/vwR0XdCY2Rmf1IiL1VZ8z/DbA654hljjg\nFefcO3XdWMnRcn47r4CXPt1rJDRZAAAHNElEQVRIRpsUpv1wAL3aN61HeSIicrw6B75zbj3Qzx9F\n5G4qYkx2Dl/tPsiPLujE3ZdlkFSldbGIiNRPSKdllpVX8PTCdfxpwRrapCTy8h3nMqhLy1CWJCLS\nYIUs8L/adZAx2TnkbCpieGZ7BnRqwfgZeepNLyISIEEPfOccr3z+NY/MLSAhLoY/39Sf8grnte2x\nQl9ExD+CGvhlFY47pi3mvZU7OL9rS564vh9tmyUxePKCatsePz5/lQJfRMRPghr4a7bvp2TtLh64\nuhcjB6YT45lu6a3tsbflIiJy6oIa+HGxMcz9+fl0a5NywnJvbY/bpyYHqzQRkQYvqD2Fu7ZqclLY\nQ2Xb4+Qq0zCT42MZPyQjWKWJiDR4QT3D99YG59g4vbebk4uISP2FTXvk4f07KOBFRAJIt4kSEYkS\nCnwRkSihwBcRiRIKfBGRKKHAFxGJEgp8EZEoocAXEYkSCnwRkShRr8A3s6FmtsrM1prZBH8VJSIi\n/lfnwDezWOBp4HKgF3CTmfXyV2EiIuJf9TnDHwCsdc6td84dBV4FhvmnLBER8bf6BH4HYNNxzzd7\nlomISBiqT/O06npfupNWMhsFjPI8PWJmy+rxmcHSEtgV6iJ8oDr9JxJqBNXpb5FSp196xdcn8DcD\nHY97fjqwtepKzrmpwFQAM1vsnMuqx2cGher0r0ioMxJqBNXpb5FUpz+2U58hnS+AbmbWycwSgBuB\nN/1RlIiI+F+dz/Cdc2VmdhcwH4gFXnDOLfdbZSIi4lf1ugGKc24eMO8U3jK1Pp8XRKrTvyKhzkio\nEVSnv0VVnebcSddZRUSkAVJrBRGRKBGQwK+t5YKZJZpZtuf1z8wsPRB11FJjRzNbaGYFZrbczH5Z\nzToXmdk+M8vx/Pwm2HV66thgZvmeGk66Wm+V/uTZn3lmdlaQ68s4bh/lmFmxmY2usk5I9qWZvWBm\nO46fDmxmLczsXTNb4/m3uZf3jvSss8bMRoagzsfNbKXn/9PXzSzVy3trPD6CUOckM9ty3P+3V3h5\nb9BasXipM/u4GjeYWY6X9wZlf3rLoIAen845v/5QeQF3HdAZSABygV5V1vkp8Kzn8Y1Atr/r8KHO\ndsBZnscpwOpq6rwImBvs2qqpdQPQsobXrwDepvK7EecBn4Ww1lhgG3BGOOxL4ELgLGDZcct+D0zw\nPJ4A/K6a97UA1nv+be553DzIdV4GxHke/666On05PoJQ5yRgnA/HRY25EOg6q7z+B+A3odyf3jIo\nkMdnIM7wfWm5MAyY5nk8E7jEzKr7IlfAOOcKnXNfeh7vBwqI3G8KDwP+4Sp9CqSaWbsQ1XIJsM45\ntzFEn38C59wiYE+Vxccff9OA4dW8dQjwrnNuj3NuL/AuMDSYdTrn/u2cK/M8/ZTK77qElJf96Yug\ntmKpqU5P1owA/hWoz/dFDRkUsOMzEIHvS8uFb9bxHND7gNMCUItPPENK/YHPqnl5oJnlmtnbZnZm\nUAv7Hwf828yWWOU3l6sKpzYXN+L9P6Rw2JcAbZxzhVD5Hx3Qupp1wmmfAtxG5V9x1ant+AiGuzxD\nTy94GYIIp/15AbDdObfGy+tB359VMihgx2cgAt+Xlgs+tWUIBjNrArwGjHbOFVd5+Usqhyb6AX8G\n3gh2fR6DnXNnUdmZ9GdmdmGV18Nif1rlF/C+C8yo5uVw2Ze+Cot9CmBmvwLKgJe9rFLb8RFozwBd\ngEygkMrhkqrCZn8CN1Hz2X1Q92ctGeT1bdUsq3V/BiLwfWm58M06ZhYHNKNufybWi5nFU7mjX3bO\nzar6unOu2Dl3wPN4HhBvZi2DXCbOua2ef3cAr1P55/HxfGpzEQSXA18657ZXfSFc9qXH9mNDXp5/\nd1SzTljsU8/FuKuA/3OewduqfDg+Aso5t905V+6cqwCe8/L54bI/44BrgGxv6wRzf3rJoIAdn4EI\nfF9aLrwJHLuqfB2wwNvBHCiecbzngQLn3JNe1ml77NqCmQ2gcn/tDl6VYGaNzSzl2GMqL+RVbUD3\nJnCLVToP2HfsT8Ig83rmFA778jjHH38jgdnVrDMfuMzMmnuGKC7zLAsaMxsK3At81zl3yMs6vhwf\nAVXletH3vHx+uLRiuRRY6ZzbXN2LwdyfNWRQ4I7PAF19voLKK87rgF95lj1E5YELkETln/1rgc+B\nzoG8Gu6lxvOp/BMoD8jx/FwB3Anc6VnnLmA5lTMKPgUGhaDOzp7Pz/XUcmx/Hl+nUXkzmnVAPpAV\ngjobURngzY5bFvJ9SeUvoEKglMqzotupvF70HrDG828Lz7pZwN+Oe+9tnmN0LfDDENS5lspx2mPH\n57GZbe2BeTUdH0Gu8yXPcZdHZVi1q1qn5/lJuRDMOj3LXzx2TB63bkj2Zw0ZFLDjU9+0FRGJEvqm\nrYhIlFDgi4hECQW+iEiUUOCLiEQJBb6ISJRQ4IuIRAkFvohIlFDgi4hEif8HC5wNJAVBQWYAAAAA\nSUVORK5CYII=\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "2.0188677310943604 2.8898627758026123\n" ] } ], "source": [ "# 随机初始化参数\n", "w = V(t.rand(1,1), requires_grad=True)\n", "b = V(t.zeros(1,1), requires_grad=True)\n", "\n", "lr =0.001 # 学习率\n", "\n", "for ii in range(8000):\n", " x, y = get_fake_data()\n", " x, y = V(x), V(y)\n", " \n", " # forward:计算loss\n", " y_pred = x.mm(w) + b.expand_as(y)\n", " loss = 0.5 * (y_pred - y) ** 2\n", " loss = loss.sum()\n", " \n", " # backward:手动计算梯度\n", " loss.backward()\n", " \n", " # 更新参数\n", " w.data.sub_(lr * w.grad.data)\n", " b.data.sub_(lr * b.grad.data)\n", " \n", " # 梯度清零\n", " w.grad.data.zero_()\n", " b.grad.data.zero_()\n", " \n", " if ii%1000 ==0:\n", " # 画图\n", " display.clear_output(wait=True)\n", " x = t.arange(0, 20).view(-1, 1)\n", " y = x.mm(w.data) + b.data.expand_as(x)\n", " plt.plot(x.numpy(), y.numpy()) # predicted\n", " \n", " x2, y2 = get_fake_data(batch_size=20) \n", " plt.scatter(x2.numpy(), y2.numpy()) # true data\n", " \n", " plt.xlim(0,20)\n", " plt.ylim(0,41) \n", " plt.show()\n", " plt.pause(0.5)\n", " \n", "print(w.data.squeeze()[0], b.data.squeeze()[0])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "用autograd实现的线性回归最大的不同点就在于autograd不需要计算反向传播,可以自动计算微分。这点不单是在深度学习,在许多机器学习的问题中都很有用。另外需要注意的是在每次反向传播之前要记得先把梯度清零。\n", "\n", "本章主要介绍了PyTorch中两个基础底层的数据结构:Tensor和autograd中的Variable。Tensor是一个类似Numpy数组的高效多维数值运算数据结构,有着和Numpy相类似的接口,并提供简单易用的GPU加速。Variable是autograd封装了Tensor并提供自动求导技术的,具有和Tensor几乎一样的接口。`autograd`是PyTorch的自动微分引擎,采用动态计算图技术,能够快速高效的计算导数。" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.8" } }, "nbformat": 4, "nbformat_minor": 2 }