{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# 第三章 PyTorch基础:Tensor和Autograd\n", "\n", "## 3.1 Tensor\n", "\n", "Tensor,又名张量,读者可能对这个名词似曾相识,因它不仅在PyTorch中出现过,它也是Theano、TensorFlow、\n", "Torch和MxNet中重要的数据结构。关于张量的本质不乏深度的剖析,但从工程角度来讲,可简单地认为它就是一个数组,且支持高效的科学计算。它可以是一个数(标量)、一维数组(向量)、二维数组(矩阵)和更高维的数组(高阶数据)。Tensor和Numpy的ndarrays类似,但PyTorch的tensor支持GPU加速。\n", "\n", "本节将系统讲解tensor的使用,力求面面俱到,但不会涉及每个函数。对于更多函数及其用法,读者可通过在IPython/Notebook中使用函数名加`?`查看帮助文档,或查阅PyTorch官方文档[^1]。\n", "\n", "[^1]: http://docs.pytorch.org" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "'0.4.1'" ] }, "execution_count": 1, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Let's begin\n", "from __future__ import print_function\n", "import torch as t\n", "t.__version__" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 3.1.1 基础操作\n", "\n", "学习过Numpy的读者会对本节内容感到非常熟悉,因tensor的接口有意设计成与Numpy类似,以方便用户使用。但不熟悉Numpy也没关系,本节内容并不要求先掌握Numpy。\n", "\n", "从接口的角度来讲,对tensor的操作可分为两类:\n", "\n", "1. `torch.function`,如`torch.save`等。\n", "2. 另一类是`tensor.function`,如`tensor.view`等。\n", "\n", "为方便使用,对tensor的大部分操作同时支持这两类接口,在本书中不做具体区分,如`torch.sum (torch.sum(a, b))`与`tensor.sum (a.sum(b))`功能等价。\n", "\n", "而从存储的角度来讲,对tensor的操作又可分为两类:\n", "\n", "1. 不会修改自身的数据,如 `a.add(b)`, 加法的结果会返回一个新的tensor。\n", "2. 会修改自身的数据,如 `a.add_(b)`, 加法的结果仍存储在a中,a被修改了。\n", "\n", "函数名以`_`结尾的都是inplace方式, 即会修改调用者自己的数据,在实际应用中需加以区分。\n", "\n", "#### 创建Tensor\n", "\n", "在PyTorch中新建tensor的方法有很多,具体如表3-1所示。\n", "\n", "表3-1: 常见新建tensor的方法\n", "\n", "|函数|功能|\n", "|:---:|:---:|\n", "|Tensor(\\*sizes)|基础构造函数|\n", "|ones(\\*sizes)|全1Tensor|\n", "|zeros(\\*sizes)|全0Tensor|\n", "|eye(\\*sizes)|对角线为1,其他为0|\n", "|arange(s,e,step|从s到e,步长为step|\n", "|linspace(s,e,steps)|从s到e,均匀切分成steps份|\n", "|rand/randn(\\*sizes)|均匀/标准分布|\n", "|normal(mean,std)/uniform(from,to)|正态分布/均匀分布|\n", "|randperm(m)|随机排列|\n", "\n", "其中使用`Tensor`函数新建tensor是最复杂多变的方式,它既可以接收一个list,并根据list的数据新建tensor,也能根据指定的形状新建tensor,还能传入其他的tensor,下面举几个例子。" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([[0.0000, 0.0000, 0.0000],\n", " [0.0000, 0.0000, 0.0000]])" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# 指定tensor的形状\n", "a = t.Tensor(2, 3)\n", "a # 数值取决于内存空间的状态" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([[1., 2., 3.],\n", " [4., 5., 6.]])" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# 用list的数据创建tensor\n", "b = t.Tensor([[1,2,3],[4,5,6]])\n", "b" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "b.tolist() # 把tensor转为list" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "`tensor.size()`返回`torch.Size`对象,它是tuple的子类,但其使用方式与tuple略有区别" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "scrolled": true }, "outputs": [ { "data": { "text/plain": [ "torch.Size([2, 3])" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "b_size = b.size()\n", "b_size" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "6" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "b.numel() # b中元素总个数,2*3,等价于b.nelement()" ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "scrolled": true }, "outputs": [ { "data": { "text/plain": [ "(tensor([[ 0.0000, 0.0000, 0.0000],\n", " [ 0.0000, -0.0000, 0.0000]]), tensor([2., 3.]))" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# 创建一个和b形状一样的tensor\n", "c = t.Tensor(b_size)\n", "# 创建一个元素为2和3的tensor\n", "d = t.Tensor((2, 3))\n", "c, d" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "除了`tensor.size()`,还可以利用`tensor.shape`直接查看tensor的形状,`tensor.shape`等价于`tensor.size()`" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "torch.Size([2, 3])" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "c.shape" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [], "source": [ "c.shape??" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "需要注意的是,`t.Tensor(*sizes)`创建tensor时,系统不会马上分配空间,只是会计算剩余的内存是否足够使用,使用到tensor时才会分配,而其它操作都是在创建完tensor之后马上进行空间分配。其它常用的创建tensor的方法举例如下。" ] }, { "cell_type": "code", "execution_count": 13, "metadata": { "scrolled": true }, "outputs": [ { "data": { "text/plain": [ "tensor([[1., 1., 1.],\n", " [1., 1., 1.]])" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "t.ones(2, 3)" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([[0., 0., 0.],\n", " [0., 0., 0.]])" ] }, "execution_count": 14, "metadata": {}, "output_type": "execute_result" } ], "source": [ "t.zeros(2, 3)" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([1, 3, 5])" ] }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "source": [ "t.arange(1, 6, 2)" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([ 1.0000, 5.5000, 10.0000])" ] }, "execution_count": 16, "metadata": {}, "output_type": "execute_result" } ], "source": [ "t.linspace(1, 10, 3)" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([[ 0.4388, 0.9361, 0.8411],\n", " [-1.0667, -0.5187, 0.5520]])" ] }, "execution_count": 17, "metadata": {}, "output_type": "execute_result" } ], "source": [ "t.randn(2, 3)" ] }, { "cell_type": "code", "execution_count": 18, "metadata": { "scrolled": true }, "outputs": [ { "data": { "text/plain": [ "tensor([2, 0, 4, 1, 3])" ] }, "execution_count": 18, "metadata": {}, "output_type": "execute_result" } ], "source": [ "t.randperm(5) # 长度为5的随机排列" ] }, { "cell_type": "code", "execution_count": 19, "metadata": { "scrolled": true }, "outputs": [ { "data": { "text/plain": [ "tensor([[1., 0., 0.],\n", " [0., 1., 0.]])" ] }, "execution_count": 19, "metadata": {}, "output_type": "execute_result" } ], "source": [ "t.eye(2, 3) # 对角线为1, 不要求行列数一致" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### 常用Tensor操作" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "通过`tensor.view`方法可以调整tensor的形状,但必须保证调整前后元素总数一致。`view`不会修改自身的数据,返回的新tensor与源tensor共享内存,也即更改其中的一个,另外一个也会跟着改变。在实际应用中可能经常需要添加或减少某一维度,这时候`squeeze`和`unsqueeze`两个函数就派上用场了。" ] }, { "cell_type": "code", "execution_count": 20, "metadata": { "scrolled": true }, "outputs": [ { "data": { "text/plain": [ "tensor([[0, 1, 2],\n", " [3, 4, 5]])" ] }, "execution_count": 20, "metadata": {}, "output_type": "execute_result" } ], "source": [ "a = t.arange(0, 6)\n", "a.view(2, 3)" ] }, { "cell_type": "code", "execution_count": 21, "metadata": { "scrolled": true }, "outputs": [ { "data": { "text/plain": [ "tensor([[0, 1, 2],\n", " [3, 4, 5]])" ] }, "execution_count": 21, "metadata": {}, "output_type": "execute_result" } ], "source": [ "b = a.view(-1, 3) # 当某一维为-1的时候,会自动计算它的大小\n", "b" ] }, { "cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([[[0, 1, 2]],\n", "\n", " [[3, 4, 5]]])" ] }, "execution_count": 22, "metadata": {}, "output_type": "execute_result" } ], "source": [ "b.unsqueeze(1) # 注意形状,在第1维(下标从0开始)上增加“1”" ] }, { "cell_type": "code", "execution_count": 23, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([[[0, 1, 2]],\n", "\n", " [[3, 4, 5]]])" ] }, "execution_count": 23, "metadata": {}, "output_type": "execute_result" } ], "source": [ "b.unsqueeze(-2) # -2表示倒数第二个维度" ] }, { "cell_type": "code", "execution_count": 24, "metadata": { "scrolled": true }, "outputs": [ { "data": { "text/plain": [ "tensor([[[[0, 1, 2],\n", " [3, 4, 5]]]])" ] }, "execution_count": 24, "metadata": {}, "output_type": "execute_result" } ], "source": [ "c = b.view(1, 1, 1, 2, 3)\n", "c.squeeze(0) # 压缩第0维的“1”" ] }, { "cell_type": "code", "execution_count": 25, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([[0, 1, 2],\n", " [3, 4, 5]])" ] }, "execution_count": 25, "metadata": {}, "output_type": "execute_result" } ], "source": [ "c.squeeze() # 把所有维度为“1”的压缩" ] }, { "cell_type": "code", "execution_count": 26, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([[ 0, 100, 2],\n", " [ 3, 4, 5]])" ] }, "execution_count": 26, "metadata": {}, "output_type": "execute_result" } ], "source": [ "a[1] = 100\n", "b # a修改,b作为view之后的,也会跟着修改" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "`resize`是另一种可用来调整`size`的方法,但与`view`不同,它可以修改tensor的大小。如果新大小超过了原大小,会自动分配新的内存空间,而如果新大小小于原大小,则之前的数据依旧会被保存,看一个例子。" ] }, { "cell_type": "code", "execution_count": 24, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "\n", " 0 100 2\n", "[torch.FloatTensor of size 1x3]" ] }, "execution_count": 24, "metadata": {}, "output_type": "execute_result" } ], "source": [ "b.resize_(1, 3)\n", "b" ] }, { "cell_type": "code", "execution_count": 25, "metadata": { "scrolled": true }, "outputs": [ { "data": { "text/plain": [ "\n", " 0.0000e+00 1.0000e+02 2.0000e+00\n", " 3.0000e+00 4.0000e+00 5.0000e+00\n", " 4.1417e+36 4.5731e-41 6.7262e-44\n", "[torch.FloatTensor of size 3x3]" ] }, "execution_count": 25, "metadata": {}, "output_type": "execute_result" } ], "source": [ "b.resize_(3, 3) # 旧的数据依旧保存着,多出的大小会分配新空间\n", "b" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### 索引操作\n", "\n", "Tensor支持与numpy.ndarray类似的索引操作,语法上也类似,下面通过一些例子,讲解常用的索引操作。如无特殊说明,索引出来的结果与原tensor共享内存,也即修改一个,另一个会跟着修改。" ] }, { "cell_type": "code", "execution_count": 26, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "\n", " 0.2355 0.8276 0.6279 -2.3826\n", " 0.3533 1.3359 0.1627 1.7314\n", " 0.8121 0.3059 2.4352 1.4577\n", "[torch.FloatTensor of size 3x4]" ] }, "execution_count": 26, "metadata": {}, "output_type": "execute_result" } ], "source": [ "a = t.randn(3, 4)\n", "a" ] }, { "cell_type": "code", "execution_count": 27, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "\n", " 0.2355\n", " 0.8276\n", " 0.6279\n", "-2.3826\n", "[torch.FloatTensor of size 4]" ] }, "execution_count": 27, "metadata": {}, "output_type": "execute_result" } ], "source": [ "a[0] # 第0行(下标从0开始)" ] }, { "cell_type": "code", "execution_count": 28, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "\n", " 0.2355\n", " 0.3533\n", " 0.8121\n", "[torch.FloatTensor of size 3]" ] }, "execution_count": 28, "metadata": {}, "output_type": "execute_result" } ], "source": [ "a[:, 0] # 第0列" ] }, { "cell_type": "code", "execution_count": 29, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0.6279084086418152" ] }, "execution_count": 29, "metadata": {}, "output_type": "execute_result" } ], "source": [ "a[0][2] # 第0行第2个元素,等价于a[0, 2]" ] }, { "cell_type": "code", "execution_count": 30, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "-2.3825833797454834" ] }, "execution_count": 30, "metadata": {}, "output_type": "execute_result" } ], "source": [ "a[0, -1] # 第0行最后一个元素" ] }, { "cell_type": "code", "execution_count": 31, "metadata": { "scrolled": true }, "outputs": [ { "data": { "text/plain": [ "\n", " 0.2355 0.8276 0.6279 -2.3826\n", " 0.3533 1.3359 0.1627 1.7314\n", "[torch.FloatTensor of size 2x4]" ] }, "execution_count": 31, "metadata": {}, "output_type": "execute_result" } ], "source": [ "a[:2] # 前两行" ] }, { "cell_type": "code", "execution_count": 32, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "\n", " 0.2355 0.8276\n", " 0.3533 1.3359\n", "[torch.FloatTensor of size 2x2]" ] }, "execution_count": 32, "metadata": {}, "output_type": "execute_result" } ], "source": [ "a[:2, 0:2] # 前两行,第0,1列" ] }, { "cell_type": "code", "execution_count": 33, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", " 0.2355 0.8276\n", "[torch.FloatTensor of size 1x2]\n", "\n", "\n", " 0.2355\n", " 0.8276\n", "[torch.FloatTensor of size 2]\n", "\n" ] } ], "source": [ "print(a[0:1, :2]) # 第0行,前两列 \n", "print(a[0, :2]) # 注意两者的区别:形状不同" ] }, { "cell_type": "code", "execution_count": 34, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "\n", " 0 0 0 0\n", " 0 1 0 1\n", " 0 0 1 1\n", "[torch.ByteTensor of size 3x4]" ] }, "execution_count": 34, "metadata": {}, "output_type": "execute_result" } ], "source": [ "a > 1 # 返回一个ByteTensor" ] }, { "cell_type": "code", "execution_count": 35, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "\n", " 1.3359\n", " 1.7314\n", " 2.4352\n", " 1.4577\n", "[torch.FloatTensor of size 4]" ] }, "execution_count": 35, "metadata": {}, "output_type": "execute_result" } ], "source": [ "a[a>1] # 等价于a.masked_select(a>1)\n", "# 选择结果与原tensor不共享内存空间" ] }, { "cell_type": "code", "execution_count": 36, "metadata": { "scrolled": true }, "outputs": [ { "data": { "text/plain": [ "\n", " 0.2355 0.8276 0.6279 -2.3826\n", " 0.3533 1.3359 0.1627 1.7314\n", "[torch.FloatTensor of size 2x4]" ] }, "execution_count": 36, "metadata": {}, "output_type": "execute_result" } ], "source": [ "a[t.LongTensor([0,1])] # 第0行和第1行" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "其它常用的选择函数如表3-2所示。\n", "\n", "表3-2常用的选择函数\n", "\n", "函数|功能|\n", ":---:|:---:|\n", "index_select(input, dim, index)|在指定维度dim上选取,比如选取某些行、某些列\n", "masked_select(input, mask)|例子如上,a[a>0],使用ByteTensor进行选取\n", "non_zero(input)|非0元素的下标\n", "gather(input, dim, index)|根据index,在dim维度上选取数据,输出的size与index一样\n", "\n", "\n", "`gather`是一个比较复杂的操作,对一个2维tensor,输出的每个元素如下:\n", "\n", "```python\n", "out[i][j] = input[index[i][j]][j] # dim=0\n", "out[i][j] = input[i][index[i][j]] # dim=1\n", "```\n", "三维tensor的`gather`操作同理,下面举几个例子。" ] }, { "cell_type": "code", "execution_count": 37, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "\n", " 0 1 2 3\n", " 4 5 6 7\n", " 8 9 10 11\n", " 12 13 14 15\n", "[torch.FloatTensor of size 4x4]" ] }, "execution_count": 37, "metadata": {}, "output_type": "execute_result" } ], "source": [ "a = t.arange(0, 16).view(4, 4)\n", "a" ] }, { "cell_type": "code", "execution_count": 38, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "\n", " 0 5 10 15\n", "[torch.FloatTensor of size 1x4]" ] }, "execution_count": 38, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# 选取对角线的元素\n", "index = t.LongTensor([[0,1,2,3]])\n", "a.gather(0, index)" ] }, { "cell_type": "code", "execution_count": 39, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "\n", " 3\n", " 6\n", " 9\n", " 12\n", "[torch.FloatTensor of size 4x1]" ] }, "execution_count": 39, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# 选取反对角线上的元素\n", "index = t.LongTensor([[3,2,1,0]]).t()\n", "a.gather(1, index)" ] }, { "cell_type": "code", "execution_count": 40, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "\n", " 12 9 6 3\n", "[torch.FloatTensor of size 1x4]" ] }, "execution_count": 40, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# 选取反对角线上的元素,注意与上面的不同\n", "index = t.LongTensor([[3,2,1,0]])\n", "a.gather(0, index)" ] }, { "cell_type": "code", "execution_count": 41, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "\n", " 0 3\n", " 5 6\n", " 10 9\n", " 15 12\n", "[torch.FloatTensor of size 4x2]" ] }, "execution_count": 41, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# 选取两个对角线上的元素\n", "index = t.LongTensor([[0,1,2,3],[3,2,1,0]]).t()\n", "b = a.gather(1, index)\n", "b" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "与`gather`相对应的逆操作是`scatter_`,`gather`把数据从input中按index取出,而`scatter_`是把取出的数据再放回去。注意`scatter_`函数是inplace操作。\n", "\n", "```python\n", "out = input.gather(dim, index)\n", "-->近似逆操作\n", "out = Tensor()\n", "out.scatter_(dim, index)\n", "```" ] }, { "cell_type": "code", "execution_count": 42, "metadata": { "scrolled": true }, "outputs": [ { "data": { "text/plain": [ "\n", " 0 0 0 3\n", " 0 5 6 0\n", " 0 9 10 0\n", " 12 0 0 15\n", "[torch.FloatTensor of size 4x4]" ] }, "execution_count": 42, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# 把两个对角线元素放回去到指定位置\n", "c = t.zeros(4,4)\n", "c.scatter_(1, index, b)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### 高级索引\n", "PyTorch在0.2版本中完善了索引操作,目前已经支持绝大多数numpy的高级索引[^10]。高级索引可以看成是普通索引操作的扩展,但是高级索引操作的结果一般不和原始的Tensor贡献内出。 \n", "[^10]: https://docs.scipy.org/doc/numpy/reference/arrays.indexing.html#advanced-indexing" ] }, { "cell_type": "code", "execution_count": 43, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "\n", "(0 ,.,.) = \n", " 0 1 2\n", " 3 4 5\n", " 6 7 8\n", "\n", "(1 ,.,.) = \n", " 9 10 11\n", " 12 13 14\n", " 15 16 17\n", "\n", "(2 ,.,.) = \n", " 18 19 20\n", " 21 22 23\n", " 24 25 26\n", "[torch.FloatTensor of size 3x3x3]" ] }, "execution_count": 43, "metadata": {}, "output_type": "execute_result" } ], "source": [ "x = t.arange(0,27).view(3,3,3)\n", "x" ] }, { "cell_type": "code", "execution_count": 44, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "\n", " 14\n", " 24\n", "[torch.FloatTensor of size 2]" ] }, "execution_count": 44, "metadata": {}, "output_type": "execute_result" } ], "source": [ "x[[1, 2], [1, 2], [2, 0]] # x[1,1,2]和x[2,2,0]" ] }, { "cell_type": "code", "execution_count": 45, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "\n", " 19\n", " 10\n", " 1\n", "[torch.FloatTensor of size 3]" ] }, "execution_count": 45, "metadata": {}, "output_type": "execute_result" } ], "source": [ "x[[2, 1, 0], [0], [1]] # x[2,0,1],x[1,0,1],x[0,0,1]" ] }, { "cell_type": "code", "execution_count": 46, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "\n", "(0 ,.,.) = \n", " 0 1 2\n", " 3 4 5\n", " 6 7 8\n", "\n", "(1 ,.,.) = \n", " 18 19 20\n", " 21 22 23\n", " 24 25 26\n", "[torch.FloatTensor of size 2x3x3]" ] }, "execution_count": 46, "metadata": {}, "output_type": "execute_result" } ], "source": [ "x[[0, 2], ...] # x[0] 和 x[2]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Tensor类型\n", "\n", "Tensor有不同的数据类型,如表3-3所示,每种类型分别对应有CPU和GPU版本(HalfTensor除外)。默认的tensor是FloatTensor,可通过`t.set_default_tensor_type` 来修改默认tensor类型(如果默认类型为GPU tensor,则所有操作都将在GPU上进行)。Tensor的类型对分析内存占用很有帮助。例如对于一个size为(1000, 1000, 1000)的FloatTensor,它有`1000*1000*1000=10^9`个元素,每个元素占32bit/8 = 4Byte内存,所以共占大约4GB内存/显存。HalfTensor是专门为GPU版本设计的,同样的元素个数,显存占用只有FloatTensor的一半,所以可以极大缓解GPU显存不足的问题,但由于HalfTensor所能表示的数值大小和精度有限[^2],所以可能出现溢出等问题。\n", "\n", "[^2]: https://stackoverflow.com/questions/872544/what-range-of-numbers-can-be-represented-in-a-16-32-and-64-bit-ieee-754-syste\n", "\n", "表3-3: tensor数据类型\n", "\n", "数据类型|\tCPU tensor\t|GPU tensor|\n", ":---:|:---:|:--:|\n", "32-bit 浮点|\ttorch.FloatTensor\t|torch.cuda.FloatTensor\n", "64-bit 浮点|\ttorch.DoubleTensor|\ttorch.cuda.DoubleTensor\n", "16-bit 半精度浮点|\tN/A\t|torch.cuda.HalfTensor\n", "8-bit 无符号整形(0~255)|\ttorch.ByteTensor|\ttorch.cuda.ByteTensor\n", "8-bit 有符号整形(-128~127)|\ttorch.CharTensor\t|torch.cuda.CharTensor\n", "16-bit 有符号整形 |\ttorch.ShortTensor|\ttorch.cuda.ShortTensor\n", "32-bit 有符号整形 \t|torch.IntTensor\t|torch.cuda.IntTensor\n", "64-bit 有符号整形 \t|torch.LongTensor\t|torch.cuda.LongTensor\n", "\n", "各数据类型之间可以互相转换,`type(new_type)`是通用的做法,同时还有`float`、`long`、`half`等快捷方法。CPU tensor与GPU tensor之间的互相转换通过`tensor.cuda`和`tensor.cpu`方法实现。Tensor还有一个`new`方法,用法与`t.Tensor`一样,会调用该tensor对应类型的构造函数,生成与当前tensor类型一致的tensor。" ] }, { "cell_type": "code", "execution_count": 47, "metadata": {}, "outputs": [], "source": [ "# 设置默认tensor,注意参数是字符串\n", "t.set_default_tensor_type('torch.IntTensor')" ] }, { "cell_type": "code", "execution_count": 48, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "\n", "-1.7683e+09 2.1918e+04 1.0000e+00\n", " 0.0000e+00 1.0000e+00 0.0000e+00\n", "[torch.IntTensor of size 2x3]" ] }, "execution_count": 48, "metadata": {}, "output_type": "execute_result" } ], "source": [ "a = t.Tensor(2,3)\n", "a # 现在a是IntTensor" ] }, { "cell_type": "code", "execution_count": 49, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "\n", "-1.7683e+09 2.1918e+04 1.0000e+00\n", " 0.0000e+00 1.0000e+00 0.0000e+00\n", "[torch.FloatTensor of size 2x3]" ] }, "execution_count": 49, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# 把a转成FloatTensor,等价于b=a.type(t.FloatTensor)\n", "b = a.float() \n", "b" ] }, { "cell_type": "code", "execution_count": 50, "metadata": { "scrolled": true }, "outputs": [ { "data": { "text/plain": [ "\n", "-1.7683e+09 2.1918e+04 1.0000e+00\n", " 0.0000e+00 1.0000e+00 0.0000e+00\n", "[torch.FloatTensor of size 2x3]" ] }, "execution_count": 50, "metadata": {}, "output_type": "execute_result" } ], "source": [ "c = a.type_as(b)\n", "c" ] }, { "cell_type": "code", "execution_count": 51, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "\n", "-1.7682e+09 2.1918e+04 3.0000e+00\n", " 0.0000e+00 1.0000e+00 0.0000e+00\n", "[torch.IntTensor of size 2x3]" ] }, "execution_count": 51, "metadata": {}, "output_type": "execute_result" } ], "source": [ "d = a.new(2,3) # 等价于torch.IntTensor(2,3)\n", "d" ] }, { "cell_type": "code", "execution_count": 52, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "\u001b[0;31mSignature:\u001b[0m \u001b[0ma\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnew\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;31mSource:\u001b[0m \n", " \u001b[0;32mdef\u001b[0m \u001b[0mnew\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m \u001b[0;34mr\"\"\"Constructs a new tensor of the same data type as :attr:`self` tensor.\u001b[0m\n", "\u001b[0;34m\u001b[0m\n", "\u001b[0;34m Any valid argument combination to the tensor constructor is accepted by\u001b[0m\n", "\u001b[0;34m this method, including sizes, :class:`torch.Storage`, NumPy ndarray,\u001b[0m\n", "\u001b[0;34m Python Sequence, etc. See :ref:`torch.Tensor ` for more\u001b[0m\n", "\u001b[0;34m details.\u001b[0m\n", "\u001b[0;34m\u001b[0m\n", "\u001b[0;34m .. note:: For CUDA tensors, this method will create new tensor on the\u001b[0m\n", "\u001b[0;34m same device as this tensor.\u001b[0m\n", "\u001b[0;34m \"\"\"\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__class__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;31mFile:\u001b[0m /usr/lib/python3.6/site-packages/torch/tensor.py\n", "\u001b[0;31mType:\u001b[0m method\n" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# 查看函数new的源码\n", "a.new??" ] }, { "cell_type": "code", "execution_count": 53, "metadata": {}, "outputs": [], "source": [ "# 恢复之前的默认设置\n", "t.set_default_tensor_type('torch.FloatTensor')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### 逐元素操作\n", "\n", "这部分操作会对tensor的每一个元素(point-wise,又名element-wise)进行操作,此类操作的输入与输出形状一致。常用的操作如表3-4所示。\n", "\n", "表3-4: 常见的逐元素操作\n", "\n", "|函数|功能|\n", "|:--:|:--:|\n", "|abs/sqrt/div/exp/fmod/log/pow..|绝对值/平方根/除法/指数/求余/求幂..|\n", "|cos/sin/asin/atan2/cosh..|相关三角函数|\n", "|ceil/round/floor/trunc| 上取整/四舍五入/下取整/只保留整数部分|\n", "|clamp(input, min, max)|超过min和max部分截断|\n", "|sigmod/tanh..|激活函数\n", "\n", "对于很多操作,例如div、mul、pow、fmod等,PyTorch都实现了运算符重载,所以可以直接使用运算符。如`a ** 2` 等价于`torch.pow(a,2)`, `a * 2`等价于`torch.mul(a,2)`。\n", "\n", "其中`clamp(x, min, max)`的输出满足以下公式:\n", "$$\n", "y_i =\n", "\\begin{cases}\n", "min, & \\text{if } x_i \\lt min \\\\\n", "x_i, & \\text{if } min \\le x_i \\le max \\\\\n", "max, & \\text{if } x_i \\gt max\\\\\n", "\\end{cases}\n", "$$\n", "`clamp`常用在某些需要比较大小的地方,如取一个tensor的每个元素与另一个数的较大值。" ] }, { "cell_type": "code", "execution_count": 54, "metadata": { "scrolled": true }, "outputs": [ { "data": { "text/plain": [ "\n", " 1.0000 0.5403 -0.4161\n", "-0.9900 -0.6536 0.2837\n", "[torch.FloatTensor of size 2x3]" ] }, "execution_count": 54, "metadata": {}, "output_type": "execute_result" } ], "source": [ "a = t.arange(0, 6).view(2, 3)\n", "t.cos(a)" ] }, { "cell_type": "code", "execution_count": 55, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "\n", " 0 1 2\n", " 0 1 2\n", "[torch.FloatTensor of size 2x3]" ] }, "execution_count": 55, "metadata": {}, "output_type": "execute_result" } ], "source": [ "a % 3 # 等价于t.fmod(a, 3)" ] }, { "cell_type": "code", "execution_count": 56, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "\n", " 0 1 4\n", " 9 16 25\n", "[torch.FloatTensor of size 2x3]" ] }, "execution_count": 56, "metadata": {}, "output_type": "execute_result" } ], "source": [ "a ** 2 # 等价于t.pow(a, 2)" ] }, { "cell_type": "code", "execution_count": 57, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", " 0 1 2\n", " 3 4 5\n", "[torch.FloatTensor of size 2x3]\n", "\n" ] }, { "data": { "text/plain": [ "\n", " 3 3 3\n", " 3 4 5\n", "[torch.FloatTensor of size 2x3]" ] }, "execution_count": 57, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# 取a中的每一个元素与3相比较大的一个 (小于3的截断成3)\n", "print(a)\n", "t.clamp(a, min=3)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### 归并操作 \n", "此类操作会使输出形状小于输入形状,并可以沿着某一维度进行指定操作。如加法`sum`,既可以计算整个tensor的和,也可以计算tensor中每一行或每一列的和。常用的归并操作如表3-5所示。\n", "\n", "表3-5: 常用归并操作\n", "\n", "|函数|功能|\n", "|:---:|:---:|\n", "|mean/sum/median/mode|均值/和/中位数/众数|\n", "|norm/dist|范数/距离|\n", "|std/var|标准差/方差|\n", "|cumsum/cumprod|累加/累乘|\n", "\n", "以上大多数函数都有一个参数**`dim`**,用来指定这些操作是在哪个维度上执行的。关于dim(对应于Numpy中的axis)的解释众说纷纭,这里提供一个简单的记忆方式:\n", "\n", "假设输入的形状是(m, n, k)\n", "\n", "- 如果指定dim=0,输出的形状就是(1, n, k)或者(n, k)\n", "- 如果指定dim=1,输出的形状就是(m, 1, k)或者(m, k)\n", "- 如果指定dim=2,输出的形状就是(m, n, 1)或者(m, n)\n", "\n", "size中是否有\"1\",取决于参数`keepdim`,`keepdim=True`会保留维度`1`。注意,以上只是经验总结,并非所有函数都符合这种形状变化方式,如`cumsum`。" ] }, { "cell_type": "code", "execution_count": 58, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "\n", " 2 2 2\n", "[torch.FloatTensor of size 1x3]" ] }, "execution_count": 58, "metadata": {}, "output_type": "execute_result" } ], "source": [ "b = t.ones(2, 3)\n", "b.sum(dim = 0, keepdim=True)" ] }, { "cell_type": "code", "execution_count": 59, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "\n", " 2\n", " 2\n", " 2\n", "[torch.FloatTensor of size 3]" ] }, "execution_count": 59, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# keepdim=False,不保留维度\"1\",注意形状\n", "b.sum(dim=0, keepdim=False)" ] }, { "cell_type": "code", "execution_count": 60, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "\n", " 3\n", " 3\n", "[torch.FloatTensor of size 2]" ] }, "execution_count": 60, "metadata": {}, "output_type": "execute_result" } ], "source": [ "b.sum(dim=1)" ] }, { "cell_type": "code", "execution_count": 61, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", " 0 1 2\n", " 3 4 5\n", "[torch.FloatTensor of size 2x3]\n", "\n" ] }, { "data": { "text/plain": [ "\n", " 0 1 3\n", " 3 7 12\n", "[torch.FloatTensor of size 2x3]" ] }, "execution_count": 61, "metadata": {}, "output_type": "execute_result" } ], "source": [ "a = t.arange(0, 6).view(2, 3)\n", "print(a)\n", "a.cumsum(dim=1) # 沿着行累加" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### 比较\n", "比较函数中有一些是逐元素比较,操作类似于逐元素操作,还有一些则类似于归并操作。常用比较函数如表3-6所示。\n", "\n", "表3-6: 常用比较函数\n", "\n", "|函数|功能|\n", "|:--:|:--:|\n", "|gt/lt/ge/le/eq/ne|大于/小于/大于等于/小于等于/等于/不等|\n", "|topk|最大的k个数|\n", "|sort|排序|\n", "|max/min|比较两个tensor最大最小值|\n", "\n", "表中第一行的比较操作已经实现了运算符重载,因此可以使用`a>=b`、`a>b`、`a!=b`、`a==b`,其返回结果是一个`ByteTensor`,可用来选取元素。max/min这两个操作比较特殊,以max来说,它有以下三种使用情况:\n", "- t.max(tensor):返回tensor中最大的一个数\n", "- t.max(tensor,dim):指定维上最大的数,返回tensor和下标\n", "- t.max(tensor1, tensor2): 比较两个tensor相比较大的元素\n", "\n", "至于比较一个tensor和一个数,可以使用clamp函数。下面举例说明。" ] }, { "cell_type": "code", "execution_count": 62, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "\n", " 0 3 6\n", " 9 12 15\n", "[torch.FloatTensor of size 2x3]" ] }, "execution_count": 62, "metadata": {}, "output_type": "execute_result" } ], "source": [ "a = t.linspace(0, 15, 6).view(2, 3)\n", "a" ] }, { "cell_type": "code", "execution_count": 63, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "\n", " 15 12 9\n", " 6 3 0\n", "[torch.FloatTensor of size 2x3]" ] }, "execution_count": 63, "metadata": {}, "output_type": "execute_result" } ], "source": [ "b = t.linspace(15, 0, 6).view(2, 3)\n", "b" ] }, { "cell_type": "code", "execution_count": 64, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "\n", " 0 0 0\n", " 1 1 1\n", "[torch.ByteTensor of size 2x3]" ] }, "execution_count": 64, "metadata": {}, "output_type": "execute_result" } ], "source": [ "a>b" ] }, { "cell_type": "code", "execution_count": 65, "metadata": { "scrolled": true }, "outputs": [ { "data": { "text/plain": [ "\n", " 9\n", " 12\n", " 15\n", "[torch.FloatTensor of size 3]" ] }, "execution_count": 65, "metadata": {}, "output_type": "execute_result" } ], "source": [ "a[a>b] # a中大于b的元素" ] }, { "cell_type": "code", "execution_count": 66, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "15.0" ] }, "execution_count": 66, "metadata": {}, "output_type": "execute_result" } ], "source": [ "t.max(a)" ] }, { "cell_type": "code", "execution_count": 67, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(\n", " 15\n", " 6\n", " [torch.FloatTensor of size 2], \n", " 0\n", " 0\n", " [torch.LongTensor of size 2])" ] }, "execution_count": 67, "metadata": {}, "output_type": "execute_result" } ], "source": [ "t.max(b, dim=1) \n", "# 第一个返回值的15和6分别表示第0行和第1行最大的元素\n", "# 第二个返回值的0和0表示上述最大的数是该行第0个元素" ] }, { "cell_type": "code", "execution_count": 68, "metadata": { "scrolled": true }, "outputs": [ { "data": { "text/plain": [ "\n", " 15 12 9\n", " 9 12 15\n", "[torch.FloatTensor of size 2x3]" ] }, "execution_count": 68, "metadata": {}, "output_type": "execute_result" } ], "source": [ "t.max(a,b)" ] }, { "cell_type": "code", "execution_count": 69, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "\n", " 10 10 10\n", " 10 12 15\n", "[torch.FloatTensor of size 2x3]" ] }, "execution_count": 69, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# 比较a和10较大的元素\n", "t.clamp(a, min=10)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### 线性代数\n", "\n", "PyTorch的线性函数主要封装了Blas和Lapack,其用法和接口都与之类似。常用的线性代数函数如表3-7所示。\n", "\n", "表3-7: 常用的线性代数函数\n", "\n", "|函数|功能|\n", "|:---:|:---:|\n", "|trace|对角线元素之和(矩阵的迹)|\n", "|diag|对角线元素|\n", "|triu/tril|矩阵的上三角/下三角,可指定偏移量|\n", "|mm/bmm|矩阵乘法,batch的矩阵乘法|\n", "|addmm/addbmm/addmv/addr/badbmm..|矩阵运算\n", "|t|转置|\n", "|dot/cross|内积/外积\n", "|inverse|求逆矩阵\n", "|svd|奇异值分解\n", "\n", "具体使用说明请参见官方文档[^3],需要注意的是,矩阵的转置会导致存储空间不连续,需调用它的`.contiguous`方法将其转为连续。\n", "[^3]: http://pytorch.org/docs/torch.html#blas-and-lapack-operations" ] }, { "cell_type": "code", "execution_count": 70, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "False" ] }, "execution_count": 70, "metadata": {}, "output_type": "execute_result" } ], "source": [ "b = a.t()\n", "b.is_contiguous()" ] }, { "cell_type": "code", "execution_count": 71, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "\n", " 0 9\n", " 3 12\n", " 6 15\n", "[torch.FloatTensor of size 3x2]" ] }, "execution_count": 71, "metadata": {}, "output_type": "execute_result" } ], "source": [ "b.contiguous()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 3.1.2 Tensor和Numpy\n", "\n", "Tensor和Numpy数组之间具有很高的相似性,彼此之间的互操作也非常简单高效。需要注意的是,Numpy和Tensor共享内存。由于Numpy历史悠久,支持丰富的操作,所以当遇到Tensor不支持的操作时,可先转成Numpy数组,处理后再转回tensor,其转换开销很小。" ] }, { "cell_type": "code", "execution_count": 72, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([[1., 1., 1.],\n", " [1., 1., 1.]], dtype=float32)" ] }, "execution_count": 72, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import numpy as np\n", "a = np.ones([2, 3],dtype=np.float32)\n", "a" ] }, { "cell_type": "code", "execution_count": 73, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "\n", " 1 1 1\n", " 1 1 1\n", "[torch.FloatTensor of size 2x3]" ] }, "execution_count": 73, "metadata": {}, "output_type": "execute_result" } ], "source": [ "b = t.from_numpy(a)\n", "b" ] }, { "cell_type": "code", "execution_count": 74, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "\n", " 1 1 1\n", " 1 1 1\n", "[torch.FloatTensor of size 2x3]" ] }, "execution_count": 74, "metadata": {}, "output_type": "execute_result" } ], "source": [ "b = t.Tensor(a) # 也可以直接将numpy对象传入Tensor\n", "b" ] }, { "cell_type": "code", "execution_count": 75, "metadata": { "scrolled": true }, "outputs": [ { "data": { "text/plain": [ "\n", " 1 100 1\n", " 1 1 1\n", "[torch.FloatTensor of size 2x3]" ] }, "execution_count": 75, "metadata": {}, "output_type": "execute_result" } ], "source": [ "a[0, 1]=100\n", "b" ] }, { "cell_type": "code", "execution_count": 76, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([[ 1., 100., 1.],\n", " [ 1., 1., 1.]], dtype=float32)" ] }, "execution_count": 76, "metadata": {}, "output_type": "execute_result" } ], "source": [ "c = b.numpy() # a, b, c三个对象共享内存\n", "c" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**注意**: 当numpy的数据类型和Tensor的类型不一样的时候,数据会被复制,不会共享内存。" ] }, { "cell_type": "code", "execution_count": 77, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([[1., 1., 1.],\n", " [1., 1., 1.]])" ] }, "execution_count": 77, "metadata": {}, "output_type": "execute_result" } ], "source": [ "a = np.ones([2, 3])\n", "a # 注意和上面的a的区别(dtype不是float32)" ] }, { "cell_type": "code", "execution_count": 78, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "\n", " 1 1 1\n", " 1 1 1\n", "[torch.FloatTensor of size 2x3]" ] }, "execution_count": 78, "metadata": {}, "output_type": "execute_result" } ], "source": [ "b = t.Tensor(a) # FloatTensor(double64或者float64)\n", "b" ] }, { "cell_type": "code", "execution_count": 79, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "\n", " 1 1 1\n", " 1 1 1\n", "[torch.DoubleTensor of size 2x3]" ] }, "execution_count": 79, "metadata": {}, "output_type": "execute_result" } ], "source": [ "c = t.from_numpy(a) # 注意c的类型(DoubleTensor)\n", "c" ] }, { "cell_type": "code", "execution_count": 80, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "\n", " 1 1 1\n", " 1 1 1\n", "[torch.FloatTensor of size 2x3]" ] }, "execution_count": 80, "metadata": {}, "output_type": "execute_result" } ], "source": [ "a[0, 1] = 100\n", "b # b与a不通向内存,所以即使a改变了,b也不变" ] }, { "cell_type": "code", "execution_count": 81, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "\n", " 1 100 1\n", " 1 1 1\n", "[torch.DoubleTensor of size 2x3]" ] }, "execution_count": 81, "metadata": {}, "output_type": "execute_result" } ], "source": [ "c # c与a共享内存" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "广播法则(broadcast)是科学运算中经常使用的一个技巧,它在快速执行向量化的同时不会占用额外的内存/显存。\n", "Numpy的广播法则定义如下:\n", "\n", "- 让所有输入数组都向其中shape最长的数组看齐,shape中不足的部分通过在前面加1补齐\n", "- 两个数组要么在某一个维度的长度一致,要么其中一个为1,否则不能计算 \n", "- 当输入数组的某个维度的长度为1时,计算时沿此维度复制扩充成一样的形状\n", "\n", "PyTorch当前已经支持了自动广播法则,但是笔者还是建议读者通过以下两个函数的组合手动实现广播法则,这样更直观,更不易出错:\n", "\n", "- `unsqueeze`或者`view`:为数据某一维的形状补1,实现法则1\n", "- `expand`或者`expand_as`,重复数组,实现法则3;该操作不会复制数组,所以不会占用额外的空间。\n", "\n", "注意,repeat实现与expand相类似的功能,但是repeat会把相同数据复制多份,因此会占用额外的空间。" ] }, { "cell_type": "code", "execution_count": 82, "metadata": { "scrolled": true }, "outputs": [], "source": [ "a = t.ones(3, 2)\n", "b = t.zeros(2, 3,1)" ] }, { "cell_type": "code", "execution_count": 83, "metadata": { "scrolled": true }, "outputs": [ { "data": { "text/plain": [ "\n", "(0 ,.,.) = \n", " 1 1\n", " 1 1\n", " 1 1\n", "\n", "(1 ,.,.) = \n", " 1 1\n", " 1 1\n", " 1 1\n", "[torch.FloatTensor of size 2x3x2]" ] }, "execution_count": 83, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# 自动广播法则\n", "# 第一步:a是2维,b是3维,所以先在较小的a前面补1 ,\n", "# 即:a.unsqueeze(0),a的形状变成(1,3,2),b的形状是(2,3,1),\n", "# 第二步: a和b在第一维和第三维形状不一样,其中一个为1 ,\n", "# 可以利用广播法则扩展,两个形状都变成了(2,3,2)\n", "a+b" ] }, { "cell_type": "code", "execution_count": 84, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "\n", "(0 ,.,.) = \n", " 1 1\n", " 1 1\n", " 1 1\n", "\n", "(1 ,.,.) = \n", " 1 1\n", " 1 1\n", " 1 1\n", "[torch.FloatTensor of size 2x3x2]" ] }, "execution_count": 84, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# 手动广播法则\n", "# 或者 a.view(1,3,2).expand(2,3,2)+b.expand(2,3,2)\n", "a.unsqueeze(0).expand(2, 3, 2) + b.expand(2,3,2)" ] }, { "cell_type": "code", "execution_count": 85, "metadata": {}, "outputs": [], "source": [ "# expand不会占用额外空间,只会在需要的时候才扩充,可极大节省内存\n", "e = a.unsqueeze(0).expand(10000000000000, 3,2)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 3.1.3 内部结构\n", "\n", "tensor的数据结构如图3-1所示。tensor分为头信息区(Tensor)和存储区(Storage),信息区主要保存着tensor的形状(size)、步长(stride)、数据类型(type)等信息,而真正的数据则保存成连续数组。由于数据动辄成千上万,因此信息区元素占用内存较少,主要内存占用则取决于tensor中元素的数目,也即存储区的大小。\n", "\n", "一般来说一个tensor有着与之相对应的storage, storage是在data之上封装的接口,便于使用,而不同tensor的头信息一般不同,但却可能使用相同的数据。下面看两个例子。\n", "\n", "![图3-1: Tensor的数据结构](imgs/tensor_data_structure.svg)" ] }, { "cell_type": "code", "execution_count": 86, "metadata": {}, "outputs": [ { "data": { "text/plain": [ " 0.0\n", " 1.0\n", " 2.0\n", " 3.0\n", " 4.0\n", " 5.0\n", "[torch.FloatStorage of size 6]" ] }, "execution_count": 86, "metadata": {}, "output_type": "execute_result" } ], "source": [ "a = t.arange(0, 6)\n", "a.storage()" ] }, { "cell_type": "code", "execution_count": 87, "metadata": {}, "outputs": [ { "data": { "text/plain": [ " 0.0\n", " 1.0\n", " 2.0\n", " 3.0\n", " 4.0\n", " 5.0\n", "[torch.FloatStorage of size 6]" ] }, "execution_count": 87, "metadata": {}, "output_type": "execute_result" } ], "source": [ "b = a.view(2, 3)\n", "b.storage()" ] }, { "cell_type": "code", "execution_count": 88, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "True" ] }, "execution_count": 88, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# 一个对象的id值可以看作它在内存中的地址\n", "# storage的内存地址一样,即是同一个storage\n", "id(b.storage()) == id(a.storage())" ] }, { "cell_type": "code", "execution_count": 89, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "\n", " 0 100 2\n", " 3 4 5\n", "[torch.FloatTensor of size 2x3]" ] }, "execution_count": 89, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# a改变,b也随之改变,因为他们共享storage\n", "a[1] = 100\n", "b" ] }, { "cell_type": "code", "execution_count": 90, "metadata": {}, "outputs": [ { "data": { "text/plain": [ " 0.0\n", " 100.0\n", " 2.0\n", " 3.0\n", " 4.0\n", " 5.0\n", "[torch.FloatStorage of size 6]" ] }, "execution_count": 90, "metadata": {}, "output_type": "execute_result" } ], "source": [ "c = a[2:] \n", "c.storage()" ] }, { "cell_type": "code", "execution_count": 91, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(94139619931688, 94139619931680)" ] }, "execution_count": 91, "metadata": {}, "output_type": "execute_result" } ], "source": [ "c.data_ptr(), a.data_ptr() # data_ptr返回tensor首元素的内存地址\n", "# 可以看出相差8,这是因为2*4=8--相差两个元素,每个元素占4个字节(float)" ] }, { "cell_type": "code", "execution_count": 92, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "\n", " 0\n", " 100\n", "-100\n", " 3\n", " 4\n", " 5\n", "[torch.FloatTensor of size 6]" ] }, "execution_count": 92, "metadata": {}, "output_type": "execute_result" } ], "source": [ "c[0] = -100 # c[0]的内存地址对应a[2]的内存地址\n", "a" ] }, { "cell_type": "code", "execution_count": 93, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "\n", " 6666 100 -100\n", " 3 4 5\n", "[torch.FloatTensor of size 2x3]" ] }, "execution_count": 93, "metadata": {}, "output_type": "execute_result" } ], "source": [ "d = t.Tensor(c.storage())\n", "d[0] = 6666\n", "b" ] }, { "cell_type": "code", "execution_count": 94, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "True" ] }, "execution_count": 94, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# 下面4个tensor共享storage\n", "id(a.storage()) == id(b.storage()) == id(c.storage()) == id(d.storage())" ] }, { "cell_type": "code", "execution_count": 95, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(0, 2, 0)" ] }, "execution_count": 95, "metadata": {}, "output_type": "execute_result" } ], "source": [ "a.storage_offset(), c.storage_offset(), d.storage_offset()" ] }, { "cell_type": "code", "execution_count": 96, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "True" ] }, "execution_count": 96, "metadata": {}, "output_type": "execute_result" } ], "source": [ "e = b[::2, ::2] # 隔2行/列取一个元素\n", "id(e.storage()) == id(a.storage())" ] }, { "cell_type": "code", "execution_count": 97, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "((3, 1), (6, 2))" ] }, "execution_count": 97, "metadata": {}, "output_type": "execute_result" } ], "source": [ "b.stride(), e.stride()" ] }, { "cell_type": "code", "execution_count": 98, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "False" ] }, "execution_count": 98, "metadata": {}, "output_type": "execute_result" } ], "source": [ "e.is_contiguous()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "可见绝大多数操作并不修改tensor的数据,而只是修改了tensor的头信息。这种做法更节省内存,同时提升了处理速度。在使用中需要注意。\n", "此外有些操作会导致tensor不连续,这时需调用`tensor.contiguous`方法将它们变成连续的数据,该方法会使数据复制一份,不再与原来的数据共享storage。\n", "另外读者可以思考一下,之前说过的高级索引一般不共享stroage,而普通索引共享storage,这是为什么?(提示:普通索引可以通过只修改tensor的offset,stride和size,而不修改storage来实现)。" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 3.1.4 其它有关Tensor的话题\n", "这部分的内容不好专门划分一小节,但是笔者认为仍值得读者注意,故而将其放在这一小节。" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### 持久化\n", "Tensor的保存和加载十分的简单,使用t.save和t.load即可完成相应的功能。在save/load时可指定使用的`pickle`模块,在load时还可将GPU tensor映射到CPU或其它GPU上。" ] }, { "cell_type": "code", "execution_count": 99, "metadata": { "scrolled": true }, "outputs": [], "source": [ "if t.cuda.is_available():\n", " a = a.cuda(1) # 把a转为GPU1上的tensor,\n", " t.save(a,'a.pth')\n", "\n", " # 加载为b, 存储于GPU1上(因为保存时tensor就在GPU1上)\n", " b = t.load('a.pth')\n", " # 加载为c, 存储于CPU\n", " c = t.load('a.pth', map_location=lambda storage, loc: storage)\n", " # 加载为d, 存储于GPU0上\n", " d = t.load('a.pth', map_location={'cuda:1':'cuda:0'})" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### 向量化" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "向量化计算是一种特殊的并行计算方式,相对于一般程序在同一时间只执行一个操作的方式,它可在同一时间执行多个操作,通常是对不同的数据执行同样的一个或一批指令,或者说把指令应用于一个数组/向量上。向量化可极大提高科学运算的效率,Python本身是一门高级语言,使用很方便,但这也意味着很多操作很低效,尤其是`for`循环。在科学计算程序中应当极力避免使用Python原生的`for循环`。" ] }, { "cell_type": "code", "execution_count": 100, "metadata": {}, "outputs": [], "source": [ "def for_loop_add(x, y):\n", " result = []\n", " for i,j in zip(x, y):\n", " result.append(i + j)\n", " return t.Tensor(result)" ] }, { "cell_type": "code", "execution_count": 101, "metadata": { "scrolled": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "222 µs ± 81.9 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)\n", "The slowest run took 11.03 times longer than the fastest. This could mean that an intermediate result is being cached.\n", "5.58 µs ± 7.27 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)\n" ] } ], "source": [ "x = t.zeros(100)\n", "y = t.ones(100)\n", "%timeit -n 10 for_loop_add(x, y)\n", "%timeit -n 10 x + y" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "可见二者有超过40倍的速度差距,因此在实际使用中应尽量调用内建函数(buildin-function),这些函数底层由C/C++实现,能通过执行底层优化实现高效计算。因此在平时写代码时,就应养成向量化的思维习惯。" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "此外还有以下几点需要注意:\n", "- 大多数`t.function`都有一个参数`out`,这时候产生的结果将保存在out指定tensor之中。\n", "- `t.set_num_threads`可以设置PyTorch进行CPU多线程并行计算时候所占用的线程数,这个可以用来限制PyTorch所占用的CPU数目。\n", "- `t.set_printoptions`可以用来设置打印tensor时的数值精度和格式。\n", "下面举例说明。" ] }, { "cell_type": "code", "execution_count": 102, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "16777216.0 16777216.0\n" ] }, { "data": { "text/plain": [ "(199999, 199998)" ] }, "execution_count": 102, "metadata": {}, "output_type": "execute_result" } ], "source": [ "a = t.arange(0, 20000000)\n", "print(a[-1], a[-2]) # 32bit的IntTensor精度有限导致溢出\n", "b = t.LongTensor()\n", "t.arange(0, 200000, out=b) # 64bit的LongTensor不会溢出\n", "b[-1],b[-2]" ] }, { "cell_type": "code", "execution_count": 103, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "\n", "-0.6379 0.5422 0.0413\n", " 0.4575 0.8977 2.3465\n", "[torch.FloatTensor of size 2x3]" ] }, "execution_count": 103, "metadata": {}, "output_type": "execute_result" } ], "source": [ "a = t.randn(2,3)\n", "a" ] }, { "cell_type": "code", "execution_count": 104, "metadata": { "scrolled": false }, "outputs": [ { "data": { "text/plain": [ "\n", "-0.6378980875 0.5421655774 0.0412697867\n", "0.4574612975 0.8976946473 2.3464736938\n", "[torch.FloatTensor of size 2x3]" ] }, "execution_count": 104, "metadata": {}, "output_type": "execute_result" } ], "source": [ "t.set_printoptions(precision=10)\n", "a" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 3.1.5 小试牛刀:线性回归" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "线性回归是机器学习入门知识,应用十分广泛。线性回归利用数理统计中回归分析,来确定两种或两种以上变量间相互依赖的定量关系的,其表达形式为$y = wx+b+e$,$e$为误差服从均值为0的正态分布。首先让我们来确认线性回归的损失函数:\n", "$$\n", "loss = \\sum_i^N \\frac 1 2 ({y_i-(wx_i+b)})^2\n", "$$\n", "然后利用随机梯度下降法更新参数$\\textbf{w}$和$\\textbf{b}$来最小化损失函数,最终学得$\\textbf{w}$和$\\textbf{b}$的数值。" ] }, { "cell_type": "code", "execution_count": 105, "metadata": {}, "outputs": [], "source": [ "import torch as t\n", "%matplotlib inline\n", "from matplotlib import pyplot as plt\n", "from IPython import display" ] }, { "cell_type": "code", "execution_count": 106, "metadata": {}, "outputs": [], "source": [ "# 设置随机数种子,保证在不同电脑上运行时下面的输出一致\n", "t.manual_seed(1000) \n", "\n", "def get_fake_data(batch_size=8):\n", " ''' 产生随机数据:y=x*2+3,加上了一些噪声'''\n", " x = t.rand(batch_size, 1) * 20\n", " y = x * 2 + (1 + t.randn(batch_size, 1))*3\n", " return x, y" ] }, { "cell_type": "code", "execution_count": 107, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 107, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXQAAAD8CAYAAABn919SAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBo\ndHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAAD11JREFUeJzt3V+MXGd9xvHvU8eU5U+1gWxQvEAN\nKHKpSLHpKkobKaJA64AQMVFRSVtktbShEqhQkEVML4CLKkHmj6peRAokTS5oVArGQS3FWCFtWqmk\n3eAQO3XdFMqfrN14KSzQsqKO+fVix2Bv1t6Z9c7OzLvfj7SamXfP6DxaK0/mvOedc1JVSJJG308N\nOoAkaXVY6JLUCAtdkhphoUtSIyx0SWqEhS5JjbDQJakRFrokNcJCl6RGXLSWO7vkkktq8+bNa7lL\nSRp5Dz744LeqamK57da00Ddv3sz09PRa7lKSRl6Sr3eznVMuktQIC12SGmGhS1Ijli30JE9N8s9J\nvpzkkSTv74y/IMkDSR5N8pdJntL/uJKkc+nmE/oPgVdU1UuBrcC1Sa4CPgB8pKouB74DvLl/MSVJ\ny1l2lUst3AHjfzovN3Z+CngF8Jud8buA9wG3rn5ESRpN+w7OsGf/UY7NzbNpfIxd27ewY9tk3/bX\n1Rx6kg1JHgJOAAeArwBzVfVEZ5PHgP6llKQRs+/gDLv3HmJmbp4CZubm2b33EPsOzvRtn10VelWd\nqqqtwHOBK4EXL7XZUu9NcmOS6STTs7OzK08qSSNkz/6jzJ88ddbY/MlT7Nl/tG/77GmVS1XNAX8H\nXAWMJzk9ZfNc4Ng53nNbVU1V1dTExLJfdJKkJhybm+9pfDV0s8plIsl45/kY8CrgCHAf8OudzXYC\n9/QrpCSNmk3jYz2Nr4ZuPqFfBtyX5GHgX4ADVfXXwLuBdyb5D+DZwO19SylJI2bX9i2Mbdxw1tjY\nxg3s2r6lb/vsZpXLw8C2Jca/ysJ8uiRpkdOrWdZylcuaXpxLktaTHdsm+1rgi/nVf0lqhIUuSY2w\n0CWpERa6JDXCQpekRljoktQIC12SGmGhS1IjLHRJaoSFLkmNsNAlqREWuiQ1wkKXpEZY6JLUCAtd\nkhphoUtSIyx0SWqEhS5JjbDQJakRFrokNcJCl6RGWOiS1AgLXZIaYaFLUiMsdElqhIUuSY2w0CWp\nERa6JDXCQpekRljoktQIC12SGmGhS1IjLHRJaoSFLkmNsNAlqRHLFnqS5yW5L8mRJI8keXtn/H1J\nZpI81Pl5Tf/jSpLO5aIutnkCeFdVfSnJM4EHkxzo/O4jVfXB/sWTJHVr2UKvquPA8c7z7yc5Akz2\nO5gkqTc9zaEn2QxsAx7oDL0tycNJ7khy8SpnkyT1oOtCT/IM4FPAO6rqe8CtwIuArSx8gv/QOd53\nY5LpJNOzs7OrEFmStJSuCj3JRhbK/ONVtRegqh6vqlNV9SPgo8CVS723qm6rqqmqmpqYmFit3JKk\nRbpZ5RLgduBIVX34jPHLztjs9cDh1Y8nSepWN6tcrgbeBBxK8lBn7D3ADUm2AgV8DXhLXxJKkrrS\nzSqXfwSyxK8+u/pxJEkr5TdFJakRFrokNcJCl6RGdHNSVGrSvoMz7Nl/lGNz82waH2PX9i3s2OaX\noDW6LHStS/sOzrB77yHmT54CYGZunt17DwFY6hpZTrloXdqz/+iPy/y0+ZOn2LP/6IASSRfOQte6\ndGxuvqdxaRRY6FqXNo2P9TQujQILXevSru1bGNu44ayxsY0b2LV9y4ASSRfOk6Jal06f+HSVi1pi\noWvd2rFt0gJXU5xykaRGWOiS1AgLXZIaYaFLUiMsdElqhKtcJKlHw3phNwtdknowzBd2c8pFknow\nzBd2s9AlqQfDfGE3C12SejDMF3az0CWpB8N8YTdPikpSD4b5wm4WuiT1aFgv7OaUiyQ1wkKXpEZY\n6JLUCAtdkhphoUtSIyx0SWqEhS5JjbDQJakRFrokNcJCl6RGWOiS1AgLXZIasWyhJ3lekvuSHEny\nSJK3d8afleRAkkc7jxf3P64k6Vy6+YT+BPCuqnoxcBXw1iQ/D9wE3FtVlwP3dl5rBO07OMPVt3yB\nF9z0N1x9yxfYd3Bm0JEkrcCyhV5Vx6vqS53n3weOAJPAdcBdnc3uAnb0K6T65/QNb2fm5il+csNb\nS10aPT3NoSfZDGwDHgCeU1XHYaH0gUtXO5z6b5hveCupN10XepJnAJ8C3lFV3+vhfTcmmU4yPTs7\nu5KM6qNhvuGtpN50VehJNrJQ5h+vqr2d4ceTXNb5/WXAiaXeW1W3VdVUVU1NTEysRmatomG+4a2k\n3nSzyiXA7cCRqvrwGb/6DLCz83wncM/qx1O/DfMNbyX1ppt7il4NvAk4lOShzth7gFuATyR5M/AN\n4A39iah+GuYb3krqTapqzXY2NTVV09PTa7Y/SWpBkgeramq57fymqCQ1wkKXpEZY6JLUCAtdkhph\noUtSI7pZtqhVsu/gjMsDJfWNhb5GTl8E6/R1U05fBAuw1CWtCgt9jZzvIlgW+uB41KSWWOhrxItg\nDR+PmtQaT4quES+CNXy8dLBaY6GvES+CNXw8alJrLPQ1smPbJDdffwWT42MEmBwf4+brr/DQfoA8\nalJrnENfQzu2TVrgQ2TX9i1nzaGDR00abRa61i0vHazWWOha1zxqUkucQ5ekRljoktQIC12SGmGh\nS1IjLHRJaoSFLkmNsNAlqREWuiQ1wkKXpEZY6JLUCAtdkhphoUtSIyx0SWqEhS5JjbDQJakRFrok\nNWIkbnCx7+CMd5WRpGUMfaHvOzhz1n0fZ+bm2b33EIClLklnGPoplz37j551E1+A+ZOn2LP/6IAS\nSdJwGvpCPzY339O4JK1XQ1/om8bHehqXpPVq2UJPckeSE0kOnzH2viQzSR7q/LymXwF3bd/C2MYN\nZ42NbdzAru1b+rVLSRpJ3XxCvxO4donxj1TV1s7PZ1c31k/s2DbJzddfweT4GAEmx8e4+forPCEq\nSYssu8qlqu5Psrn/Uc5tx7ZJC1ySlnEhc+hvS/JwZ0rm4lVLJElakZUW+q3Ai4CtwHHgQ+faMMmN\nSaaTTM/Ozq5wd5Kk5ayo0Kvq8ao6VVU/Aj4KXHmebW+rqqmqmpqYmFhpTknSMlZU6EkuO+Pl64HD\n59pWkrQ2lj0pmuRu4OXAJUkeA94LvDzJVqCArwFv6WNGSVIXulnlcsMSw7f3IYsk6QIM/TdFJUnd\nsdAlqREWuiQ1wkKXpEZY6JLUCAtdkhphoUtSIyx0SWqEhS5JjbDQJakRFrokNcJCl6RGWOiS1AgL\nXZIaYaFLUiMsdElqhIUuSY2w0CWpERa6JDXCQpekRljoktQIC12SGmGhS1IjLHRJaoSFLkmNsNAl\nqREWuiQ1wkKXpEZY6JLUCAtdkhphoUtSIyx0SWqEhS5JjbDQJakRFrokNcJCl6RGLFvoSe5IciLJ\n4TPGnpXkQJJHO48X9zemJGk53XxCvxO4dtHYTcC9VXU5cG/ntSRpgJYt9Kq6H/j2ouHrgLs6z+8C\ndqxyLklSj1Y6h/6cqjoO0Hm8dPUiSZJWou8nRZPcmGQ6yfTs7Gy/dydJ69ZKC/3xJJcBdB5PnGvD\nqrqtqqaqampiYmKFu5MkLWelhf4ZYGfn+U7gntWJI0laqW6WLd4N/BOwJcljSd4M3AL8apJHgV/t\nvJYkDdBFy21QVTec41evXOUskqQL4DdFJakRFrokNcJCl6RGWOiS1AgLXZIaYaFLUiMsdElqhIUu\nSY2w0CWpERa6JDXCQpekRix7LZdRs+/gDHv2H+XY3DybxsfYtX0LO7ZNDjqWJPVdU4W+7+AMu/ce\nYv7kKQBm5ubZvfcQgKUuqXlNTbns2X/0x2V+2vzJU+zZf3RAiSRp7TRV6Mfm5nsal6SWNFXom8bH\nehqXpJY0Vei7tm9hbOOGs8bGNm5g1/YtA0okSWunqZOip098uspF0nrUVKHDQqlb4JLWo6amXCRp\nPbPQJakRFrokNcJCl6RGWOiS1IhU1drtLJkFvr7MZpcA31qDOBfCjKtnFHKacXWMQkYYzpw/W1UT\ny220poXejSTTVTU16BznY8bVMwo5zbg6RiEjjE7OpTjlIkmNsNAlqRHDWOi3DTpAF8y4ekYhpxlX\nxyhkhNHJ+SRDN4cuSVqZYfyELklagaEq9CRfS3IoyUNJpgedZylJxpN8Msm/JTmS5JcGnelMSbZ0\n/n6nf76X5B2DzrVYkj9K8kiSw0nuTvLUQWdaLMnbO/keGaa/YZI7kpxIcviMsWclOZDk0c7jxUOY\n8Q2dv+WPkgx8Fck5Mu7p/Lf9cJJPJxkfZMZeDVWhd/xKVW0d4mVDfwp8rqp+DngpcGTAec5SVUc7\nf7+twC8CPwA+PeBYZ0kyCfwhMFVVLwE2AG8cbKqzJXkJ8PvAlSz8O782yeWDTfVjdwLXLhq7Cbi3\nqi4H7u28HqQ7eXLGw8D1wP1rnmZpd/LkjAeAl1TVLwD/Duxe61AXYhgLfWgl+RngGuB2gKr6v6qa\nG2yq83ol8JWqWu7LXINwETCW5CLgacCxAedZ7MXAF6vqB1X1BPD3wOsHnAmAqrof+Pai4euAuzrP\n7wJ2rGmoRZbKWFVHqmpobvB7joyf7/x7A3wReO6aB7sAw1boBXw+yYNJbhx0mCW8EJgF/jzJwSQf\nS/L0QYc6jzcCdw86xGJVNQN8EPgGcBz4blV9frCpnuQwcE2SZyd5GvAa4HkDznQ+z6mq4wCdx0sH\nnKcFvwv87aBD9GLYCv3qqnoZ8GrgrUmuGXSgRS4CXgbcWlXbgP9l8Ie2S0ryFOB1wF8NOstinfnd\n64AXAJuApyf57cGmOltVHQE+wMIh+OeALwNPnPdNakaSP2bh3/vjg87Si6Eq9Ko61nk8wcK875WD\nTfQkjwGPVdUDndefZKHgh9GrgS9V1eODDrKEVwH/WVWzVXUS2Av88oAzPUlV3V5VL6uqa1g4NH90\n0JnO4/EklwF0Hk8MOM/ISrITeC3wWzVi67qHptCTPD3JM08/B36NhcPeoVFV/wV8M8npu06/EvjX\nAUY6nxsYwumWjm8AVyV5WpKw8HccqpPLAEku7Tw+n4WTecP69wT4DLCz83wncM8As4ysJNcC7wZe\nV1U/GHSeXg3NF4uSvJCfrMa4CPiLqvqTAUZaUpKtwMeApwBfBX6nqr4z2FRn68z5fhN4YVV9d9B5\nlpLk/cBvsHBYexD4var64WBTnS3JPwDPBk4C76yqewccCYAkdwMvZ+GqgI8D7wX2AZ8Ans/C/zDf\nUFWLT5wOOuO3gT8DJoA54KGq2j5kGXcDPw38d2ezL1bVHwwk4AoMTaFLki7M0Ey5SJIujIUuSY2w\n0CWpERa6JDXCQpekRljoktQIC12SGmGhS1Ij/h/CJYJPfXoR0gAAAABJRU5ErkJggg==\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# 来看看产生的x-y分布\n", "x, y = get_fake_data()\n", "plt.scatter(x.squeeze().numpy(), y.squeeze().numpy())" ] }, { "cell_type": "code", "execution_count": 108, "metadata": { "scrolled": false }, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXwAAAD8CAYAAAB0IB+mAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBo\ndHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAAIABJREFUeJzt3Xd4VGXax/HvDQQINVSBQKQaQEDA\nCCgWVBTsiNvsroX1XX13fXcXKWJZK4p1d11d7O66lpUAFhSxgV1BMKGFjhBCJ9RA2vP+MRM3hEky\nyfSZ3+e6uDJz5kzm9nhyz5nnPPM75pxDRETiX51IFyAiIuGhhi8ikiDU8EVEEoQavohIglDDFxFJ\nEGr4IiIJQg1fRCRBqOGLiCQINXwRkQRRL5wv1rp1a9e5c+dwvqSISMjkHygiN7+A0gqJBY3q1+Xo\nVo2pV8cOW3fznoMUlZSSVLcO7Zo1JKVRkl+vs2DBgu3OuTaB1hvWht+5c2fmz58fzpcUEQmZoZM/\npji/4IjlqSnJfDH+jJ/uz1iYy4TMbFoXlfy0LCmpLpNG92XUgNRqX8fM1gejXr+HdMysrpktNLN3\nvPe7mNk3ZrbSzF43s/rBKEhEJFbk+mj2AJsqLJ8yO4eCcs0eoKCohCmzc0JWmy81GcP/PbCs3P0H\ngceccz2AXcB1wSxMRCRabd17kJte+b7SxzukJB92v+IbQHXLQ8Wvhm9mHYHzgGe99w04A3jTu8pL\nwKhQFCgiEi2cc7zx3QaGPzKXOcu2cF7f9jSsd3gbTU6qy9gR6Yctq/gGUN3yUPH3CP9x4Fag1Hu/\nFZDvnCv23t8IVD8QJSISo9Zt389lz3zDrdOy6NW+Ge///hSevHwgky/pR2pKMoZn7P4BH+PyY0ek\nk5xU97Blvt4YQq3ak7Zmdj6w1Tm3wMyGlS32sarPYH0zGwOMAUhLS6tlmSIikVFUUsqzn63l8Q9X\nUL9eHR4Y3ZdfZnSijncGzqgBqdWeeC17fMrsHDblF9AhJZmxI9L9OmEbTP7M0hkKXGhm5wINgWZ4\njvhTzKye9yi/I7DJ15Odc1OBqQAZGRm62oqIxIzsjbsZNy2LpXl7OKdPO/584bG0bdbwsHVmLMz1\nq5H788YQatU2fOfcBGACgPcI/0/OucvN7D/Az4DXgKuBmSGsU0QkbA4UFvPYnBU89/laWjdpwNNX\nHM/IPu2OWK9sumXZDJzc/AImZGYDRLy5+xLIPPxxwGtmdi+wEHguOCWJiETOZyu3MXF6Nht2FnDp\noDTGn9OT5sm+vyBV1XTLmG/4zrlPgU+9t9cAg4JfkohI+O3aX8g97y4l8/tcurZuzOtjhjC4a6sq\nnxMt0y39FdZv2oqIRBvnHG/9sIm7317K7oIi/veM7tx0encaVphV40uHlGSfX74K93RLf6nhi0jC\nys0vYNL0bD7J2cZxnVJ45ZK+9GzXzO/njx2RftgYPkRmuqW/1PBFJOGUlDr++dU6HvJGG9xxfm+u\nPqkzdev4mnFeuWiZbukvNXwRSSg5m/cybloWizbkMyy9DfeO6kPHFo1q/fuiYbqlv9TwRSQk/J2f\nHi4Hi0r4+yereGruapo2TOKJX/XnwuM64EmKSQxq+CISdNE2P/3btTsZn5nFmm37GT0glUnn96Zl\n48QL+FXDF5Ggi5b56XsOFvHge8t55Zsf6dgimZevHcSpxwR8HZGYpYYvIkEXDfPTP1iymdtnLmbb\n3kNcf3IX/nD2MTSqn9gtL7H/60UkJCI5P33rnoPc9fYSZmVvpme7pky9MoPjOqWE/HVjgRq+iARd\nJOanO+d4/bsN3DdrGYeKSxk7Ip0xp3YlqW5NrvMU39TwRSTowj0/fe32/UzIzOLrNTsZ3KUlD4zu\nS9c2TULyWrFMDV9EQiIc89OLSkp55rM1PP7hShrUq8Pk0X35RbmsejmcGr6IxKSsjfmMm5bNsrw9\nnNu3HXddcGRWvRxODV9EYsqBwmIe/WAFz3+xljZNG/CPK49nxLFHZtXLkdTwRSRmzFvhyarfuKuA\nywenMe6cnjRr6DurXo6khi8iEVGT6IWd+wu5952lZC7MpWubxrzxmxMZ1KVlmCuOfWr4IhJ2/kYv\nlGXV//ntpewpKOJ3Z3Tnt35m1cuR1PBFJOz8iV7YuOsAk2Ys5tOcbfTvlMLkGmbVy5Gqbfhm1hCY\nBzTwrv+mc+5OM3sROA3Y7V31GufcolAVKiLxo6rohZJSx8tfrWOKN6v+zgt6c9WJNc+qlyP5c4R/\nCDjDObfPzJKAz83sPe9jY51zb4auPBGJR5VFL7Rp2oDRT33JD0HKqpfDVfudY+exz3s3yfvPhbQq\nEYlrY0ekk1xhHL5eHWP7vkNs2HmAJ37VnxeuOUHNPsj8Cpkws7pmtgjYCsxxzn3jfeg+M8sys8fM\nrEElzx1jZvPNbP62bduCVLaIxLJRA1J5YHRfUr1havXqGMWljlEDUvnwD6dxUf/UhLowSbj41fCd\ncyXOuf5AR2CQmfUBJgA9gROAlsC4Sp471TmX4ZzLaNMmcXOoReRwZ/Rqy2npnp7QrnlDXr52EI/+\non9CXpgkXGo0S8c5l29mnwIjnXMPexcfMrMXgD8FuzgRiU+zl2zmDm9W/Q2ndOH/zlJWfTj4M0un\nDVDkbfbJwHDgQTNr75zLM8/nrlHA4hDXKiIxbsueg9w5cwnvL9lMr/bNeOaqDPp1VFZ9uPjzltoe\neMnM6uIZAnrDOfeOmX3sfTMwYBFwYwjrFJEYVlrqeH3+Bu6ftYzC4lLGjezJ9ad0UVZ9mFXb8J1z\nWcAAH8vPCElFIhJX1mzbx4TMbL5Zu5MhXVvywOh+dGndONJlJSQNmolISBSVlDJ13hqe+GglDevV\n4cFLPFn1mn0TOWr4IhJ0P2zIZ9y0LJZv3st5fdtz54W9adtUWfWRpoYvIkFzoLCYRz5YwQtfrKVt\n04ZMvfJ4zlZWfdRQwxeRoJi7Yhu3ebPqrxiSxq0jlVUfbdTwRSQg5bPqu7VpzH9uPJETOiurPhqp\n4YtIrTjnmLloE3e/s5S9B4v43Zk9uOn0bjSoF9qs+ppcOEUOp4YvIjW2YecBbpuxmHkrtjEgLYXJ\no/uR3q5pyF/X3wuniG9q+CLit5JSx4tfruPh2TnUMfjzhcdyxZCjj8iqD9VRuD8XTpHKqeGLiF+W\n5e1h/LQsfti4m9PT23Dvxf9NuywvlEfhVV04Raqnhi8iVTpYVMLfPl7F03NX0zw5ib9cOoAL+rWv\n9AtUoTwKr+zCKR18vPHIkdTwRaRSX6/ZwcTMbNZs388lAzsy6bxetKgmvrg2R+H+DgGNHZF+2KcH\ngOSkuowdke7nf1FiU8MXkSPsLihi8nvLefXbH+nUMpl/XjeIU3r4dz2Lmh6F12QIqOy+ZunUjhq+\niBzm/cWerPrt+w4x5tSu3DK8R42y6mt6FF7TIaBRA1LV4GtJDV9EgMOz6nu3b8ZzV59A347Na/x7\nanoUrhOx4aOGL5LgSksdr323gQfe82TV3zoynRtO6RpQVn1NjsJ1IjZ81PBFEthqb1b9t2t3cmLX\nVjwwui+dw5xVrxOx4aOGL5KACotLmTpvNX/5eBUN69XhoUv68fOMjjXKqg/Wl6t0IjZ8/LmmbUNg\nHtDAu/6bzrk7zawL8BrQEvgeuNI5VxjKYkUkcIs25DO+LKu+X3vuvKDmWfXB/nKVTsSGhz+DdIeA\nM5xzxwH9gZFmNgR4EHjMOdcD2AVcF7oyRSRQ+w8Vc/fbS7n471+Qf6CIZ67K4MnLBtbqwiRVzayR\n6OXPNW0dsM97N8n7zwFnAJd5l78E3AU8FfwSRSRQn+Zs5bbpi8nNL+DKIUdz68h0mgaQVa+ZNbHJ\nrzF8M6sLLAC6A08Cq4F851yxd5WNgD6PiUSZHfsOcc87S5mxaBPd2zbhzRtPJCMIWfWaWROb/Jp3\n5Zwrcc71BzoCg4Bevlbz9VwzG2Nm881s/rZt22pfqYj4zTnH9IUbGf7oXN7NzuP3Z/bg3d+dHJRm\nD56ZNclJh+fea2ZN9KvRLB3nXL6ZfQoMAVLMrJ73KL8jsKmS50wFpgJkZGT4fFMQkeApn1U/MC2F\nyZf045ijgptVr5k1scmfWTptgCJvs08GhuM5YfsJ8DM8M3WuBmaGslARqVpJqeOFL9byyAcrqGNw\n90XHcsXgo6lTx/+pljWhmTWxx58j/PbAS95x/DrAG865d8xsKfCamd0LLASeC2GdIlKFpZv2MCHT\nk1V/Zs+23DOqj8bT5Qj+zNLJAgb4WL4Gz3i+iETIwaIS/vLRSqbOW0NKo+qz6iWx6Zu2IjHq6zU7\nmJCZzdrt+/n58R257bxepDSqOqteEpsavkgIheLarp6s+mW8+u0G0lo24l/XDebkHq2DVLHEMzV8\nkRAJxbVd31+cx+0zl7Bj3yF+c2pXbhl+DMn161b/RBHU8EVCJpjXdt28+yB3zFzMB0u3cGyHZrxw\nzQn0Sa15Vr0kNjV8kRAJRvxAaanj1e9+ZPKs5RSWlDLhnJ5cd3IX6gWQVS+JSw1fJEQCjR9YvW0f\nE6Zl8+26nZzUrRX3Xxz+rHqJLzpMEAmR2sYPFBaX8rePV3LO45+Rs2UvD/2sH69cP1jNXgKmI3yR\nEKlN/MDCH3cxflo2OVtqn1UvUhk1fJEQ8jd+YP+hYh7+IIcXv1xHu2YNefaqDIb3PioMFUoiUcMX\nibBPcrYyafpiNu0u4IrBgWfVi1RGDV8kQnxl1R9/dHDii0V8UcMXCTNPVn0u97yzlH2HirlleA/+\nZ1g3GtTTF6gktNTwRcJow84DTJyezWcrtzMwLYUHL+lHjyBn1YtURg1fJAyKS0p58ct1P2XV33PR\nsVwewqx6EV/U8EVCbOmmPYzPzCJLWfUSYWr4IiFSllX/j3lraNEoib9dNoDz+iqrXiJHDV8kBL5a\nvYOJ05VVL9FFDV8kiHYfKOKB95bx2neerPpXrh/M0O7Kqpfo4M9FzDsBLwPtgFJgqnPuCTO7C7gB\n2OZddaJzblaoChWJZs453l+8mTveWsLO/YX85rSu3HKmsuoluvhzhF8M/NE5972ZNQUWmNkc72OP\nOeceDl15ItFPWfUSK/y5iHkekOe9vdfMlgGBXaNNJA6Uljr+/e2PPPieJ6t+/Dk9uV5Z9RLFajSG\nb2adgQHAN8BQ4GYzuwqYj+dTwC4fzxkDjAFIS0sLsFyR6LBq6z4mZGbx3bpdDO3uyao/upXiiyW6\nmXPOvxXNmgBzgfucc5lmdhSwHXDAPUB759y1Vf2OjIwMN3/+/ABLFomcwuJS/jF3NX/9eBXJ9esy\n6bxe/Oz4jppqKSFlZguccxmB/h6/jvDNLAmYBrzinMsEcM5tKff4M8A7gRYjEs2+/3EXE7xZ9ef3\na8+dFxxLm6YNIl2WiN/8maVjwHPAMufco+WWt/eO7wNcDCwOTYkikbXvUDEPz87hpa+UVS+xzZ8j\n/KHAlUC2mS3yLpsIXGpm/fEM6awDfhOSCkUi6JPlW5k0w5NVf9WQo/nTCGXVS+zyZ5bO54CvAUrN\nuZe4tX3fIe5+eylv/aCseokf+qatSDnOOTK/z+Wed5eyX1n1EmfU8EW8lFUv8U4NXxKesuolUajh\nS0Jbsmk346dlk52rrHqJf2r4kpAOFpXw+IcreeYzT1b9Xy8dwPn9lFUv8U0NXxLOl6u3MzEzm3U7\nDiirXhKKGr7EpRkLc5kyO4dN+QV0SElm7Ih0Tk9vy/2zlvH6fGXVS2JSw5e4M2NhLhMysykoKgEg\nN7+AW9/MokFSHQ4UliirXhKWGr7EnSmzc35q9mUKS0pxOGbeNFRZ9ZKwFNwtcWdTfoHP5UUlTs1e\nEpoavsSdyhIsUzXdUhKchnQkbhQWl/L03NXs2F94xGPJSXUZOyI9AlWJRA81fIkLC9bvYkJmFiu2\n7OOC4zowqHMLnp675rBZOqMG6MqcktjU8CWmVcyqf+7qDM7s5cmqv/LEzhGtTSTaqOFLzPp4+RYm\nTV9M3p6DXDnkaG4d2ZMmDbRLi1RGfx0Sc8pn1fdo24Q3bzyJ449uEemyRKKeGr7EDOcc077P5V5v\nVv3/DT+GG4d1VVa9iJ/8uaZtJ+BloB1QCkx1zj1hZi2B14HOeC5x+Avn3K7QlSqJ7Mcdnqz6z1dt\n5/ijWzB5dF9l1YvUkD9H+MXAH51z35tZU2CBmc0BrgE+cs5NNrPxwHhgXOhKlURUXFLK81+s5dE5\nK6hXp46y6kUC4M81bfOAPO/tvWa2DEgFLgKGeVd7CfgUNXwJosW5uxmfmcXi3D0M7+XJqm/fXF+e\nEqmtGo3hm1lnYADwDXCU980A51yembUNenWSkA7Pqq/Pk5cN5Ny+7ZRVLxIgvxu+mTUBpgG3OOf2\n+PvHZ2ZjgDEAaWlptalREsiXq7YzYXo263cc4JcZnZh4bi+aN0qKdFkiccGvhm9mSXia/SvOuUzv\n4i1m1t57dN8e2Orruc65qcBUgIyMDBeEmiUO7T5QxH2zlvLG/I0c3aoR/75+MCcpq14kqPyZpWPA\nc8Ay59yj5R56C7gamOz9OTMkFUpcc87xbnYed721lF0HCrnxtG7cMrwHDZM01VIk2Pw5wh8KXAlk\nm9ki77KJeBr9G2Z2HfAj8PPQlCjxKm93AbfPWMyHy7bSJ7UZL/76BMUXi4SQP7N0PgcqG7A/M7jl\nSCIoLXW88s16Hnw/h+LSUiae25Nrh3ahXl2ldYuEkr5pK2G1cstexmdms2D9Lk7u3pr7L+5LWqtG\nkS5LJCGo4UtYHCou4alPV/P3T1bTqEFdHvn5cYwemKqpliJhpIYvIbdg/S7GT8ti5dZ9XHhcB+64\noDetm/i+KpWIhI4avoTMvkPFTHl/OS9/vZ72zRry/DUZnNHzqEiXJZKw1PAlJD5atoVJMxazec9B\nrj6xM38aka6sepEI01+gBNW2vYf489tLeCcrj2OOasKTl5/EwDRl1YtEAzV8CQrnHG8u2Mi97y6j\noLCEP5x1DDee1o369TTVUiRaqOFLwNbv2M/E6dl8sWoHGUe3YPIlfeneVln1ItFGDV9q7Yis+lF9\nuHxQmrLqRaKUGr7UyuFZ9Udxz6hjlVUvEuXU8KVGCgpLePyjFTz72Vpl1YvEGDV88dsXq7YzITOb\nH3cqq14kFqnhS7XyDxRy37vL+M+CjXRu1Yh/3zCYk7opq14k1qjhS6X+m1W/hF0HivifYd34/ZnK\nqheJVWr44tOm/ALumOnJqu+b2pyXrh3EsR2UVS8Sy9Tw5TClpY5/fbOeB99bTolzTDqvF9ec1FlZ\n9SJxQA1fflI+q/6UHq25b5Sy6kXiiRq+cKi4hL9/spq/f7qKxg3qKateJE75cxHz54Hzga3OuT7e\nZXcBNwDbvKtNdM7NClWREjoL1u9k3LRsVm3dx0X9O3D7+cqqF4lX/hzhvwj8DXi5wvLHnHMPB70i\nCYu9B4uYMjuHf369ng7Nk3nhmhM4vWfbSJclIiHkz0XM55lZ59CXIuHy4VJPVv2WvQe55qTO/Ons\ndBorq14k7gXyV36zmV0FzAf+6Jzb5WslMxsDjAFIS0sL4OUkUNv2HuKut5fwblYe6Uc15e9XDFRW\nvUgCqe1cu6eAbkB/IA94pLIVnXNTnXMZzrmMNm3a1PLlJBDOOd6Yv4Hhj85lzpIt/PGsY3j7f09W\nsxdJMLU6wnfObSm7bWbPAO8ErSIJqvU79jMhM5svV+9gUOeW3D+6L93bNol0WSISAbVq+GbW3jmX\n5717MbA4eCVJMBSXlPLs52t5bM4K6tetw30X9+HSE5RVL5LI/JmW+SowDGhtZhuBO4FhZtYfcMA6\n4DchrFFqaHHubsZNy2LJpj2c1fso7rmoD+2aN4x0WSISYf7M0rnUx+LnQlCLBKigsITHP1zBs5+v\npWXj+jx1+UBG9lFWvYh4aC5enCifVX/poE6MH6msehE5nBp+jMs/UMi97y7jzQUb6dK6Ma/eMIQT\nu7WKdFkiEoXU8GOUc463s/K4+21PVv1vh3Xjd8qqF5EqqOHHoE35Bdw+YzEfLd9Kv47NefnawfTu\n0CzSZYlIlFPDjyElpY5/fb2eh95fTqlDWfUiUiNq+DFixZa9jJ+Wxfc/5nNKj9bcf3FfOrVUVr2I\n+E8NP8odKi7hyU9W89Snq2jSoB6P/fI4RvVXVr2I1JwafhjMWJjLlNk5bMovoENKMmNHpDNqQGq1\nz5u/bifjMz1Z9aO8WfWtlFUvIrWkhh9iMxbmMiEzm4KiEgBy8wuYkJkN4LPpz1iYy4PvLydv90EA\nWjRK4oVfn8Dp6cqqF5HAqOGH2JTZOT81+zIFRSVMmZ1zRMOfsTCXW9/MorCk9L/rFpaw+0BRta9T\n208RIpI4NL0jxDblF/i1fOveg0zIzD6s2QMcLC5lyuycKl+j7FNEbn4Bjv9+ipixMDeg2kUkvqjh\nh1iHlOQqlzvneOO7DQx/ZO4RnwTKVPamUaaqTxEiImXU8ENs7Ih0kn18+3X/oWKembeGy575hlun\nZdGzfTPaNvV9QrayN40y/n6KEJHEpoYfYqMGpPLA6L60qBBkll9QxH2zlrFwwy7uv7gvr90whInn\n9jrizSE5qS5jR6RX+RrVfYoQEQE1/LAYNSCVRvV9nx9PSa7PZYM9FyYpe3NITUnGgNSUZB4Y3bfa\nk6++PkX480YhIolFs3TCJLeS4ZUtew4edn/UgNQaz64pW782s3Q0u0ckcajhh8FnK7dRt45RUuqO\neCxYwy61eaOo6XcERCS2VTukY2bPm9lWM1tcbllLM5tjZiu9P1uEtszYtGt/IX984weufO5bWjaq\nT/0KIWeRHnbR7B6RxOLPGP6LwMgKy8YDHznnegAfee/HlRkLcxk6+WO6jH+XoZM/rtGcduccb/2w\nieGPzmXmolxuOr0bn407nYd+1q/G4/OhpNk9IonFn2vazjOzzhUWX4TnwuYALwGfAuOCWFdEBTLU\nkZtfwKTp2XySs43jOjbnX9cPplf7Zj89N5qGSjqkJPs8t6DZPSLxqbazdI5yzuUBeH/GVdBLbYY6\nSkodL36xlrMencvXa3Zy+/m9yfzt0J+afTTS7B6RxBLyk7ZmNgYYA5CWlhbqlwuKmg515Gzey7hp\nWSzakM+px7ThvlF9YiKrPpDZPSISe2rb8LeYWXvnXJ6ZtQe2Vraic24qMBUgIyPjyGkqUcjfoY5D\nxSU8+fEqnpq7Omaz6qNtmElEQqe2QzpvAVd7b18NzAxOOdHBn6GO79bt5NwnPuMvH6/i/H4d+PAP\np3HxgI4x1exFJLFUe4RvZq/iOUHb2sw2AncCk4E3zOw64Efg56EsMtyqGurYe7CIB99fzr++/pHU\nlGRe/PUJDFNWvYjEAHMufKMsGRkZbv78+WF7vWD7YMlm7pi5hK17D3LNSV3449nH0LiBvrsmIqFl\nZguccxmB/h51Kz9s3XuQu95awqzszfRs15Snrzye/p1SIl2WiEiNqOFXwTnHG/M3cN+7yzhYXMrY\nEemMObUrSXWVOScisUcNvxJrt+9nYmY2X63ZwaAuLXlgdF+6tWkS6bJERGpNDb+CopJSnvlsDU98\nuJL69erwwOi+/DKjE3XqaPaNiMQ2NfxysjbmM25aNsvy9jDy2Hb8+aJjOapZw0iXJSISFGr4wIHC\nYh6bs4LnPl9L6yYNePqKgYzs0z7SZYmIBFXCN/zPVm5j4vRsNuws4NJBaYw/pyfNk5Oqf6KISIxJ\n2Ia/a38h97y7lMzvc+naujGvjxnC4K6tIl2WiEjIJFzDL8uqv/vtpewuKOLm07tz8xndaVghSkFE\nJN4kVMOvKqteRCTeJUTDLyl1/POrdTw0Owfn4Pbze3PNSZ2pW2GqpS7oLSLxLO4bvr9Z9bqgt4jE\nu7ht+OWz6ps2TOLxX/bnov4dKo0vruoqV2r4IhIP4rLhf7duJ+OnZbF6235GD0hl0vm9adm4fpXP\n0QW9RSTexVXD33OwiIfKZdW/dO0gTjumjV/P1QW9RSTexU3DL59Vf93JXfjDWTXLqh87Iv2wMXzQ\nBb1FJL7EfMOvmFX/jyuP57haZNXrgt4iEu9ituGHIqteF/QWkXgWUMM3s3XAXqAEKA7GJbj8UT6r\nfrA3q76rsupFRKoUjCP8051z24Pwe6qlrHoRkdqLmSGd8ln15/Rpx58vPJa2yqoXEfFboA3fAR+Y\nmQP+4ZybWnEFMxsDjAFIS0ur8QscKCzm0Q9W8PwXZVn1xzOyT7sAyxYRSTyBNvyhzrlNZtYWmGNm\ny51z88qv4H0TmAqQkZHhavLL563wZNVv3FXA5YPTGHdOT5o1VFa9iEhtBNTwnXObvD+3mtl0YBAw\nr+pnVe+wrPo2jXnjNycyqEvLQH+tiEhCq3XDN7PGQB3n3F7v7bOBuwMppmJW/f+e0Z2bTldWvYhI\nMARyhH8UMN0bRlYP+Ldz7v3a/rKNuw4wacZiPs3ZRv9OKbxySV96tlNWvYhIsNS64Tvn1gDHBVpA\nSanjpS/X8fAHOQDceUFvrjrxyKx6EREJTESnZS7fvIdx07L5YUM+w9LbcO+oPnRscWRWvYiIBC4i\nDf9gUQlPfrKKpz5dTbPkJJ74VX8uPK7yrHoREQlc2Bv+t2t3Mj4zizXb9jN6YCqTzqs+q15ERAIX\n1oafm1/AL/7xFR1bJPPytYM41c+sehERCVxYG/7O/YXccUoX/u+sY2hUP2ZSHURE4kJYu273Nk24\n7bze4XxJERHxqn14fC0k19cXqEREIiWsDV9ERCJHDV9EJEGo4YuIJIioniozY2GuLiouIhIkUdvw\nZyzMZUJmNgVFJYBnDv+EzGwANX0RkVqI2iGdKbNzfmr2ZQqKSpgyOydCFYmIxLaobfib8gtqtFxE\nRKoWtQ2/Q0pyjZaLiEjVorbhjx2RTnKFK10lJ9Vl7Ij0CFUkIhLbovakbdmJWc3SEREJjoAavpmN\nBJ4A6gLPOucmB6Uqr1EDUtXgRUSCpNZDOmZWF3gSOAfoDVxqZkpGExGJUoGM4Q8CVjnn1jjnCoHX\ngIuCU5aIiARbIA0/FdhQ7v5G7zIREYlCgYzh+7oArTtiJbMxwBjv3UNmtjiA1wyX1sD2SBfhB9UZ\nPLFQI6jOYIuVOoMyPTGQhr/0vWA7AAAFWUlEQVQR6FTufkdgU8WVnHNTgakAZjbfOZcRwGuGheoM\nrlioMxZqBNUZbLFUZzB+TyBDOt8BPcysi5nVB34FvBWMokREJPhqfYTvnCs2s5uB2XimZT7vnFsS\ntMpERCSoApqH75ybBcyqwVOmBvJ6YaQ6gysW6oyFGkF1BltC1WnOHXGeVURE4lDUZumIiEhwhaTh\nm9lIM8sxs1VmNt7H4w3M7HXv49+YWedQ1FFNjZ3M7BMzW2ZmS8zs9z7WGWZmu81skfffHeGu01vH\nOjPL9tZwxNl68/iLd3tmmdnAMNeXXm4bLTKzPWZ2S4V1IrItzex5M9tafjqwmbU0szlmttL7s0Ul\nz73au85KM7s6AnVOMbPl3v+n080spZLnVrl/hKHOu8wst9z/23MreW6VfSEMdb5ersZ1ZraokueG\nZXtW1oNCun8654L6D88J3NVAV6A+8APQu8I6vwWe9t7+FfB6sOvwo872wEDv7abACh91DgPeCXdt\nPmpdB7Su4vFzgffwfDdiCPBNBGutC2wGjo6GbQmcCgwEFpdb9hAw3nt7PPCgj+e1BNZ4f7bw3m4R\n5jrPBup5bz/oq05/9o8w1HkX8Cc/9osq+0Ko66zw+CPAHZHcnpX1oFDun6E4wvcncuEi4CXv7TeB\nM83M1xe5QsY5l+ec+957ey+wjNj9pvBFwMvO42sgxczaR6iWM4HVzrn1EXr9wzjn5gE7Kywuv/+9\nBIzy8dQRwBzn3E7n3C5gDjAynHU65z5wzhV7736N57suEVXJ9vRHWKNYqqrT22t+Abwaqtf3RxU9\nKGT7Zygavj+RCz+t492hdwOtQlCLX7xDSgOAb3w8fKKZ/WBm75nZsWEt7L8c8IGZLTDPN5criqaY\ni19R+R9SNGxLgKOcc3ng+aMD2vpYJ5q2KcC1eD7F+VLd/hEON3uHnp6vZAgimrbnKcAW59zKSh4P\n+/as0INCtn+GouH7E7ngVyxDOJhZE2AacItzbk+Fh7/HMzRxHPBXYEa46/Ma6pwbiCeZ9CYzO7XC\n41GxPc3zBbwLgf/4eDhatqW/omKbApjZbUAx8Eolq1S3f4TaU0A3oD+Qh2e4pKKo2Z7ApVR9dB/W\n7VlND6r0aT6WVbs9Q9Hw/Ylc+GkdM6sHNKd2HxMDYmZJeDb0K865zIqPO+f2OOf2eW/PApLMrHWY\ny8Q5t8n7cyswHc/H4/L8irkIg3OA751zWyo+EC3b0mtL2ZCX9+dWH+tExTb1now7H7jceQdvK/Jj\n/wgp59wW51yJc64UeKaS14+W7VkPGA28Xtk64dyelfSgkO2foWj4/kQuvAWUnVX+GfBxZTtzqHjH\n8Z4DljnnHq1knXZl5xbMbBCe7bUjfFWCmTU2s6Zlt/GcyKsYQPcWcJV5DAF2l30kDLNKj5yiYVuW\nU37/uxqY6WOd2cDZZtbCO0RxtndZ2JjnAkPjgAudcwcqWcef/SOkKpwvuriS14+WKJbhwHLn3EZf\nD4Zze1bRg0K3f4bo7PO5eM44rwZu8y67G8+OC9AQz8f+VcC3QNdQng2vpMaT8XwEygIWef+dC9wI\n3Ohd52ZgCZ4ZBV8DJ0Wgzq7e1//BW0vZ9ixfp+G5GM1qIBvIiECdjfA08ObllkV8W+J5A8oDivAc\nFV2H53zRR8BK78+W3nUz8Fy5rey513r30VXAryNQ5yo847Rl+2fZzLYOwKyq9o8w1/lP736XhadZ\nta9Yp/f+EX0hnHV6l79Ytk+WWzci27OKHhSy/VPftBURSRD6pq2ISIJQwxcRSRBq+CIiCUINX0Qk\nQajhi4gkCDV8EZEEoYYvIpIg1PBFRBLE/wOF691fEa+RdwAAAABJRU5ErkJggg==\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "1.9918575286865234 2.954965829849243\n" ] } ], "source": [ "# 随机初始化参数\n", "w = t.rand(1, 1) \n", "b = t.zeros(1, 1)\n", "\n", "lr =0.001 # 学习率\n", "\n", "for ii in range(20000):\n", " x, y = get_fake_data()\n", " \n", " # forward:计算loss\n", " y_pred = x.mm(w) + b.expand_as(y) # x@W等价于x.mm(w);for python3 only\n", " loss = 0.5 * (y_pred - y) ** 2 # 均方误差\n", " loss = loss.sum()\n", " \n", " # backward:手动计算梯度\n", " dloss = 1\n", " dy_pred = dloss * (y_pred - y)\n", " \n", " dw = x.t().mm(dy_pred)\n", " db = dy_pred.sum()\n", " \n", " # 更新参数\n", " w.sub_(lr * dw)\n", " b.sub_(lr * db)\n", " \n", " if ii%1000 ==0:\n", " \n", " # 画图\n", " display.clear_output(wait=True)\n", " x = t.arange(0, 20).view(-1, 1)\n", " y = x.mm(w) + b.expand_as(x)\n", " plt.plot(x.numpy(), y.numpy()) # predicted\n", " \n", " x2, y2 = get_fake_data(batch_size=20) \n", " plt.scatter(x2.numpy(), y2.numpy()) # true data\n", " \n", " plt.xlim(0, 20)\n", " plt.ylim(0, 41)\n", " plt.show()\n", " plt.pause(0.5)\n", " \n", "print(w.squeeze()[0], b.squeeze()[0])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "可见程序已经基本学出w=2、b=3,并且图中直线和数据已经实现较好的拟合。" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "虽然上面提到了许多操作,但是只要掌握了这个例子基本上就可以了,其他的知识,读者日后遇到的时候,可以再看看这部份的内容或者查找对应文档。\n" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.5.2" } }, "nbformat": 4, "nbformat_minor": 2 }