diff --git a/2_pytorch/PyTorch快速入门.ipynb b/2_pytorch/PyTorch快速入门.ipynb index eb55881..1f3a5b7 100644 --- a/2_pytorch/PyTorch快速入门.ipynb +++ b/2_pytorch/PyTorch快速入门.ipynb @@ -1381,34 +1381,17 @@ }, { "cell_type": "code", - "execution_count": 11, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "实际的label: cat ship ship plane\n" - ] - }, - { - "data": { - "image/png": "iVBORw0KGgoAAAANSUhEUgAAAZAAAABkCAIAAAAnqfEgAAA0bklEQVR4nO19WZMc6XXdqcysvbq6em/0ABgAg2UwxAyHo5mhRIkSJdohWrbssOWwFXaEIxzhFz/4wb9DP8ARDlNW2H6wZYclh+RNFmmSokmKnN2zYkCgATS6G71UV1fXmpWLH/KcW9Xd1SNRCke45e8+ALezsjK//PKrzHvuci7gxIkTJ06cOHHixIkTJ06cOHHixIkTJ06cOHHixImT/78kd3rTb/3TV/RZkimFIACQ87zszzAcZkqUjLhDoZApccKvpEnKg3hxpujbSKOqjh8DyBcG2Z8+An0l1dGiTBlFPGyS2IADjYFbhlJy3DPR0XIaNkcbRzqRLtADBxnqWx2eGb2QH/3Gv7uPCdnf3+cAIu6ay02ZzJ9UfrKDpCeV8QYv+5MbvNQ7uUdO8yMlhU0gd05T2/tPGKTtuby8fOKj3/rWOrWYE7W/u50pw8EAwLXnrmd/NmbrmZL3OYBC3qdiW7SMgpwWSdTPlFo1r6/nAAQ+B+l7PMjBQTNTZmZmuGc+r6NxH1stURJmiq1b/pnj371uj98NuJxKpVKmhCG/G+mXUi6VdXyeqDFTmjzs13/zn/EqFm/yKz5/U/WZWqYcDbkUu+19jU2/C93XQMMtB0UAJT/QuHUr7dZpQ5zEJ7Yk2jI+rK7R83xMWwC5nP3e7acan9qH3yoWi5lS8Io6dRFArsDJ6e1/lCk//8u/duIgHpw4ceLknEhwelNoL9iE7y4kCYAiaBl54IMwCE5aTzJZkAu4aWhvG71bgoQfZU9/7Yic7DVEQ51IT/rE19j4XooDPpvDUB9Fno4TA8jJOisV9E7WdXmBvZx1RnDnFPZa4Nsg8KY/033fn7r9zyl/NjMtp7fZ2CLycgASe5+mGm0qM0qvXDMzJ779Z7ewTkutwjvlpVxswy63JGEPQKnAo1XL3CHQ4W0BFLVKyrqbnoY9jG0fro1C3gMnAACCQGaa7DUvd/Lai4IIsuTQ7Y10IkoGI1Itfk8nyMv6MHttNBzqQjRs2RQ44/4mKQcf+XM8SJ4/t9inheXlZWH1O5mSxl2dmscZptxn5CUABpo3/VwQjghoPC3gfo8/c1vSdiEGSjyPSpqEADwzeDVvUaQVaE+AnD0lOD9zc7y0YnlGh+WNSLwUQK7I88adGs4QZ2E5ceLk3Ih7YDlx4uTcyBRImApMAcPJLTnhskQoz68Ihcm0Nh+fudwKBZp5UZLXR/7kPmZM5uSn9/QYzXk0OFOPBnM/oWm7vUcbtRPyW53OSJcUA6iVBAQ0ttkKHZ/lEi8w8eRYFXTyZfDndSFhMh3sGAj6DDT0Z5A/zdHGiMx2Hhvi9kl2IQLmI15yYOAh1i3LnT5jcmrLnyCfMewgx1Mbviv4PH7eiwEUPYF32y5/+bBPx7bv876XAt7E0VDQxuB8xC1pLgAQC+EW8vyKIUEIN1l4IZY7otfjGfd3dzNlZZFAJvPH+wWuDF/HtwnM6+0fCC0OFWewCMBoZD+uY+Kl3B5rbLGCIXGO11Wa4akXnl3htw4PMqXWI0gMB/zNxrUSgGS2kf05I9xtJ/IshjYMNQk8Y6nE2R5PmO5vtvZsBdpBIl1XYmtHS7EQcMmVy4o8wCA5pz1BDCAx++lsx4izsJw4cXJuxD2wnDhxcm5kCiQMElqVlrvhJSNMmO4TURyFbE6F0iJDUhZMUSrK6sqtTGm39gDs7dMIzwcFnU4RQCVM9VHJlA/XaaijtMB9PH4U1ggbm+0mgI2nrezPWpEHSba55fIqT7RQM6BhKWC8xoIs2/hUOgn3lNX650y/+nMhSp05NnyqZLQoSQCMBLc/vc8kspVV5kkZrl+aJ+QpKUaT/ORD+oxJKCgRL4nkSRAQyHsJgLz+9GKuhEJekMSPtWcoRXczJziv5RoNFC70qwAGusCKXAG+BQ4Nt+hKuwPCrjfffCtTRkKjc/XXeNiiB0DYDjmD4VohniGd1KLbgqsWxk2mQ8IICpyBizMRQB4qwutLqSrmV6/olr31o0wJd4kNL7x4C0Bulz+KYY4xx5ou4KjPCGNJwy6mPJq3oLikooQWPB1WSgCCkeDwSEercraLh4eZElx6IVN6jVkOUpg91o0oJbzYXJoC8GLFauMzDSlnYTlx4uTciHtgOXHi5NzIFEhoSCMXNKh4OQhlAPCEm0IZ+YUCbdR4nD92MgOzoCDKF//SX86UN7/3fQCbB3vZn10BwCiiRfpwYydT7j95kimluQuZcnH1CrcUWM8RClEWassAogFt472dzUypNOYzZaPzNFMGuqKVGo3hiooD4pCI4Kwn+uko4f+l0pzPxIwKbuZVGqW80H5nCKB1SLP/6R6rUsozhAYLKk+xWhMLmVmxztTxjc/6p5OCHAipLi1vZSLxEICvMF8uJrjLK1Y7MhAhqOvXDUQo6Vc1NElkML4IoNNuZX/VKoRFnmbSKmYCBYNbCg4221TKSrkMheHCUQIgKNh9VxQv5kgi/Rysdq0gV0OqlZbE0z0M4x+dRfFSqyRT0qeAWU7YbZBTjVGscrclQv7eUQhgdP8Tjk2ek0QVQd1AF6axFSKVFj1WmnGo+i0Fjge1EgB/wD8DXjGGFzik/pbqn3JL/O7sIi9EJxp5FlflVSdpAsCX9yDwzlzzzsJy4sTJuZEpFtbQ42P7sMenchwNAczV+EiuqyInkGfd/KnjysrkZLpHr8uckW/+/n/KlKcHAwBPO9zh4SZ3WN98zOOXaGpFPs2oWp1P66BS0z58LZT08C15FQB7MpHWLl7OlIFsrvv3aWE1W8rlWePRri5TyQd6t5zhJbXKjDT5CQyOdDxBx7aPE1tOWVixJtXKuX29aa1yYne/nSntLq+oP4wBdHsqciryVnb7vFO1iswNjaQwHsyfcBU/kS1ZzFliEWcyr/rYLJdqnEiV6HbkVKPjncxj8nOqEZE5ZlNppfgxRgA6R5yTR5axJaPJjKNLdU6LZV298+57mfL5O5/LlMSSwuIQQCm1dELOZL8nnKE1E42UPhbw+CNVyA+HPUyTWJZXovy41IwJ/cpCy9XSiWaPNBvLzMwqLz/LMaSH2Zi4w+IqR5tXPfM2K6ihipyuwmLpCiNaeVXRDQSYqjNVAOERr2KoyQnK8pdrBQYLtPVyeZmiKU3FGS0fX4ZblMsDyHlKEsSZdW/OwnLixMm5EffAcuLEybmRKZBwt0d7rBk1MuXb/+vbAG7foGPvl+4Ql81ZsbWlooiSwVNFjhU9yLeLBw+ZE9TsFQGkFfrCvZqyP+aPMqWkDI5QaTKh1dnMced6jaN6urWVKe1WE0BdxnBJPtdHAoD5Ou3nna2HmVJ7yjNeqKt8x7MIgDE6HJNuz6gsZGPLtDYuMF88AaYYbZBhQy859s6wIiEDZh0hGvO+l+XKHageYkuQcOeASsbTMBLe6x0RDu/I+77xhNP1wo1rmfLclYscrdKIxv5+o9PKTfw7Ubrhne2I9+U4TwSUPDkQ+odtABBKSkUJ4Iu2oWCEazaBI4YRYsNWsT4ae/dDAN0uE4KePuWe1XpNJxI21EyGHe5TUvhot9XKlLfeJ0isFn0A169xugJB0WGPi6csFpBkyLURKw4QG9YZtDFVLKXOuKjGmYz6SLAxL5RdvPcpj/rGdzIlev2L+lYRQJoSkxYEHgfglda2eIG+mCSSqiqWUsVwRvzWzEKDp36yDwAdLqf8Ct1HeEyAGWiSB7ucN1/em+QmM7MG4njwcubvzwEItFzTs6M+zsJy4sTJuRH3wHLixMm5kWmlOY2rmdLb4+NsVFgC0BRU7IUEWfWCUmDGoTRDQ7RFBxFB1q5M+N22Ig6NBQBzy4zidRJay4sqxLEIYKh8j4Gq0vuKAT2riEZPGHAnHAAIlJbVaurEGmRfBq1f4Imethmg3GoTvzy7KGx7hnXa6jNKVauI1zAwFGxsENrbwiKGBMdEesffGaeyura3mIM2P0/sXC7x0oYDXlqlyC2rS0TrGd9xt8fLqcoIDwdia9OFdcQ2F42LjRRaGqeA2UeTVzNJDoGzpGRsedrJIGExjQHUFGadNW48pY8VhY9KBo+ExD1d+5hmNxYXdjsEMFPl9jnN24MNUjPff0zlk3t/mCmtvVamdAYcWy98P1MCKLuq1wZw5yYpjP/GX/tapjyjFTgscbSDLscfdnmieqqkpP4RpkneV1mMJsHChYkcL8ZAWTvg8aMN5hjWC/ylHG3yjGFpFkAqwsvcFlMaq89wuYZ14S9wkZQ7Sh9rcZADVU1Fe3QgFAYhgKhNuF1sMnw/6guPlwmZWw8Y6y+UCQlnLjCC6SsXLFXi1RApgEgLL0zOxITOwnLixMm5EffAcuLEybmRKZDw1ksMNGz8gHn9tdklAK9/idsr3nqmhIqPGBrKiV8tRiNTZoT43nmPEY1agyb0M1fuAEhl0hcMYA5YrNNTZYBnTTs04A/fezdTZkvcUlGQsVapAdjcpm1sBRueQOK80gVbB7R4D5pU7m/R1l1bZqJdoFGdkKDOq4iF6UaKkEJBHFMs9GPVIYaP0uMppOPooRSrIzEOAMO2DdXZjFQ9DyGLSm0GE5Aw5xv/gTqXlHXLrE+M4rjjGM2pwWSQP3/y88/ChI/X1zVIzuRRm8smHg0BPFHd1YHoIrod4v3lBaK5WpUowlfScmiUhAXx8+n+dgc9AAMbtLjkH21yXT3YYKi0F4quQ4FjVHh8IxW3aq3tR58A2BTm+qPvfDdTbt98LlOWGsRH/U6LI1F7m9FtMpR0RLl3QorCdKnuIIw0RZDZk9JRlVvn1c9nSj34KV7REed25Gcs6UZEqQhjmSfqxqK7kCtgJI6EvFZyXwz6lsfZj2MAvQ7PUtXRBtqzqJ/h/AxZQGI9HDpaclDyankk/r9cbuJCMTp7OTkLy4kTJ+dG3APLiRMn50amQMLKLMHOs9cYEMmKpZ69yoaXiyNihtZ9Jl6OBFLiiGjr9V/4m5ly+dqrmXL1xfVMefNtorm52iqAzR0a6oHYvEoKaakmHB3l9R0e0MaeU9dMsx2tFnxxcRHAQFX2e8oAtFKyWk1RSGWHhgo53X+8kSnLcwQaNy4qNe64fP1f/Rse1hJHZfrOqEfm9auEw6+99ILOyK9bcmkWiUsNv8g+jzSlFuQqFIUajABDWY4Lc8pZtQ5shQIm2AKQl+muoraWAqMtMa4dHbYyZWQ5sQrwLShv8Mb1awDyVqFm3TknQOMJ+c73fpApRtVvRZG9QQfA+vYT7UCxWZpT5nBVgdGizpNXKmmgvEdPbb56gxBAoLasqeDwVlPE5wrfVmoNnVMEJB2r9eOZrAS1XqsD+OlXX8r+7B42tQNx96NHnNJ79+7xIwXPH+5zSvsKIJ6QapXrLdKVjmK7C0RzRpeSEwour3B+2urqunvI0eZ8H0CoZmUFC8C1uGck5F9Ujndba7JkHQ2MLlE+jWFWnaq2DId9zZvwa0V1jjMXL2WKbx6GcWc53eBxPnIKjNdTcnbmqLOwnDhxcm5kioXlF+k2e7L9YaZ84dXXAFRnaZj4R8biIONCr9wfb9Ab93NzTOZChQUfM1VVPwQ8frlQwUQxhPmS19bo8P7wHot4CnJJtuVTvHqJ1t/N51lV32yqg0g9B2Bzm4knnt4SDfFhtcSUZDZXudLIlL7K0D99pPKgwvRn+kD+7LCv8nSZM51DXbq2xLef57dSI/YVL22hjAlTZUx2LFNrdp4pPGMiB+t3YvwNskmtACrhvzzausqhnuxwWpr7tFX7fdWRDPW2FKODUQtcvESf9OVLFwFUC7ZsLHRwpoX1zl2euqISDbvRg6gHoDHP3DG7y6GMmp2O5laXXCtx7UUyFT1rrarWSl5QBVDoqhvoiC78ZrN5Yth2a0MRRbR1RqsGu7zIZTM/fwETFT/7B5zJhQbP++rnuRQ3NmmntwecqI9UuXKaTJwXKD96eYYX2FHKYaBVGltClipaPC2nRMliOV+xCM/HhLN8JPKSstomGbwwW9V87bHm1nrwRCqJy5dzABKlvBnJnfE65CM1NrZMQ323FNsql+VpnNXIYeJ25M5eTs7CcuLEybkR98By4sTJuZEpkDBfordyMDBoMAKQVyFLpWquUDr/iqJbnQloQ/7WP/8XmfKrf/ef8LAqUygUzYaPAFy99kz2506TzteBHJ+rywQLRlw7VCuUa9cZAXjuOrHh4dvsd9I96gBoi/Q2iszFSyO/oXyZpEWre3ZO3Axy1fseL2RjcwfT5O/8rV/jkOSirp7qE1kWdDLO4XZbbAoigcgHJQCB8llS2ed9ZS2liXLQhCby8u4HZsznrdDnGKK0fJaBaA+MsWCu0ciUWCyAJZ/jb+0T9Ww8Wc+U64q3+F6ACdzqC6V+RmnOkSXCmatb2LDilQBcvMQ8pnDIkext8yt7TQZkVpfJBldapCu32drXUblzfY64tVScAzAQy0Yv4pyXKrrvEe/7uLerHPbm3Ih6HO3rP3UnU24+uwZgENJr/uDH/Mq9Tz7IlJ957cVMuXSZS/rRe4xKmb88OaPopKBsr4LyChPR3ZUVMInEgHjUVutTEYSUZolbV6qKEaUJjrUsFQOibBRf3oNxZOaUpCoPMkgY+ykmGBA9KQVDnzrsUOSLRtMS6NpjTfu49VQSYKJwzSgqT4uzsJw4cXJuxD2wnDhxcm5kijWYU3FAT9Bs0OsDyKu95dG+akRUiJMHQcSFBi3DTz9kKsrmBhX0iPgePl7PlC+sfhHAM88yJri2Q6V7jzvMFxqZUm8QG96//4AnWqPV3RLIGsl8fbqzDyAxqgRZvD3F9bxTDAxW1mMBrIJ4zsL9bUyTRMloYxtbH9UKrJgplzhjfdG29UacuvX7vMZCoQzg8lUWsj94zPr73/uv38iUSNGcko5WMUVAslEn2GnMEhF84QsvAVhaZHnEcxc5XV5OnIKy1C0SZGGj/jLxxdqFBpVn2Kwo45DrKbtnjILPfvHlRcy/uMwxWOB1b+8xgE5XBAYqzbBksYaYyNeu3ciU+iyvqL5IkLin6HAi7JxVofTUKLSncFsYKrNJJAQFY3kMeMsKYmpfXuWULs1RKeU9AEsCnnWlL+0/JO57+OP1TFlV3LO1/X0edp6jDc/AX4F4C3w1iC3pZ9jaYXCz2SFlwu4Wo5BzM0yZvPMC0ai1K874D0aKx1lU2parNSUwV0NuDPC5czwOR1o8L/vIvqtqm/F31VBHZ7QlZzvnlRmXt2BgCgCeEG58dlqfs7CcOHFybsQ9sJw4cXJuZJqBarUmCg1cWFzABBL55rtEeXMKAN2Yp7FXKljYhfhrd4cgLhm2MuXydVJ8+aUigEqdRv7iClNM91VC0VJw0PgBl5eZRRkIn1oJjpXvZ9FAK22xDMOBIoyRWMYXBSvgqQmryMxKinFEYhM8Ib/7e3/AsYn32lPyXk3h1BkhtSs3eGlLC8RHCxdYtTO/uAygJDaC1kfEF+9//ChT+sKvBibsvsyIrv765SuZ8qUvvsLjV2cAVH3V0MjEDjVdkdpk9awiRw1ByzpsoyG+/G02RtvbawIoq45kZZUTWKko+/eUNATnJwqhRMIHD0BT5HnttlIljfNbKO/RBgdQb/O79dkGdxYdXE/5rshFmKwvqfB2lCpWxGMAhzNZ0z6B2pdeWuC1G1tDt90CEAlgGp/9VcHVjz76cabcvPW8js/Z3lQqaUmFVifE4Jh1BkiE1I6ULL27S+/EQZNH++S9H2bKx+8Se16/ziKwK9dvA5hbFAuFQJaxSxpPv6Ev3+hGtC0Y9yI41mtuoh2sgo/a08LFpzsNm4yDj2POkuws9lOd3lsPzsJy4sTJOZJpeVh6WNZretPOlDHRUrQtsqC9Fp93i3UepypPZKQOKOvK5VmZb2TKs3oJZJkyP3zzo+zPJ1v0ns7UaHPllYHywaePNDorPVG6hx7GnS7fvXML8wAi7bCtGp1anQMI5HSv6L1qRSEImfiT9DiY1eXpxc8/evt/Z0pZNEzDkJ71vJzKP/3Tr2fKwyeki92n2xR3PscyjkK5BKA3pHWQlxn7yiukOhqoGap5iG9cY9nT58SytLbIS6tXaPskgxDA4232B905EAf0Hrd0O/RJt1QcHo7UKV4nsnJrq8EajSIAlQbn5A54FbOz02cJQHC8JhkTL8msXNnCI4FqtmxLocTDLi7R61+r8QJLCjgEGmSQ543IctBSFYJY36NZ5aB51u1JnFCB1bgMlZqnMus04rTE8RBAqNKTvi6nMsO0xIfb9I5/cJ/Wt9X3jFT2lLbPTHrKxEyVkvjBn5e9dv02oxa9I973D95i7uFbb9DC+s531jPlww/fB3Dr9svZnzdu3c6UxlwjU2w5+f5Jw0qVXZNbtACSGBNZhCZWrBPLmE/GKWBnypgVLudjooouSs7M63MWlhMnTs6NuAeWEydOzo1MY2uQg+3C8gXt5AFIlLBz4SIhyY821zOlBTVrCegmbyzSLTdbp6Fu+ThXBAlrswsA/uVv/uvsz56O3+6z6qKn8hrzN682RJXV5Km7RTsRvaQffbwJYEe0BObKbXgcZL0haCAWpCCkMR/0mAY1XxGOKE03aXcfE6XOy8a+eJEe6Bc+z2qhvGDF++/8MccvO78mkqOdvS0A1TphxUKdO/z1r/08B6kcp9lZ7rO4wOybZpMT9eAh6acPW4Sl7cMjAEeKWhyIhqmpfieRQhAFebgLgvNGYlGvc/xWxzO3PAOgaFC6LGoBUVaclnmRTSehebh5oiTqAyiIZWF5ZS1Tcqo9KiiryMBpSZUrvgZptBbG/pzlBFmiWa+rQhxjgJI/PhU27B1yJp+scyabyhFqqKvrykIDQEl0EeYYTgOi+EClP3tqZnNxjUuupmtvD6a7k61kx1oRp55tkWNbmVmNBdYn/dxXuOSuX+dP8rvf+lam3F/fANB7WywUYih58SW6Gi5d4kECRWbiyBi9rZBI12jO9DTFRD9gIxCx5k/GdTXuA2ttay29y+qTxk53D0CSnsSVp8VZWE6cODk34h5YTpw4OTcyBRIa8W59jsZ8FAcAijJ9b4r590dvEFsdFljNn4Dm98pFHvmDDxm/+Nmv/MNM+b44c7vdNoCRAnM7WydDgZ1I8SNht4ZH7PZMmdjncJc2fOQ3MmV1pYEJa9YqcgbiQe70xMWsjqrRgGVDywFDjWuiUR5GVs9xTJ7cZY1+W7GnX/3lf5wpX/vaVzPlD7/JaNGy+CGW1XW1rFSgUi4BsCI+3xkpJSVDRbLGDRZFSmPZ/oTDfrTDNKVQ7XOCUhXAzAyzfpYFZEbhyfhOXkjQSuRNmZlhkK5en9FHOQAdEfI+fcp7Z3N7WioCSpGnsJqSzhr1ZQDJmAaS96Vc4+lSq+oQbElSbTlFs5uaggRApBsXxRxbe19k3HbtgoSdQwZPN9XCZ3Ve1U5Vpv5lPZwSQdFIh7Fw5DMCWbduMtPw5Reo3L3PMPHb732EaZITEvTEZeyJ+CTvW6GMsqIUxfMUGL1xk8TNiX4ym9v/AUBzj+A0UXHY0ycfZ8pzNxg3vP05fnd5RS4g/dKjkfialcwYpzEm7ssUamzh7tMkfGOWx/HF2pdSYIwwxxU/p8RZWE6cODk34h5YTpw4OTcyBRJWa4Qtc4uMcWQ97weqXynVZC0rePToEYsGvvw62c4GHSVn1mlsb22wnuDTu3d52CjEuDEHOuJdqM8TilppTkMprLdu8fg/fIeW7Vsfr/PUX/mVTMmIBu/f+/TEQSzXdKDqisurRHMl5VvOzwuMiJIwCqfnsA16jLu9+HkWyv/SV38pUxbUKfZnv6hIn6DHjCqK6ppkv1DCRDdQi1sZS7c1CqrLUE9EDHFNs7F8kXHJ5gHncKbRADASWskJLxlvt4WlrOlLR9G0VC1SjFb88RYTXgf9HoCRUHasEo1K9czSHIPktQrn1vDdzu4+gLYyVy1f9PotJkYa3bufNzRExXBxqIYtPVHr9Yc9AJHyeD2VHCVD7lkTCjaa/3JBJV+KfzXkE5gVyXo4HALoaZBGN+ipoGROcL4iisqNxyy0EqrD556/gWlihP3+WJErwOqIrHQmORZcAxAK6V+8dCVTrly5CuANaycszsKdnRYVocWPPmIXq6tXObbnnqOyssJU1RklxyKXBzBQW9ZYv4684LyFAi1x1CpzUuOxHIutzxwmi4Qcp7sTJ07+Aoh7YDlx4uTcyBRImETEULPzREzdfgygJ3xhUaTLl0lCcPd9orzDnpIDq4wkXiZhN9Y/Wc+UzU3ii5/50msAusId9TXmDc6vMbbyqEnc11NL1UKVNvzsMo//Sp1j2N1jDGh9fQNAR7X7LbWWXF6i2V8HjeErNWK35brI0UEcESrGVD2DS+za8y9nyq//g3/EQcYEGp/cY8wuyYnEQpHEkTLimi3Vuyc9ALG6ZipGhATEL0dtFuv7T2n2byondihUkigdsaoo5P1PNwA8eKTAq1IxFxYX9F2l6aqR6p4mEEogHDMdSqlVygAaJZ7FOAX7nemxVEzkozb3OOz7YmqPkiGARoOlo2trpBYIVao2Cgknk5RDaguJ9/rG5DHUaIWh8h4mcF9J3BJl5YuaTyBRuK0qBkdDZAVV2Nlqz8KpRi6Y80/G7Eai4d/YZ+Vmr9vKFCuoXL1wEdPEF1wyBToRcgrsjtMsT9X66SOrQKzP1IHJuk0pRrGvyHu7yfvy9h7x4wfv/ihT5lX/u7rKn9vq2hUApZLynBcYWFxaoRvH0nftlkXyMFjr1nHiqOWdJh4mWBzSM5jv4SwsJ06cnCOZYmEdiVKgLA/xcBBCnS0wkZi/NM/X9V15zneafAHu6Z3cmOGj9/aLdEnef8iclIzAypziN2/Qc3zjKq2y9c1WpmSl5wD294xfQd1f9G7ceJ/m2NZ+G0BOIQJfFf+rF2m4XdFT+rJ8+cZ+NZQplyR5DXJ6LcWv/f2/xwGs8p357vuMKpgHNBy3CVG9hVy25lbM+prE9m6xHp/jVwm3hHo37u3RgrNUI7OEGmKJylzRzX01Rpe/dm9PjULFFxzJ6W7FOtY5pqK26SUlH3mRDyC0jjRqf1JWatVpaal+aGuThm1VlT3Pv/AigAWxklUU+hiI3fjggGl3xiTRE61CRXlqs3Wu0qp61pcLeQCBbKVYTvcsyANgJKLqgXV2GXP+iqVXeEKZbQj8AoBULVcHQyr7uzQY9/YZX7JqMGPCMF6QokiNT0guNQuLW8xFnZOpYtwGExUxVMzn3e/QHt/e2gKwtUWjqX2oCjkZhjMK+9RklJXEO2IUchvbXNJ319kNdzCIAUQxD7K4RFR05w7r7W7eYDLa0hJva32WkZNimU+AFFot+oHQpjfabud0d+LEyV8AcQ8sJ06cnBuZAgnv36P5d1nJ+yUvBJAIRARmQ5qHT07lmkiBn3+eqTR/+Af/JVN6LVqnlQX6Vu9t7AC4eJH+vKu3SO9bFCR57ln2kmk1W5ny4Yf07icCIxsHtPPbfdn5cRFAu0WkubxKG/VRk1vmLzYyZV+QB+qV0pK/OVVDoEEyxDR55503MuW9997JFE+GrifT2koczOcKGCMCjeqg4GFiJgtjogLx+SpFy0v5Ub1IL7UnAozIt2tX+lgKAAUhkUgsgL2WRRV0XVasIxQaynsdqW1SRzioUggALIvALxAuK5xZSoH5Zd7ueWEEYwHOFtKRCqSOOup4WsxraOLVkxv+mRVGToq6d771jlUxVnfQBzBQsKIlXGmQbTDgGW/fJjdeXhmFE3zBJ1v4DLtHADa26dDY2aWvOhSUNnIRyywryFXS0TV+4xvfwFRRMldiOVaR6mOEFq0PVM5X0pMglS83/LtvvckztnYALCiJ7PEWr72uZLG8VrgF2epKPQvEjlIITnpgOl4HwH6LgZr1ByxQax1wWt56QwtYpJiXL9MVsyZa8Atr/EmurXBLtTYHIFcW5YN3Zlqfs7CcOHFybsQ9sJw4cXJuZAokfOdTBqEu3yEleYIugJzFy2S1ttXPo9VioGRh/uVM+ZWv/WKmvPx5Wt2//R9/J1NyKvWenZ0D8Mwao2zGue5HDBLNr3J4a9eICFrCIG+/806mbHXEvZ2nrTu7ughg8Tr/9AXHYpnUd9UI59620rv03LY6la6uNUpsir6FCfmjb/+PTOmJGq2Q52HLqkGx6fVTVfZbG8u8QcIcgFLxJMouiF8hUOpZSW1lDWgodgevZC8ehV3CEMBAZTEjBcgs88heVcF4i65UkHy2RujRqHBLreIDKAT8Sl4pQrl4OnDGRGcUS9oKBHuTDOxYDYoyniz1rSTcN+hy/P1DLrm+uq8GBZtSEcXFEYBPPiRaebi+zpEI+BuSWrvANKJ5kSP2BetMOZA7otnaB9ALLf/L6EC4xXr62s2oCFttqbZpW7UyJ2QkhG4h5lwk2gZDi9o5VQqVhRQ7Cg4O+jzOrZsvAHjl5deyP994jy0I/vhHzLE6FKl/rLWxvMqQ35e//OVMCXTL1tUs9vs/+D6Az71ALv+6XEA7uq5tJQlaTHZVJBBXr17hGRUT7x4d6opSAHm1sx2c4hQxcRaWEydOzo24B5YTJ07OjUyBhHfbBCN7sagL8gMAXij7LRH/lrLs1tRQ88tfYqSvlGfc6uqzLPj+q3/71zPl3//Of86U3e1DAJuHRhvA/qwFWbzNHpV7YoOAIjLpIpHm3DJHa2AnSxlNSrZdJGRCsoeRijbGiZG0rbsezfuR4l5pMt06XVmiMbzVZ/wljluZUlezzEClOe091moctWmHj2KLfw0xtRZBHGaFMufWMG805njj+6ai1q1VkazHWVauHVb8ATkBqJJwX1mTMK8U3EtSLq4xJCcgjsHgCICndrOBMEmjXj45fsndTz7MlBfuEEfYtGej8xSaS1TDYbCiJwb6QY8RagsXWqPca6IzX15mgmJW+REoVjsr9kQ7ryXlWvLnRx9/kilGUGHOAcuizBZYR0mhfXEWGiQM1avNOOMfPeXasAzS+IwGVuO2o2P2dP5vJHlCzEgEEi2oWVY4+Mtf+ao+8TDB137zZbp3XvwpKgqujuffegVcu8bM7UAzduUGSf7WLt8CUC7zdlufARu/9Rkw3Le8xNRxo3zwhZQ9eWniZAhgpCtNctNnCc7CcuLEyTmSKRbWJ2qP+rvfpaPuC88uAlgtqHm3XiAXVvnsvLDIl9hz11TbqRKKrV0+cb/+b2lYvfU2X7lZxc9E6YucpnLXxSUeNjY3s/zl1ic18tSI/PilDELrPmJ9t2kn+LI7UtUMR7LO8lY6Y0lJ4fQqgXSkEvEq30JH1jUz5kv4+du0KZILtLl29zgbO6Lr7bRiTLylYyVSJRGPVg34Xnr+JfJQb8q5uyt/fz88+drPSn+KKq6q6pY1qpyuJTX7WV3jTbyu2uOVEqeuo8SofdXHZh7uquIAtRm+aRcW5nCGjJT0NOhwtJ6sy8yKMHos63h67y7tnSMLaOidnFdQwhrfJ1aqbWW9cQpgcYGDNHuqN7aeqDx69PjEPqZYp/ieCrAPWy0A3T21yw1s2Lwco+jqKk0pUo1RPO7tPt126PdpQvpKHwtEBh3qpxQp9zDSldphjd3MqneiOMJEM5tQ1uva5au6QhWHSfFEmvbgETPX+qGhFrFmz16dPN3BofpOaTaq9Su6UNX5H/LSNp82NVqOsqj6uayyKFdTdfrBmU2YnIXlxImTcyPugeXEiZNzI1MgYUd22jfeYh3Mp/fuA/grr7Ig+7k1gpQH90lD/POvkau3lKer+EiI7Lf/G/M+3v6Axfq9SAUxQQmAJzfwuJeknO6pZz45fjQUZBtpS07JNUMdNjM3g+AkuKtUZH/KtI4NQ2ge7ESR2mQWlB12QvY3Wcgej2i+9mXt96zHqjpfLolAKj8kZCuLYKHvpwDS1ICxsIP8jr0+weOXXyfAvHObpMyPHjE7Zu+ATn0rE8kc2oabSjrdkiBVQ8xZsc64vcejfSJeJMjnWl8mvKrM1gFUZvjdebFr1eR8PS1l3YhQiMxCHFnQxpMzuSDcak16LDJQk1PZU2ZQRRcSKWfn7sek62g39wG0VA2TWLtcHS3QkiiJ5MD4LnrC9bsi7TJI6HsBgDmth1B79pQSFokEIhkDwJO0CrncdBPh29/+nxx8RMLiipKSElE/G/nHGIQmlhopwmgLESQRAE9IbSBwl4z5sKz+RlGXBmMstZquUT147Dv00Ht2B5UEZwmGenpY0MMbE4+cuvZxUmAMAFUdRIGs0+IsLCdOnJwbcQ8sJ06cnBuZAgkXFmkZNptEJVutAwDfU6OaePSs9qXVtyQSO/i02H/4Bin3fv+b38uUYVLVOcVq4B17XMaWYyVD0ZqhGieslddYjCZnBSXGkeD5mMj1mDH227H5OjpxtEQkCmZar64S49TrVN7AMVlV4G/jEbFhNLTsGCoPFO06VJ6UXXBX6V3daAQgiQ0SiodaIGI4IOJ467v/PVN+scoruqMr6os+wUJmWR3VwCJcAhE7exztQxEW7/UY9hoIHpUEAOeVXlea5fj9cgEChgCKwpU5f8pC4iVrkAZGPNVmZaMd6AItx6pkeTpSLMAXNulYeGw0x8ZZbKHeoICJkqygZEfjkELh/Y4oPSxuaB1hLTZc0vhH/RDASIwCFpC1AJ/5NCxzKlKiYhob7J0edC5phURCgoFqwoLCrEaiuKQ5T8b83epVYyCRa83CiOHx7ZOXaI2UbA9ROap3VDhQ6dXx32ykJrgGzC030NNoPZzEjyZhx259BGCgz0vBHs4QZ2E5ceLk3Ih7YDlx4uTcyBRLPpAdmy+IQmxQBPBgx+okmPn5C6+Qpa/cYEF2W5zov/HHhFAD2agjGfxFMXtlJrTlTJr4MiZPW8/F00jQdlYtTrlUxkQmm5GyH6nhipVHDAVSZhus6li9QKUmHNETI8UJuXyTJGTtLiFVd8PsWHG/Ceg1daKCqmpChQXjjLHbYLAdIrW4Erfce4/x1sdHnMklT+1XlS4Yy+rueAmA7ZRo5Z6ikxtiBegZAcNl1uivXiGbWkm1LDCgJ8q9Wq0GoKIonqfE1PSM4BeAtpg8ekoc3dkUB8MgBBArRdYoJUaCbHZd1kI0L6KIcRRYit3xbMKMm2HQURxZRAtHh1QsNludUVKxJjAdKTAtFsMsr/VQ3YYMCcbKyTRi+OTU3TSCitwYsh0TyxPudBjwrcjFYceyTsAWCgwjG5syLT2LG0YAQmPpEPeDJZ2Ow4WG2cdI08alORT+zb5lZXATFWXxCcUaqXqnfsf2USAgORpFAHpzXFcXLs7gDHEWlhMnTs6NuAeWEydOzo1MbaRqPT5lKwYlAGFEu3ynQ/vzzY8ZsvmVHm28o5QAarNJpaQgXdSzHDYRhFcqAAIZq1Yfn9OoLKxgMcFUHAbGhGfFZZ2Qww6jLgQMMVH+bgCwO6ChWxMSnFtmPZ2RiH/8gCHRfGK27jGpNxhKW1phKG1LkNAsYKvMH8pOtp5RsXo3xTgJH04M2w43Egbp7jGtzis2MsUX68CmTvQOhgDuCUB1apy32kUW/S2qbe2irr2o5MwxF5+gTVEM9H7gA/CtyaiF87TltGw//FQHO1kBl0XTAjG4Ww/OnHUzzRMWVUSOmDuFXyIVhHYiRRKHEYBElXFjAjz1+yoUGYlbfoaT0BVcbR9QsaZn6TgKmcMEgZ+xOKTpyTtl2DBvRAu6y73edA/DxiPSDX66zfNWldQaCEVG45XFGYv1UaKgc36chj2CKgoB6NLHLgZrEGtd+8YxR9tH99dmO2OkSOKT8VBPvo6cGErG5PRaRafmCR3l9MZzFQDPvMQmEnUlFJwWZ2E5ceLk3Mi09JkxZY96cnh5AIn1mJSZs77D18XXf5utcb76lVcz5f4mrYDeONdJvnxVV/iFAoCK3pkF2Up9FVWYvzyVcZQv8dS+3vnmoLUt2aO9b3k6uhzboSHjaGGVsYLdPdaRt1SV0npEu+D6NVW3H5eSOtYUdTn2covlr5XfHFHu5JSOi/azneztY/vpLZdK6egt95Fe8rNqqPPxgKzW74tdulmvAJi/xMGvXSErWWON116ocPzWhHWksQWya3wpgd722Rt1bCLl7AV75pvPT5SmFJu7VzZLdjRL2EntLc3vDsW8HIkbI9GcTvAfUMzpnnUVNesgkKkVaxWViqpYUu3RwR7tmq5iLHmtdt+6ew6HmOhhYybweBK0kq3jaUlLriPaiV73ENPEg1aRLYTYiKRlAdkk+5pJGVDWHtWKsbI5tilNlftmk5sadDCCCuvBI/7uWJ+NzJTz8wBS61Rk5F1mnVnb1/H8KI6h8EgkMuu6mEIuvngTQJDj7WjdfR9niLOwnDhxcm7EPbCcOHFybmQKJJxXU0mrmehGIYCCOi9aKoenRK3v/PC9TFnfpBv+sEcbuykPvTJCUBUYySoMijqI4Y6S/OW+zHL7yGzUSEAvZ749mbhxOMJEBkpZSHNxntQC80tEgqHAwlB1/H1j7xWgyLpynhYrlO+qWH+mwRMNugQyxv0QyyqOxwa/xm8W9HFJhX1SJUN1lWLzXXFVPxKF9H5FuUgrzA5bvbgE4NoiowoLCi94mvyuAKDVQwTywtaMOVo7B0qdK5UrAIqa0rzIOT5DzNU9ZgFW+lOa5ACkikSMkaa+ay722Nz8QqnFohwLWiTG+pDy4OZv1u1Q1CJU+lhP6UXd4zUiAHIFHnagPMFs/FoyY0xvkNC2GBtEGvLUB/vE7KPwjOVkpJXaYSSsbh9BxTqWg5gIf3maWyPqS9IIkzBcnpmCrt3wZZIeQ+iYyMMajeSrNy97mmKiFe6YhcI8C6nc/7mTP9WRqC7nbpKC+eJVlvQNnu4A+LH4NsoipDwtzsJy4sTJuRH3wHLixMm5kSmQcCgQpE4rGCYjAHlxM0RmFRv/gUJm65vkALA6+yg0y9a646hZaa+DidiKlexUraFLhdjQk8FZMP42oRWrvN9tijEaESZKN+aU1LG60KCyykhZS9it3WI9REfdTRrqfLO3M71wPFQxhF+gxTu3xBON1H80UrhwZJE4494WJMyuzDJ3cqeCg1BVRyDeu1FZpS2zHORzsysaNqtqavUAwIygYlGVRgOr6rB4pdHaiT9vjBY0hrwgeRZpzWtPS8hKz6AqBzAIrfTfIlbH0nw8XaCRu9uSmIB7wiCWPWSwKzm5wLJamZHSCX2t55FwX6zDVrUUDQl6RpLRV7HL8T43yal4riVkBQLINi3Np/w5jIaM3uZOQn+JBQBF5+ApXphXTA3x+IfHnRV5H5M2mIshzQEoCdg26mK4twvRsM25YT+Zgm732Pmj72WRRPtK58hi8Tqs7mbbgs6LPPXlmzczZX6eDoqNj9goa//efUxknJUKZ02Ts7CcOHFyfsQ9sJw4cXJuZCokpDFcFAbJCvsTxS9yFqSwou2xchIJpslJ0z0dl3onmLD/Dw6I6ZpiSa/XWJAxK4BW92hMJurtGSVDXYkgQMkHMBTlmDGIB7KWo566MPW4T6fF7luJyAyM7XsQTH+mB0oTbSwQnNbESB0PxWomKGgNoNIxmZlRC3iYQCLWm9aI0AJBzrKyEOs1VZbUSO1WE1NFTbA6s+pDcfJ1NNqeAQHjNRcrQEGIzACgAbEx/kpTAKGK7AsFKUo1PC35ovE1KnPYPAmehwmmh3FwcJxmezKwCEUSLQJrlWSRQloZsX1fSDDuq5hGUcKqvlKeZeDY+OdGKtvyToE3onWL/I4bnlKrCq522/QwtJUvaojZ7jtUlcLt5mcx8nWVSKXim/RVkWOKDXLM22d1NrkUE5yIvaCtAYxBob4rX42mJQjtbnqnvnVMIo3Nzmvp5fVlFYHduqZj8UQf//AHmTLc4e/Oj2NMVAslZ7SbhbOwnDhx4sSJEydOnDhx4sSJEydOnDhx4sSJEydOnDhx4sSJEyf/z8r/Ab+8NWulkQjIAAAAAElFTkSuQmCC\n", - "text/plain": [ - "" - ] - }, - "execution_count": 11, - "metadata": {}, - "output_type": "execute_result" - } - ], + "execution_count": null, + "metadata": { + "lines_to_next_cell": 2 + }, + "outputs": [], "source": [ "dataiter = iter(testloader)\n", "images, labels = dataiter.next() # 一个batch返回4张图片\n", "print('实际的label: ', ' '.join(\\\n", " '%08s'%classes[labels[j]] for j in range(4)))\n", - "show(tv.utils.make_grid(images / 2 - 0.5)).resize((400,100))\n" + "show(tv.utils.make_grid(images / 2 - 0.5)).resize((400,100))" ] }, { diff --git a/demo_code/CNN_CIFAR.py b/demo_code/CNN_CIFAR.py new file mode 100644 index 0000000..d747a2d --- /dev/null +++ b/demo_code/CNN_CIFAR.py @@ -0,0 +1,129 @@ + +import torch as t +import torch.nn as nn +import torch.nn.functional as F +from torch import optim +from torch.autograd import Variable + +import torchvision as tv +import torchvision.transforms as transforms +from torchvision.transforms import ToPILImage + +show = ToPILImage() # 可以把Tensor转成Image,方便可视化 + +# 第一次运行程序torchvision会自动下载CIFAR-10数据集, +# 大约100M,需花费一定的时间, +# 如果已经下载有CIFAR-10,可通过root参数指定 + +# 定义对数据的预处理 +transform = transforms.Compose([ + transforms.ToTensor(), # 转为Tensor + transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), # 归一化 + ]) + +# 训练集 +dataset_path = "../data" +trainset = tv.datasets.CIFAR10( + root=dataset_path, train=True, download=True, transform=transform) + +trainloader = t.utils.data.DataLoader( + trainset, + batch_size=4, + shuffle=True, + num_workers=2) + +# 测试集 +testset = tv.datasets.CIFAR10( + dataset_path, train=False, download=True, transform=transform) + +testloader = t.utils.data.DataLoader( + testset, + batch_size=4, + shuffle=False, + num_workers=2) + +classes = ('plane', 'car', 'bird', 'cat', 'deer', + 'dog', 'frog', 'horse', 'ship', 'truck') + + +# Define the network +class Net(nn.Module): + def __init__(self): + super(Net, self).__init__() + self.conv1 = nn.Conv2d(3, 6, 5) + self.conv2 = nn.Conv2d(6, 16, 5) + self.fc1 = nn.Linear(16*5*5, 120) + self.fc2 = nn.Linear(120, 84) + self.fc3 = nn.Linear(84, 10) + + def forward(self, x): + x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2)) + x = F.max_pool2d(F.relu(self.conv2(x)), 2) + x = x.view(x.size()[0], -1) + x = F.relu(self.fc1(x)) + x = F.relu(self.fc2(x)) + x = self.fc3(x) + return x + + +net = Net() +print(net) + +criterion = nn.CrossEntropyLoss() # 交叉熵损失函数 +optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9) + +t.set_num_threads(8) +for epoch in range(2): + + running_loss = 0.0 + for i, data in enumerate(trainloader, 0): + + # 输入数据 + inputs, labels = data + inputs, labels = Variable(inputs), Variable(labels) + + # 梯度清零 + optimizer.zero_grad() + + # forward + backward + outputs = net(inputs) + loss = criterion(outputs, labels) + loss.backward() + + # 更新参数 + optimizer.step() + + # 打印log信息 + running_loss += loss.data[0] + if i % 2000 == 1999: # 每2000个batch打印一下训练状态 + print('[%d, %5d] loss: %.3f' \ + % (epoch + 1, i + 1, running_loss / 2000)) + running_loss = 0.0 +print('Finished Training') + +dataiter = iter(testloader) +images, labels = dataiter.next() # 一个batch返回4张图片 +print('实际的label: ', ' '.join(\ + '%08s'%classes[labels[j]] for j in range(4))) +show(tv.utils.make_grid(images / 2 - 0.5)).resize((400,100)) + + +# 计算图片在每个类别上的分数 +outputs = net(Variable(images)) +# 得分最高的那个类 +_, predicted = t.max(outputs.data, 1) + +print('预测结果: ', ' '.join('%5s'\ + % classes[predicted[j]] for j in range(4))) + + +correct = 0 # 预测正确的图片数 +total = 0 # 总共的图片数 +for data in testloader: + images, labels = data + outputs = net(Variable(images)) + _, predicted = t.max(outputs.data, 1) + total += labels.size(0) + correct += (predicted == labels).sum() + +print('10000张测试集中的准确率为: %d %%' % (100 * correct / total)) \ No newline at end of file diff --git a/demo_code/Nerual_Network.py b/demo_code/Nerual_Network.py new file mode 100644 index 0000000..47d5c99 --- /dev/null +++ b/demo_code/Nerual_Network.py @@ -0,0 +1,96 @@ +from __future__ import print_function + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +from torch.autograd import Variable + +from torchvision import datasets, transforms + +# Training settings +batch_size = 64 + +# MNIST Dataset +dataset_path = "../data/mnist" +train_dataset = datasets.MNIST(root=dataset_path, + train=True, + transform=transforms.ToTensor(), + download=True) + +test_dataset = datasets.MNIST(root=dataset_path, + train=False, + transform=transforms.ToTensor()) + +# Data Loader (Input Pipeline) +train_loader = torch.utils.data.DataLoader(dataset=train_dataset, + batch_size=batch_size, + shuffle=True) + +test_loader = torch.utils.data.DataLoader(dataset=test_dataset, + batch_size=batch_size, + shuffle=False) + + +# define Network +class Net(nn.Module): + def __init__(self): + super(Net, self).__init__() + self.l1 = nn.Linear(784, 520) + self.l2 = nn.Linear(520, 320) + self.l3 = nn.Linear(320, 240) + self.l4 = nn.Linear(240, 120) + self.l5 = nn.Linear(120, 10) + + def forward(self, x): + x = x.view(-1, 784) # Flatten the data (n, 1, 28, 28)-> (n, 784) + x = F.relu(self.l1(x)) + x = F.relu(self.l2(x)) + x = F.relu(self.l3(x)) + x = F.relu(self.l4(x)) + return self.l5(x) + + +model = Net() + +criterion = nn.CrossEntropyLoss() +optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.5) + + +def train(epoch): + #model.train() + for batch_idx, (data, target) in enumerate(train_loader): + data, target = Variable(data), Variable(target) + optimizer.zero_grad() + output = model(data) + loss = criterion(output, target) + loss.backward() + optimizer.step() + if batch_idx % 10 == 0: + print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( + epoch, batch_idx * len(data), len(train_loader.dataset), + 100. * batch_idx / len(train_loader), loss.data[0])) + + +def test(): + model.eval() + test_loss = 0 + correct = 0 + for data, target in test_loader: + data, target = Variable(data, volatile=True), Variable(target) + output = model(data) + # sum up batch loss + test_loss += criterion(output, target).data[0] + # get the index of the max + pred = output.data.max(1, keepdim=True)[1] + correct += pred.eq(target.data.view_as(pred)).cpu().sum() + + test_loss /= len(test_loader.dataset) + print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format( + test_loss, correct, len(test_loader.dataset), + 100. * correct / len(test_loader.dataset))) + + +for epoch in range(1, 10): + train(epoch) + test() diff --git a/demo_code/Neural_Network.py b/demo_code/Neural_Network.0.py similarity index 92% rename from demo_code/Neural_Network.py rename to demo_code/Neural_Network.0.py index fcf95b4..c6a9afd 100644 --- a/demo_code/Neural_Network.py +++ b/demo_code/Neural_Network.0.py @@ -6,9 +6,9 @@ from torch.utils.data import DataLoader from torchvision import transforms from torchvision import datasets -batch_size = 32 -learning_rate = 1e-2 -num_epoches = 50 +batch_size = 32 +learning_rate = 1e-2 +num_epoches = 50 # 下载训练集 MNIST 手写数字训练集 dataset_path = "../data/mnist" @@ -50,15 +50,21 @@ for epoch in range(num_epoches): print('*' * 10) running_loss = 0.0 running_acc = 0.0 + for i, data in enumerate(train_loader, 1): + # FIXME: label need to change one-hot coding img, label = data img = img.view(img.size(0), -1) + target = torch.zeros(label.size(0), 10) + target = target.scatter_(1, label.data, 1) + if torch.cuda.is_available(): img = Variable(img).cuda() label = Variable(label).cuda() else: img = Variable(img) label = Variable(label) + # 向前传播 out = model(img) loss = criterion(out, label) @@ -66,6 +72,7 @@ for epoch in range(num_epoches): _, pred = torch.max(out, 1) num_correct = (pred == label).sum() running_acc += num_correct.data[0] + # 向后传播 optimizer.zero_grad() loss.backward() @@ -75,12 +82,15 @@ for epoch in range(num_epoches): print('[{}/{}] Loss: {:.6f}, Acc: {:.6f}'.format( epoch + 1, num_epoches, running_loss / (batch_size * i), running_acc / (batch_size * i))) + print('Finish {} epoch, Loss: {:.6f}, Acc: {:.6f}'.format( epoch + 1, running_loss / (len(train_dataset)), running_acc / (len( train_dataset)))) + model.eval() eval_loss = 0. eval_acc = 0. + for data in test_loader: img, label = data img = img.view(img.size(0), -1) @@ -96,9 +106,10 @@ for epoch in range(num_epoches): _, pred = torch.max(out, 1) num_correct = (pred == label).sum() eval_acc += num_correct.data[0] + print('Test Loss: {:.6f}, Acc: {:.6f}'.format(eval_loss / (len( test_dataset)), eval_acc / (len(test_dataset)))) print() # 保存模型 -torch.save(model.state_dict(), './neural_network.pth') +torch.save(model.state_dict(), './model_Neural_Network.pth')