pytorch-handbook/chapter5/5.3-Fashion-MNIST.ipynb
2020-03-22 19:43:42 +08:00

1888 lines
173 KiB
Plaintext
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'1.2.0'"
]
},
"execution_count": 1,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import torch,math\n",
"from pathlib import Path\n",
"import pandas as pd\n",
"import numpy as np\n",
"import matplotlib.pyplot as plt \n",
"import torchvision.datasets as dsets\n",
"from torch.utils.data import Dataset, DataLoader\n",
"\n",
"import torch.nn.functional as F\n",
"import torch.nn as NN\n",
"torch.__version__"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Fashion MNIST进行分类 \n",
"## Fashion MNIST 介绍\n",
"Fashion MNIST数据集 是kaggle上提供的一个图像分类入门级的数据集其中包含10个类别的70000个灰度图像。如图所示这些图片显示的是每件衣服的低分辨率(28×28像素)\n",
"\n",
"数据集的下载和介绍:[地址](https://www.kaggle.com/zalando-research/fashionmnist/)\n",
"\n",
"\n",
"Fashion MNIST的目标是作为经典MNIST数据的替换——通常被用作计算机视觉机器学习程序的“Hello, World”。\n",
"\n",
"MNIST数据集包含手写数字(0-9等)的图像格式与我们将在这里使用的衣服相同MNIST只有手写的0-1数据的复杂度不高所以他只能用来做“Hello, World”\n",
"\n",
"而Fashion MNIST 的由于使用的是衣服的数据比数字要复杂的多并且图片的内容也会更加多样性所以它是一个比常规MNIST稍微更具挑战性的问题。\n",
"\n",
"Fashion MNIST这个数据集相对较小用于验证算法是否按预期工作。它们是测试和调试代码的好起点。\n",
"\n",
"## 数据集介绍\n",
"\n",
"### 分类\n",
"```\n",
"0 T-shirt/top\n",
"1 Trouser\n",
"2 Pullover\n",
"3 Dress\n",
"4 Coat\n",
"5 Sandal\n",
"6 Shirt\n",
"7 Sneaker\n",
"8 Bag\n",
"9 Ankle boot \n",
"```\n",
"### 格式\n",
"\n",
"fashion-mnist_test.csv\n",
"\n",
"fashion-mnist_train.csv\n",
"\n",
"存储的训练的数据和测试的数据,格式如下:\n",
"\n",
"label是分类的标签\n",
"pixel1-pixel784是每一个像素代表的值 因为是灰度图像所以是一个0-255之间的数值。\n",
"\n",
"为什么是784个像素 28 * 28 = 784\n",
"\n",
"### 数据提交\n",
"\n",
"Fashion MNIST不需要我们进行数据的提交数据集中已经帮助我们将 训练集和测试集分好了我们只需要载入、训练、查看即可所以Fashion MNIST 是一个非常好的入门级别的数据集\n"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"#指定数据目录\n",
"DATA_PATH=Path('./data/')"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>label</th>\n",
" <th>pixel1</th>\n",
" <th>pixel2</th>\n",
" <th>pixel3</th>\n",
" <th>pixel4</th>\n",
" <th>pixel5</th>\n",
" <th>pixel6</th>\n",
" <th>pixel7</th>\n",
" <th>pixel8</th>\n",
" <th>pixel9</th>\n",
" <th>...</th>\n",
" <th>pixel775</th>\n",
" <th>pixel776</th>\n",
" <th>pixel777</th>\n",
" <th>pixel778</th>\n",
" <th>pixel779</th>\n",
" <th>pixel780</th>\n",
" <th>pixel781</th>\n",
" <th>pixel782</th>\n",
" <th>pixel783</th>\n",
" <th>pixel784</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <td>0</td>\n",
" <td>2</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>...</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1</td>\n",
" <td>9</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>...</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <td>2</td>\n",
" <td>6</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>5</td>\n",
" <td>0</td>\n",
" <td>...</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>30</td>\n",
" <td>43</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <td>3</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" <td>2</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>...</td>\n",
" <td>3</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <td>4</td>\n",
" <td>3</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>...</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <td>5</td>\n",
" <td>4</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>5</td>\n",
" <td>4</td>\n",
" <td>5</td>\n",
" <td>5</td>\n",
" <td>3</td>\n",
" <td>5</td>\n",
" <td>...</td>\n",
" <td>7</td>\n",
" <td>8</td>\n",
" <td>7</td>\n",
" <td>4</td>\n",
" <td>3</td>\n",
" <td>7</td>\n",
" <td>5</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <td>6</td>\n",
" <td>4</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>...</td>\n",
" <td>14</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <td>7</td>\n",
" <td>5</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>...</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <td>8</td>\n",
" <td>4</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>3</td>\n",
" <td>2</td>\n",
" <td>0</td>\n",
" <td>...</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <td>9</td>\n",
" <td>8</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>...</td>\n",
" <td>203</td>\n",
" <td>214</td>\n",
" <td>166</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>10 rows × 785 columns</p>\n",
"</div>"
],
"text/plain": [
" label pixel1 pixel2 pixel3 pixel4 pixel5 pixel6 pixel7 pixel8 \\\n",
"0 2 0 0 0 0 0 0 0 0 \n",
"1 9 0 0 0 0 0 0 0 0 \n",
"2 6 0 0 0 0 0 0 0 5 \n",
"3 0 0 0 0 1 2 0 0 0 \n",
"4 3 0 0 0 0 0 0 0 0 \n",
"5 4 0 0 0 5 4 5 5 3 \n",
"6 4 0 0 0 0 0 0 0 0 \n",
"7 5 0 0 0 0 0 0 0 0 \n",
"8 4 0 0 0 0 0 0 3 2 \n",
"9 8 0 0 0 0 0 0 0 0 \n",
"\n",
" pixel9 ... pixel775 pixel776 pixel777 pixel778 pixel779 pixel780 \\\n",
"0 0 ... 0 0 0 0 0 0 \n",
"1 0 ... 0 0 0 0 0 0 \n",
"2 0 ... 0 0 0 30 43 0 \n",
"3 0 ... 3 0 0 0 0 1 \n",
"4 0 ... 0 0 0 0 0 0 \n",
"5 5 ... 7 8 7 4 3 7 \n",
"6 0 ... 14 0 0 0 0 0 \n",
"7 0 ... 0 0 0 0 0 0 \n",
"8 0 ... 1 0 0 0 0 0 \n",
"9 0 ... 203 214 166 0 0 0 \n",
"\n",
" pixel781 pixel782 pixel783 pixel784 \n",
"0 0 0 0 0 \n",
"1 0 0 0 0 \n",
"2 0 0 0 0 \n",
"3 0 0 0 0 \n",
"4 0 0 0 0 \n",
"5 5 0 0 0 \n",
"6 0 0 0 0 \n",
"7 0 0 0 0 \n",
"8 0 0 0 0 \n",
"9 0 0 0 0 \n",
"\n",
"[10 rows x 785 columns]"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"train = pd.read_csv(DATA_PATH / \"fashion-mnist_train.csv\");\n",
"train.head(10)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>label</th>\n",
" <th>pixel1</th>\n",
" <th>pixel2</th>\n",
" <th>pixel3</th>\n",
" <th>pixel4</th>\n",
" <th>pixel5</th>\n",
" <th>pixel6</th>\n",
" <th>pixel7</th>\n",
" <th>pixel8</th>\n",
" <th>pixel9</th>\n",
" <th>...</th>\n",
" <th>pixel775</th>\n",
" <th>pixel776</th>\n",
" <th>pixel777</th>\n",
" <th>pixel778</th>\n",
" <th>pixel779</th>\n",
" <th>pixel780</th>\n",
" <th>pixel781</th>\n",
" <th>pixel782</th>\n",
" <th>pixel783</th>\n",
" <th>pixel784</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>9</td>\n",
" <td>8</td>\n",
" <td>...</td>\n",
" <td>103</td>\n",
" <td>87</td>\n",
" <td>56</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>...</td>\n",
" <td>34</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <td>2</td>\n",
" <td>2</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>14</td>\n",
" <td>53</td>\n",
" <td>99</td>\n",
" <td>...</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>63</td>\n",
" <td>53</td>\n",
" <td>31</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <td>3</td>\n",
" <td>2</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>...</td>\n",
" <td>137</td>\n",
" <td>126</td>\n",
" <td>140</td>\n",
" <td>0</td>\n",
" <td>133</td>\n",
" <td>224</td>\n",
" <td>222</td>\n",
" <td>56</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <td>4</td>\n",
" <td>3</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>...</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <td>5</td>\n",
" <td>2</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>44</td>\n",
" <td>105</td>\n",
" <td>44</td>\n",
" <td>10</td>\n",
" <td>...</td>\n",
" <td>105</td>\n",
" <td>64</td>\n",
" <td>30</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <td>6</td>\n",
" <td>8</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>...</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <td>7</td>\n",
" <td>6</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>...</td>\n",
" <td>174</td>\n",
" <td>136</td>\n",
" <td>155</td>\n",
" <td>31</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <td>8</td>\n",
" <td>5</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>...</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <td>9</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>...</td>\n",
" <td>57</td>\n",
" <td>70</td>\n",
" <td>28</td>\n",
" <td>0</td>\n",
" <td>2</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>10 rows × 785 columns</p>\n",
"</div>"
],
"text/plain": [
" label pixel1 pixel2 pixel3 pixel4 pixel5 pixel6 pixel7 pixel8 \\\n",
"0 0 0 0 0 0 0 0 0 9 \n",
"1 1 0 0 0 0 0 0 0 0 \n",
"2 2 0 0 0 0 0 0 14 53 \n",
"3 2 0 0 0 0 0 0 0 0 \n",
"4 3 0 0 0 0 0 0 0 0 \n",
"5 2 0 0 0 0 0 44 105 44 \n",
"6 8 0 0 0 0 0 0 0 0 \n",
"7 6 0 0 0 0 0 0 0 1 \n",
"8 5 0 0 0 0 0 0 0 0 \n",
"9 0 0 0 0 0 0 0 0 0 \n",
"\n",
" pixel9 ... pixel775 pixel776 pixel777 pixel778 pixel779 pixel780 \\\n",
"0 8 ... 103 87 56 0 0 0 \n",
"1 0 ... 34 0 0 0 0 0 \n",
"2 99 ... 0 0 0 0 63 53 \n",
"3 0 ... 137 126 140 0 133 224 \n",
"4 0 ... 0 0 0 0 0 0 \n",
"5 10 ... 105 64 30 0 0 0 \n",
"6 0 ... 0 0 0 0 0 0 \n",
"7 0 ... 174 136 155 31 0 1 \n",
"8 0 ... 0 0 0 0 0 0 \n",
"9 0 ... 57 70 28 0 2 0 \n",
"\n",
" pixel781 pixel782 pixel783 pixel784 \n",
"0 0 0 0 0 \n",
"1 0 0 0 0 \n",
"2 31 0 0 0 \n",
"3 222 56 0 0 \n",
"4 0 0 0 0 \n",
"5 0 0 0 0 \n",
"6 0 0 0 0 \n",
"7 0 0 0 0 \n",
"8 0 0 0 0 \n",
"9 0 0 0 0 \n",
"\n",
"[10 rows x 785 columns]"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"test = pd.read_csv(DATA_PATH / \"fashion-mnist_test.csv\");\n",
"test.head(10)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"label 9\n",
"pixel1 16\n",
"pixel2 36\n",
"pixel3 226\n",
"pixel4 164\n",
" ... \n",
"pixel780 255\n",
"pixel781 255\n",
"pixel782 255\n",
"pixel783 255\n",
"pixel784 170\n",
"Length: 785, dtype: int64"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"train.max()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"ubyte文件标识了数据的格式\n",
"\n",
"其中idx3的数字表示数据维度。也就是图像为3维\n",
"idx1 标签维1维。\n",
"\n",
"具体格式详解http://yann.lecun.com/exdb/mnist/"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"(2051, 60000, 28, 28)\n"
]
}
],
"source": [
"import struct\n",
"from PIL import Image \n",
"\n",
"with open(DATA_PATH / \"train-images-idx3-ubyte\", 'rb') as file_object:\n",
" header_data=struct.unpack(\">4I\",file_object.read(16))\n",
" print(header_data)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"(2049, 60000)\n"
]
}
],
"source": [
"with open(DATA_PATH / \"train-labels-idx1-ubyte\", 'rb') as file_object:\n",
" header_data=struct.unpack(\">2I\",file_object.read(8))\n",
" print(header_data)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"如下是训练的图片的二进制格式\n",
"\n",
" [offset] [type] [value] [description]\n",
" 0000 32 bit integer 0x00000803(2051) magic number\n",
" 0004 32 bit integer 60000 number of images\n",
" 0008 32 bit integer 28 number of rows\n",
" 0012 32 bit integer 28 number of columns\n",
" 0016 unsigned byte ?? pixel\n",
" 0017 unsigned byte ?? pixel\n",
" ........\n",
" xxxx unsigned byte ?? pixel\n",
" \n",
"有四字节的header_data故使用`unpack_from`进行二进制转换时偏置offset=16"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"(28, 28)\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAP8AAAD8CAYAAAC4nHJkAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvnQurowAAEfJJREFUeJzt3W2M1eWZx/HfJfjEg6AigsiKVlzZGBfXEY1PUStGN41atVhfbDDW0piabJOarPFNTcxGott2+8I0odZUY2vbpFI1PtWYTdwNqIyEAHW2LSrWERxUFHl0GLj2BYfNiPO/rsM5Z8459P5+EjMz55p7zj1n+HnOzPW/79vcXQDKc1inJwCgMwg/UCjCDxSK8AOFIvxAoQg/UCjCDxSK8AOFIvxAoca2887MjMsJgVHm7lbP5zX1zG9mV5vZn8xsnZnd3czXAtBe1ui1/WY2RtKfJc2X1C9phaRb3P3NYAzP/MAoa8cz/zxJ69z9bXcflPRrSdc18fUAtFEz4Z8h6b1hH/fXbvsCM1tkZr1m1tvEfQFosWb+4DfSS4svvax39yWSlki87Ae6STPP/P2SZg77+GRJG5qbDoB2aSb8KyTNNrNTzewISd+U9HRrpgVgtDX8st/dh8zsTkkvShoj6RF3/2PLZgZgVDXc6mvozvidHxh1bbnIB8Chi/ADhSL8QKEIP1Aowg8UivADhSL8QKEIP1Aowg8UivADhSL8QKEIP1Aowg8Uqq1bd6P9zOIFXs2u6pw4cWJYv/jiiytrzz//fFP3nX1vY8aMqawNDQ01dd/NyuYeadVKXJ75gUIRfqBQhB8oFOEHCkX4gUIRfqBQhB8oFH3+v3GHHRb//33Pnj1h/fTTTw/rt99+e1jfuXNnZW379u3h2F27doX1119/Paw308vP+vDZ45qNb2Zu0fUL2c9zOJ75gUIRfqBQhB8oFOEHCkX4gUIRfqBQhB8oVFN9fjNbL2mrpD2Shty9pxWTQutEPWEp7wtfccUVYf3KK68M6/39/ZW1I488Mhw7bty4sD5//vyw/vDDD1fWBgYGwrHZmvmD6aePZMKECZW1vXv3hmN37NjR1H3v14qLfC53949a8HUAtBEv+4FCNRt+l/QHM3vDzBa1YkIA2qPZl/0XufsGM5sq6SUz+193f2X4J9T+p8D/GIAu09Qzv7tvqL3dJGmppHkjfM4Sd+/hj4FAd2k4/GY23swm7n9f0lWS1rZqYgBGVzMv+0+UtLS2dHGspF+5+wstmRWAUddw+N39bUn/2MK5YBQMDg42Nf68884L67NmzQrr0XUG2Zr4F198Mayfc845Yf2BBx6orPX29oZj16xZE9b7+vrC+rx5X/oN+Auix3XZsmXh2OXLl1fWtm3bFo4djlYfUCjCDxSK8AOFIvxAoQg/UCjCDxTKWnXcb113Zta+OytItE109vPNlsVG7TJJmjx5cljfvXt3ZS1buppZsWJFWF+3bl1lrdkW6PTp08N69H1L8dxvuummcOxDDz1UWevt7dVnn31W1/nfPPMDhSL8QKEIP1Aowg8UivADhSL8QKEIP1Ao+vxdIDvOuRnZz/fVV18N69mS3Uz0vWXHVDfbi4+O+M6uMVi5cmVYj64hkPLv7eqrr66snXbaaeHYGTNmhHV3p88PoBrhBwpF+IFCEX6gUIQfKBThBwpF+IFCteKUXjSpnddaHOiTTz4J69m69Z07d4b16BjusWPjf37RMdZS3MeXpKOPPrqylvX5L7nkkrB+4YUXhvVsW/KpU6dW1l54oT3HX/DMDxSK8AOFIvxAoQg/UCjCDxSK8AOFIvxAodI+v5k9Iulrkja5+1m1246T9BtJsyStl7TA3eOGMbrSuHHjwnrWr87qO3bsqKxt2bIlHPvxxx+H9Wyvgej6iWwPhez7yh63PXv2hPXoOoOZM2eGY1ulnmf+X0g6cOeBuyW97O6zJb1c+xjAISQNv7u/ImnzATdfJ+nR2vuPSrq+xfMCMMoa/Z3/RHffKEm1t9XXKgLoSqN+bb+ZLZK0aLTvB8DBafSZf8DMpktS7e2mqk909yXu3uPuPQ3eF4BR0Gj4n5a0sPb+QklPtWY6ANolDb+ZPSFpuaS/N7N+M/uWpMWS5pvZXyTNr30M4BCS/s7v7rdUlL7a4rkUq9mec9RTztbEn3TSSWH9888/b6oerefP9uWPrhGQpMmTJ4f16DqBrE9/xBFHhPWtW7eG9UmTJoX11atXV9ayn1lPT/Vv0G+++WY4djiu8AMKRfiBQhF+oFCEHygU4QcKRfiBQrF1dxfItu4eM2ZMWI9afTfffHM4dtq0aWH9ww8/DOvR9thSvHR1/Pjx4dhsaWvWKozajLt37w7HZtuKZ9/38ccfH9YfeuihytrcuXPDsdHcDua4d575gUIRfqBQhB8oFOEHCkX4gUIRfqBQhB8olLXzeGgz69xZ1F0s6ykPDQ01/LXPP//8sP7ss8+G9ewI7mauQZg4cWI4NjuCO9va+/DDD2+oJuXXIGRHm2ei7+3BBx8Mxz7++ONh3d3ravbzzA8UivADhSL8QKEIP1Aowg8UivADhSL8QKEOqfX80VrlrN+cbX+drYOO1n9Ha9br0UwfP/Pcc8+F9e3bt4f1rM+fbXEdXUeS7RWQ/UyPOuqosJ6t2W9mbPYzz+Z+9tlnV9ayo8tbhWd+oFCEHygU4QcKRfiBQhF+oFCEHygU4QcKlfb5zewRSV+TtMndz6rddq+kb0va36i9x93jhnIdmlkbPpq98tF26aWXhvUbb7wxrF900UWVteyY62xNfNbHz/YiiH5m2dyyfw/RvvxSfB1Ato9FNrdM9rht27atsnbDDTeEY5955pmG5nSgep75fyHp6hFu/7G7z63913TwAbRXGn53f0XS5jbMBUAbNfM7/51mttrMHjGzY1s2IwBt0Wj4fyrpK5LmStoo6YdVn2hmi8ys18x6G7wvAKOgofC7+4C773H3vZJ+Jmle8LlL3L3H3XsanSSA1mso/GY2fdiHX5e0tjXTAdAu9bT6npB0maQpZtYv6QeSLjOzuZJc0npJ3xnFOQIYBcXs23/ccceF9ZNOOimsz549u+GxWd/2jDPOCOuff/55WI/2KsjWpWfnzG/YsCGsZ/vfR/3u7Az7wcHBsD5u3LiwvmzZssrahAkTwrHZtRfZev5sTX70uA0MDIRj58yZE9bZtx9AiPADhSL8QKEIP1Aowg8UivADheqqVt8FF1wQjr/vvvsqayeccEI4dvLkyWE9WnoqxctLP/3003Bsttw4a1llLa9o2/Fs6+2+vr6wvmDBgrDe2xtftR0dw33ssfGSkFmzZoX1zNtvv11Zy44H37p1a1jPlvxmLdSo1XjMMceEY7N/L7T6AIQIP1Aowg8UivADhSL8QKEIP1Aowg8Uqu19/qhfvnz58nD89OnTK2tZnz6rN7NVc7bFdNZrb9akSZMqa1OmTAnH3nrrrWH9qquuCut33HFHWI+WBO/atSsc+84774T1qI8vxcuwm11OnC1lzq4jiMZny4VPOeWUsE6fH0CI8AOFIvxAoQg/UCjCDxSK8AOFIvxAodra558yZYpfe+21lfXFixeH4996663KWrYVc1bPjnuOZD3fqA8vSe+9915Yz7bPjvYyiLb1lqRp06aF9euvvz6sR8dgS/Ga/Oxncu655zZVj773rI+fPW7ZEdyZaA+G7N9TtO/FBx98oMHBQfr8AKoRfqBQhB8oFOEHCkX4gUIRfqBQhB8o1NjsE8xspqTHJE2TtFfSEnf/iZkdJ+k3kmZJWi9pgbt/En2toaEhbdq0qbKe9bujNdLZMdbZ1856zlFfN9tnffPmzWH93XffDevZ3KL9ArI189mZAkuXLg3ra9asCetRnz87Nj3rxWfnJUTHk2ffd7amPuvFZ+OjPn92DUF0pHv2mAxXzzP/kKTvu/scSRdI+q6Z/YOkuyW97O6zJb1c+xjAISINv7tvdPeVtfe3SuqTNEPSdZIerX3ao5LiS8EAdJWD+p3fzGZJOkfSa5JOdPeN0r7/QUia2urJARg9dYffzCZI+p2k77n7ZwcxbpGZ9ZpZb/Y7HID2qSv8Zna49gX/l+7+ZO3mATObXqtPlzTiX/LcfYm797h7T7OLIQC0Thp+2/dnyZ9L6nP3Hw0rPS1pYe39hZKeav30AIyWtNUn6SJJ/yJpjZmtqt12j6TFkn5rZt+S9FdJ38i+0ODgoN5///3Kera8uL+/v7I2fvz4cGy2hXXWIvnoo48qax9++GE4duzY+GHOlhNnbaVoWW22hXS2dDX6viVpzpw5YX379u2Vtaz9+sknYec4fdyiuUdtQClvBWbjsyO6o6XUW7ZsCcfOnTu3srZ27dpw7HBp+N39fyRVNSW/Wvc9AegqXOEHFIrwA4Ui/EChCD9QKMIPFIrwA4Wqp8/fMjt37tSqVasq608++WRlTZJuu+22ylq2vXV2nHO29DVaVpv14bOeb3blY3YEeLScOTuaPLu2Iju6fOPGjQ1//Wxu2fURzfzMml0u3MxyYim+juDUU08Nxw4MDDR8v8PxzA8UivADhSL8QKEIP1Aowg8UivADhSL8QKHaekS3mTV1Z9dcc01l7a677grHTp0abzGYrVuP+rpZvzrr02d9/qzfHX39aItoKe/zZ9cwZPXoe8vGZnPPROOjXnk9sp9ZtnV3tJ5/9erV4dgFCxaEdXfniG4A1Qg/UCjCDxSK8AOFIvxAoQg/UCjCDxSq7X3+aJ/4rDfajMsvvzys33///WE9uk5g0qRJ4dhsb/zsOoCsz59dZxCJjkyX8usAonMYpPhnum3btnBs9rhkorln696zfQyyn+lLL70U1vv6+ipry5YtC8dm6PMDCBF+oFCEHygU4QcKRfiBQhF+oFCEHyhU2uc3s5mSHpM0TdJeSUvc/Sdmdq+kb0vafzj9Pe7+XPK12ndRQRudeeaZYX3KlClhPdsD/uSTTw7r69evr6xl/ey33norrOPQU2+fv55DO4Ykfd/dV5rZRElvmNn+Kxh+7O7/0egkAXROGn533yhpY+39rWbWJ2nGaE8MwOg6qN/5zWyWpHMkvVa76U4zW21mj5jZsRVjFplZr5n1NjVTAC1Vd/jNbIKk30n6nrt/Jumnkr4iaa72vTL44Ujj3H2Ju/e4e08L5gugReoKv5kdrn3B/6W7PylJ7j7g7nvcfa+kn0maN3rTBNBqafht3xaoP5fU5+4/Gnb79GGf9nVJa1s/PQCjpZ5W38WS/lvSGu1r9UnSPZJu0b6X/C5pvaTv1P44GH2tv8lWH9BN6m31HVL79gPIsZ4fQIjwA4Ui/EChCD9QKMIPFIrwA4Ui/EChCD9QKMIPFIrwA4Ui/EChCD9QKMIPFIrwA4WqZ/feVvpI0rvDPp5Su60bdevcunVeEnNrVCvndkq9n9jW9fxfunOz3m7d269b59at85KYW6M6NTde9gOFIvxAoTod/iUdvv9It86tW+clMbdGdWRuHf2dH0DndPqZH0CHdCT8Zna1mf3JzNaZ2d2dmEMVM1tvZmvMbFWnjxirHYO2yczWDrvtODN7ycz+Uns74jFpHZrbvWb2fu2xW2Vm/9yhuc00s/8ysz4z+6OZ/Wvt9o4+dsG8OvK4tf1lv5mNkfRnSfMl9UtaIekWd3+zrROpYGbrJfW4e8d7wmZ2qaRtkh5z97Nqtz0gabO7L679j/NYd/+3LpnbvZK2dfrk5tqBMtOHnywt6XpJt6qDj10wrwXqwOPWiWf+eZLWufvb7j4o6deSruvAPLqeu78iafMBN18n6dHa+49q3z+etquYW1dw943uvrL2/lZJ+0+W7uhjF8yrIzoR/hmS3hv2cb+668hvl/QHM3vDzBZ1ejIjOHH/yUi1t1M7PJ8DpSc3t9MBJ0t3zWPXyInXrdaJ8I90mkg3tRwucvd/knSNpO/WXt6iPnWd3NwuI5ws3RUaPfG61ToR/n5JM4d9fLKkDR2Yx4jcfUPt7SZJS9V9pw8P7D8ktfZ2U4fn8/+66eTmkU6WVhc8dt104nUnwr9C0mwzO9XMjpD0TUlPd2AeX2Jm42t/iJGZjZd0lbrv9OGnJS2svb9Q0lMdnMsXdMvJzVUnS6vDj123nXjdkYt8aq2M/5Q0RtIj7v7vbZ/ECMzsNO17tpf2rXj8VSfnZmZPSLpM+1Z9DUj6gaTfS/qtpL+T9FdJ33D3tv/hrWJul+kgT24epblVnSz9mjr42LXyxOuWzIcr/IAycYUfUCjCDxSK8AOFIvxAoQg/UCjCDxSK8AOFIvxAof4PYwQAhKEd7F8AAAAASUVORK5CYII=\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"with open(DATA_PATH / \"train-images-idx3-ubyte\", 'rb') as file_object:\n",
" raw_img=file_object.read()\n",
"img = struct.unpack_from(\">784B\",raw_img,16)\n",
"image = np.asarray(img)\n",
"image = image.reshape((28,28))\n",
"print(image.shape)\n",
"plt.imshow(image,cmap = plt.cm.gray)\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"(0,)\n"
]
}
],
"source": [
"with open(DATA_PATH / \"train-labels-idx1-ubyte\", 'rb') as file_object:\n",
" raw_img = file_object.read(1)\n",
" label = struct.unpack(\">B\",raw_img)\n",
" print(label)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"这里好像有点错误显示的错位了但是我的确是按照格式进行处理的。这种格式处理起来比较复杂并且数据集中的csv直接给出了每个像素的值所以这里我们可以直接使用csv格式的数据。"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 数据加载\n",
"\n",
"为了使用pytorch的dataloader进行数据的加载需要先创建一个自定义的dataset"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"class FashionMNISTDataset(Dataset):\n",
" def __init__(self, csv_file, transform=None):\n",
" data = pd.read_csv(csv_file)\n",
" self.X = np.array(data.iloc[:, 1:]).reshape(-1, 1, 28, 28).astype(float)\n",
" self.Y = np.array(data.iloc[:, 0]);\n",
" del data; #结束data对数据的引用,节省空间\n",
" self.len=len(self.X)\n",
"\n",
" def __len__(self):\n",
" #return len(self.X)\n",
" return self.len\n",
" \n",
" \n",
" def __getitem__(self, idx):\n",
" item = self.X[idx]\n",
" label = self.Y[idx]\n",
" return (item, label)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"对于自定义的数据集,只需要实现三个函数:\n",
"\n",
"`__init__` 初始化函数主要用于数据的加载这里直接使用pandas将数据读取为dataframe然后将其转成numpy数组来进行索引\n",
"\n",
"`__len__` 返回数据集的总数pytorch里面的datalorder需要知道数据集的总数的\n",
"\n",
"`__getitem__`会返回单张图片它包含一个index返回值为样本及其标签。\n",
"\n",
"创建训练和测试集"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
"train_dataset = FashionMNISTDataset(csv_file=DATA_PATH / \"fashion-mnist_train.csv\")\n",
"test_dataset = FashionMNISTDataset(csv_file=DATA_PATH / \"fashion-mnist_test.csv\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"在使用Pytorch的DataLoader读取数据之前需要指定一个batch size 这也是一个超参数涉及到内存的使用量如果出现OOM的错误则要减小这个数值一般这个数值都为2的幂或者2的倍数。"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
"#因为是常量,所以大写,需要说明的是,这些常量建议都使用完整的英文单词,减少歧义\n",
"BATCH_SIZE=256 # 这个batch 可以在M250的笔记本显卡中进行训练不会oom"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"我们接着使用dataloader模块来使用这些数据"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [],
"source": [
"train_loader = torch.utils.data.DataLoader(dataset=train_dataset,\n",
" batch_size=BATCH_SIZE,\n",
" shuffle=True) # shuffle 标识要打乱顺序"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [],
"source": [
"test_loader = torch.utils.data.DataLoader(dataset=test_dataset,\n",
" batch_size=BATCH_SIZE,\n",
" shuffle=False) # shuffle 标识要打乱顺序,测试集不需要打乱"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"查看一下数据"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(torch.Size([1, 28, 28]), torch.Size([28, 28]))"
]
},
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"a=iter(train_loader)\n",
"data=next(a)\n",
"img=data[0][0].reshape(28,28)\n",
"data[0][0].shape,img.shape"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAD4CAYAAAAq5pAIAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAR6klEQVR4nO3dbYxUVZ4G8OehedNu0Oalu3lvQKIubVYMggloUGRkMQQnihmiG9YYMXGMQzIflrgf4IMhZrMzk00kE3uCGWaddUKcAYkiO0hQ10QILUHAQUYgLNM00CDIiyCv//3Q10mLff+nrVtVt+A8v6RT3fXvU/d00Q+3qs8959DMICLXvx55d0BEykNhF4mEwi4SCYVdJBIKu0gkepbzYCT1p3+REjMzdnV/pjM7yZkkd5PcQ3JRlscSkdJioePsJKsA/BXADACtALYAmGdmf3Ha6MwuUmKlOLNPArDHzPaZ2QUAfwAwJ8PjiUgJZQn7MAB/6/R1a3Lfd5BcQLKFZEuGY4lIRln+QNfVS4XvvUw3s2YAzYBexovkKcuZvRXAiE5fDwfQlq07IlIqWcK+BcA4kqNJ9gbwEwBritMtESm2gl/Gm9klks8D+B8AVQBeM7PPitYzESmqgofeCjqY3rOLlFxJLqoRkWuHwi4SCYVdJBIKu0gkFHaRSCjsIpFQ2EUiobCLREJhF4mEwi4SCYVdJBIKu0gkFHaRSJR1KWnpGtnlJKW/K+XMxDyPnVWfPn3c+oULF1JrWX+usWPHuvX77rvPrb/++uuptYsXL7ptvX8z7+fSmV0kEgq7SCQUdpFIKOwikVDYRSKhsItEQmEXiYRWl5VMQuP0PXoUfj65fPlywW1DGhsb3XponDykvr7erZ86dSq19uqrr2Y6tlaXFYmcwi4SCYVdJBIKu0gkFHaRSCjsIpFQ2EUiofnsFSA0Vh2S55zz0LFLOVbe1NTk1h966KHUWv/+/d22e/fudev9+vVz6zfeeKNbP3PmjFsvhUxhJ7kfwGkAlwFcMrOJxeiUiBRfMc7s95vZsSI8joiUkN6zi0Qia9gNwJ9JfkJyQVffQHIByRaSLRmPJSIZZH0ZP8XM2kjWAVhP8nMz+7DzN5hZM4BmQBNhRPKU6cxuZm3JbTuAVQAmFaNTIlJ8BYedZDXJft9+DuBHAHYWq2MiUlxZXsbXA1iVjBH3BPDfZrauKL2KTJ7rwr/wwgtuPTTevHHjRrfuzRsfP36823b06NFuvaGhwa23tbWl1g4ePOi2Dc1HHzp0qFtfvny5W9+5s/znxYLDbmb7APxjEfsiIiWkoTeRSCjsIpFQ2EUiobCLREJhF4mElpIug1Jvizx48ODU2uTJk922gwYNcuuhvoeGoOrq6lJrffv2ddu2tra6dW85ZsCfxjpq1Ci37bp1/ijym2++6dazyPr7oqWkRSKnsItEQmEXiYTCLhIJhV0kEgq7SCQUdpFIaCnpRJ8+fdz6pUuXUmuh5ZJD46Kh6ZSzZs1y68OHD0+thaZSnjx50q1XV1e79ZDDhw+n1q5cueK2vfnmm916aDlmb4rsypUr3bahcfaQqqoqt+797KW69kVndpFIKOwikVDYRSKhsItEQmEXiYTCLhIJhV0kEprPXgShMdXQksgTJ/qb3/bs6V8O4W0PvGHDBrftyJEj3fo999zj1kNLMg8cODC1Fvq5Tpw44dZvv/12t/7BBx+k1tasWeO2LfUaBFm26dZ8dhFxKewikVDYRSKhsItEQmEXiYTCLhIJhV0kEprP3k3enPO7777bbRva1jg0L/ubb75x697xly1b5rZ955133PrZs2fdemhOutf3UNvQ89Kjh3+u2rJli1vPolevXm499LOF1kDweGP03hh88MxO8jWS7SR3drpvAMn1JL9Ibmt/aIdFpLy68zL+twBmXnXfIgAbzGwcgA3J1yJSwYJhN7MPARy/6u45AFYkn68A8EiR+yUiRVboe/Z6MzsEAGZ2iGTqhl4kFwBYUOBxRKRISv4HOjNrBtAMXL8TYUSuBYUOvR0hOQQAktv24nVJREqh0LCvATA/+Xw+gLeK0x0RKZXgy3iSbwCYBmAQyVYAiwG8DGAlyacBHAAwt5SdLIfbbrvNrXtrs7/99tuZjj1u3Di3vnbtWrd+xx13pNY+/vhjt+2jjz7q1pcuXerW58yZ49a3b9+eWvv666/dtr1793br3jx+IDwO7wnNGb948WLBjx0SWt9g8uTJqTVvPfxg2M1sXkppeqitiFQOXS4rEgmFXSQSCrtIJBR2kUgo7CKRuKamuBY6tQ8Aamv9iXmhobfVq1e7dY+3nDIQnsLqbXsM+Fsbh4b1Fi3y5zBt2rTJrR84cMCtz549O7V2/vx5t+3x41dPyfhhxx41alRqLbQEdkho+fARI0a49Ycffji1dsstt7htjx07VlC/dGYXiYTCLhIJhV0kEgq7SCQUdpFIKOwikVDYRSJxTY2ze1MWQ0vzetNAgfBU0CxCY/yhqZ79+/d363379k2tXbp0yW0bGkcP+fzzz916U1NTai20ZXN1dbVbb2hocOuhKbKemTOvXmP1u2699Va33tjY6Na93+XQ78unn36aWvOm3urMLhIJhV0kEgq7SCQUdpFIKOwikVDYRSKhsItE4poaZw9tg+tpa2tz6xMmTHDr69atS62F5i7X1aXujgUge9+GDh2aWluyZInbttS85Z5D4+DefHQgvA7AkCFDUmvPPvus23bKlCmZjn3ixAm37v3socd+//33U2unT59OrenMLhIJhV0kEgq7SCQUdpFIKOwikVDYRSKhsItE4poaZw+tDe/Zs2dPpmPff//9qTVvPjkAvPvuu249tGa9Nycc8Ofy19TUuG2zmjp1qlv35mafO3fObXvmzBm3vm/fPrfuXX8wb17a5sQdNm/e7NZDa/kPGDDArXu/M6G1GbwxfK9t8MxO8jWS7SR3drpvCcmDJLclH7NCjyMi+erOy/jfAuhq2Y5fmdmdycfa4nZLRIotGHYz+xCAvw+PiFS8LH+ge57k9uRlfuobM5ILSLaQbMlwLBHJqNCw/xrAWAB3AjgE4Bdp32hmzWY20cwmFngsESmCgsJuZkfM7LKZXQHwGwCTitstESm2gsJOsvPcwR8D2Jn2vSJSGYLj7CTfADANwCCSrQAWA5hG8k4ABmA/AH9y8DUgNA6/d+/e1FqW8X8gPBburTEO+OuIb9u2zW3r7Z8OhNfTnzt3rlvfsWNHaq1fv35u29Ca9944OgDMmDEjtbZw4UK37UcffeTWQ7zrMgDgqaeeSq3t3r3bbXvy5MmC+hQMu5l1dfXB8oKOJiK50eWyIpFQ2EUiobCLREJhF4mEwi4SibJPcSWZWss6hJWF1y8gW98GDx7s1kNDa8OGDXPr3vDWSy+95LZ9/PHH3fr06dPdemgZbG+r7NbWVrdtaEvmBx980K1PmzYttRZa6jmrrVu3uvVnnnkmtRZaSrpQOrOLREJhF4mEwi4SCYVdJBIKu0gkFHaRSCjsIpEo+zh7nmPpnlL2KzQV88svv3Troe2Dx48fn1p75ZVX3LaTJ09266ElkUP19vb21Jq3nTMA3HvvvW79gQcecOveWHrouoqQrL8v3lh6aAntQunMLhIJhV0kEgq7SCQUdpFIKOwikVDYRSKhsItE4prasrlS58J749wAMGbMGLd+9OhRt15fX19w+9B20MOHD3froTHfr776yq17y0E/9thjbttQPXR9Qp8+fVJr58+fd9tmHYc/e/asW7948WJq7eDBg5mOnUZndpFIKOwikVDYRSKhsItEQmEXiYTCLhIJhV0kEtfUOHspx9Jra2vdujcnPbT18Llz59z6sWPH3PqyZcvcujcW3qtXL7dtaJw8NOc8tMb5k08+mVpbvHix29bbJhsI/2zeWHrWcfSsvGsAQltVFyp4Zic5guRGkrtIfkbyZ8n9A0iuJ/lFcuunRURy1Z2X8ZcA/NzMbgdwD4CfkvwHAIsAbDCzcQA2JF+LSIUKht3MDpnZ1uTz0wB2ARgGYA6AFcm3rQDwSKk6KSLZ/aD37CQbAUwAsBlAvZkdAjr+QyBZl9JmAYAF2bopIll1O+wkawD8EcBCMzvV3T9wmFkzgObkMSpztUmRCHRr6I1kL3QE/fdm9qfk7iMkhyT1IQDSlxEVkdwFz+zsOIUvB7DLzH7ZqbQGwHwALye3b3XngFmGPLyht6qqKrdtU1OTW+/Z038qvH6Hhp9C00RD00xPnjzp1r1hnIEDB7ptQz/3hQsX3Hqo7++9915qbfXq1W7bEG+aaEhoGDfr0Fyob97zevjw4UzHTtOdl/FTAPwzgB0ktyX3vYiOkK8k+TSAAwDmlqSHIlIUwbCb2UcA0v6bm17c7ohIqehyWZFIKOwikVDYRSKhsItEQmEXicR1s2VzaOvg0DTU0Fh47969U2s1NTVu21C9b9++bj00Fu5NQ62urnbbXrlyJdOxhw0b5tZDy0F7QmPdlbr9N5BtnD40JbpQOrOLREJhF4mEwi4SCYVdJBIKu0gkFHaRSCjsIpGoqKWks4yrjhw50m3bo4f//1poLNwbbw71OzRWfcMNN7j1EO8agNBc+5C5c/2Zy88991zBj30tj6Nn5f2bHT9+vCTH1JldJBIKu0gkFHaRSCjsIpFQ2EUiobCLREJhF4lEWcfZa2pqcNddd6XWn3jiCbe9N2YcmhN+9uxZt+6NewLAqVOnUmuhbY9DQtcAhDQ0NKTWbrrpJrftmDFj3PrSpUvdektLi1vPe2vkQmUd4w+1b2xsTK1l/X1IfdySPKqIVByFXSQSCrtIJBR2kUgo7CKRUNhFIqGwi0SC3dinegSA3wFoAHAFQLOZ/SfJJQCeAXA0+dYXzWxt4LHcg9XW1rp98cbSx48f77YdNGiQWx81apRbr6urS62FxrL79+/v1kNzzkNj1e3t7am13bt3u21XrVrl1o8ePerWr1dVVVVu/fLly249y7/Z7Nmz3babNm1y62bW5cG7c1HNJQA/N7OtJPsB+ITk+qT2KzP7j248hojkrDv7sx8CcCj5/DTJXQD8bUBEpOL8oPfsJBsBTACwObnreZLbSb5GssvX4CQXkGwh6V9XKSIl1e2wk6wB8EcAC83sFIBfAxgL4E50nPl/0VU7M2s2s4lmNrEI/RWRAnUr7CR7oSPovzezPwGAmR0xs8tmdgXAbwBMKl03RSSrYNjZ8WfF5QB2mdkvO90/pNO3/RjAzuJ3T0SKpTtDb1MB/C+AHegYegOAFwHMQ8dLeAOwH8CzyR/zvMe6ftcGFqkQaUNvwbAXk8IuUnppYdcVdCKRUNhFIqGwi0RCYReJhMIuEgmFXSQSCrtIJBR2kUgo7CKRUNhFIqGwi0RCYReJhMIuEgmFXSQSZd2yGcAxAP/X6etByX2VqFL7Vqn9AtS3QhWzb6lropd1Pvv3Dk62VOradJXat0rtF6C+FapcfdPLeJFIKOwikcg77M05H99TqX2r1H4B6luhytK3XN+zi0j55H1mF5EyUdhFIpFL2EnOJLmb5B6Si/LoQxqS+0nuILkt7/3pkj302knu7HTfAJLrSX6R3Pr7XJe3b0tIHkyeu20kZ+XUtxEkN5LcRfIzkj9L7s/1uXP6VZbnrezv2UlWAfgrgBkAWgFsATDPzP5S1o6kILkfwEQzy/0CDJL3ATgD4Hdm1pTc9+8AjpvZy8l/lLVm9q8V0rclAM7kvY13slvRkM7bjAN4BMC/IMfnzunX4yjD85bHmX0SgD1mts/MLgD4A4A5OfSj4pnZhwCOX3X3HAArks9XoOOXpexS+lYRzOyQmW1NPj8N4NttxnN97px+lUUeYR8G4G+dvm5FZe33bgD+TPITkgvy7kwX6r/dZiu5rcu5P1cLbuNdTldtM14xz10h259nlUfYu9qappLG/6aY2V0A/gnAT5OXq9I93drGu1y62Ga8IhS6/XlWeYS9FcCITl8PB9CWQz+6ZGZtyW07gFWovK2oj3y7g25y255zf/6ukrbx7mqbcVTAc5fn9ud5hH0LgHEkR5PsDeAnANbk0I/vIVmd/OEEJKsB/AiVtxX1GgDzk8/nA3grx758R6Vs4522zThyfu5y3/7czMr+AWAWOv4ivxfAv+XRh5R+jQHwafLxWd59A/AGOl7WXUTHK6KnAQwEsAHAF8ntgArq23+hY2vv7egI1pCc+jYVHW8NtwPYlnzMyvu5c/pVludNl8uKREJX0IlEQmEXiYTCLhIJhV0kEgq7SCQUdpFIKOwikfh/PMG3KVbuIeoAAAAASUVORK5CYII=\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"plt.imshow(img,cmap = plt.cm.gray)\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"这回看着就没问题了是一个完整的图了所以我们还是用csv吧\n",
"\n",
"## 创建网络\n",
"\n",
"三层的简单的CNN网络"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [],
"source": [
"class CNN(NN.Module):\n",
" def __init__(self):\n",
" super(CNN, self).__init__()\n",
" self.layer1 = NN.Sequential( \n",
" NN.Conv2d(1, 16, kernel_size=5, padding=2),\n",
" NN.BatchNorm2d(16), \n",
" NN.ReLU()) #16, 28, 28\n",
" self.pool1=NN.MaxPool2d(2) #16, 14, 14\n",
" self.layer2 = NN.Sequential(\n",
" NN.Conv2d(16, 32, kernel_size=3),\n",
" NN.BatchNorm2d(32),\n",
" NN.ReLU())#32, 12, 12\n",
" self.layer3 = NN.Sequential(\n",
" NN.Conv2d(32, 64, kernel_size=3),\n",
" NN.BatchNorm2d(64),\n",
" NN.ReLU()) #64, 10, 10\n",
" self.pool2=NN.MaxPool2d(2) #64, 5, 5\n",
" self.fc = NN.Linear(5*5*64, 10)\n",
" def forward(self, x):\n",
" out = self.layer1(x)\n",
" #print(out.shape)\n",
" out=self.pool1(out)\n",
" #print(out.shape)\n",
" out = self.layer2(out)\n",
" #print(out.shape)\n",
" out=self.layer3(out)\n",
" #print(out.shape)\n",
" out=self.pool2(out)\n",
" #print(out.shape)\n",
" out = out.view(out.size(0), -1)\n",
" #print(out.shape)\n",
" out = self.fc(out)\n",
" return out"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"以上代码看起来很简单。这里面都是包含的数学的含义。我们只讲pytorch相关的在函数里使用torch.nn提供的模块来定义各个层在每个卷积层后使用了批次的归一化和RELU激活并且在每一个操作分组后面进行了pooling的操作减少信息量避免过拟合后我们使用了全连接层来输出10个类别。\n",
"\n",
"view函数用来改变输出值矩阵的形状来匹配最后一层的维度。"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([[-0.9031, 0.1854, -1.2564, 0.0946, -0.9428, 0.9311, -0.4686, -0.5068,\n",
" -0.3318, -0.6995]], grad_fn=<AddmmBackward>)"
]
},
"execution_count": 18,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"cnn = CNN();\n",
"#可以通过以下方式验证,没报错说明没问题,\n",
"cnn(torch.rand(1,1,28,28))"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CNN(\n",
" (layer1): Sequential(\n",
" (0): Conv2d(1, 16, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))\n",
" (1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (2): ReLU()\n",
" )\n",
" (pool1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
" (layer2): Sequential(\n",
" (0): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1))\n",
" (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (2): ReLU()\n",
" )\n",
" (layer3): Sequential(\n",
" (0): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1))\n",
" (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (2): ReLU()\n",
" )\n",
" (pool2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
" (fc): Linear(in_features=1600, out_features=10, bias=True)\n",
")\n"
]
}
],
"source": [
"#打印下网络,做最后的确认\n",
"print(cnn)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"从定义模型开始就要指定模型计算的位置CPU还是GPU所以需要加另外一个参数"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"cuda\n"
]
}
],
"source": [
"DEVICE=torch.device(\"cpu\")\n",
"if torch.cuda.is_available():\n",
" DEVICE=torch.device(\"cuda\")\n",
"print(DEVICE)"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {},
"outputs": [],
"source": [
"#先把网络放到gpu上\n",
"cnn=cnn.to(DEVICE)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 损失函数\n",
"多分类因为使用Softmax回归将神经网络前向传播得到的结果变成概率分布 所以使用交叉熵损失。\n",
"在pytorch中 \n",
"NN.CrossEntropyLoss 是将 `nn.LogSoftmax()` 和 `nn.NLLLoss()`进行了整合,[CrossEntropyLoss](https://pytorch.org/docs/stable/nn.html#crossentropyloss) ,我们也可以分开来写使用两步计算,这里为了方便直接一步到位\n"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {},
"outputs": [],
"source": [
"#损失函数也需要放到GPU中\n",
"criterion = NN.CrossEntropyLoss().to(DEVICE)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 优化器\n",
"Adam 优化器:简单,暴力,最主要还是懒"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {},
"outputs": [],
"source": [
"#另外一个超参数,学习率\n",
"LEARNING_RATE=0.01"
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {},
"outputs": [],
"source": [
"#优化器不需要放GPU\n",
"optimizer = torch.optim.Adam(cnn.parameters(), lr=LEARNING_RATE)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 开始训练"
]
},
{
"cell_type": "code",
"execution_count": 25,
"metadata": {},
"outputs": [],
"source": [
"#另外一个超参数,指定训练批次\n",
"TOTAL_EPOCHS=50"
]
},
{
"cell_type": "code",
"execution_count": 26,
"metadata": {
"scrolled": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch : 1/50, Iter : 100/234, Loss: 0.4569\n",
"Epoch : 1/50, Iter : 200/234, Loss: 0.3623\n",
"Epoch : 2/50, Iter : 100/234, Loss: 0.2648\n",
"Epoch : 2/50, Iter : 200/234, Loss: 0.3044\n",
"Epoch : 3/50, Iter : 100/234, Loss: 0.2107\n",
"Epoch : 3/50, Iter : 200/234, Loss: 0.3022\n",
"Epoch : 4/50, Iter : 100/234, Loss: 0.2583\n",
"Epoch : 4/50, Iter : 200/234, Loss: 0.2837\n",
"Epoch : 5/50, Iter : 100/234, Loss: 0.2377\n",
"Epoch : 5/50, Iter : 200/234, Loss: 0.2422\n",
"Epoch : 6/50, Iter : 100/234, Loss: 0.1537\n",
"Epoch : 6/50, Iter : 200/234, Loss: 0.2270\n",
"Epoch : 7/50, Iter : 100/234, Loss: 0.1485\n",
"Epoch : 7/50, Iter : 200/234, Loss: 0.1740\n",
"Epoch : 8/50, Iter : 100/234, Loss: 0.3264\n",
"Epoch : 8/50, Iter : 200/234, Loss: 0.2096\n",
"Epoch : 9/50, Iter : 100/234, Loss: 0.1844\n",
"Epoch : 9/50, Iter : 200/234, Loss: 0.1927\n",
"Epoch : 10/50, Iter : 100/234, Loss: 0.1343\n",
"Epoch : 10/50, Iter : 200/234, Loss: 0.2225\n",
"Epoch : 11/50, Iter : 100/234, Loss: 0.1251\n",
"Epoch : 11/50, Iter : 200/234, Loss: 0.1789\n",
"Epoch : 12/50, Iter : 100/234, Loss: 0.1439\n",
"Epoch : 12/50, Iter : 200/234, Loss: 0.1290\n",
"Epoch : 13/50, Iter : 100/234, Loss: 0.2017\n",
"Epoch : 13/50, Iter : 200/234, Loss: 0.1130\n",
"Epoch : 14/50, Iter : 100/234, Loss: 0.0992\n",
"Epoch : 14/50, Iter : 200/234, Loss: 0.1736\n",
"Epoch : 15/50, Iter : 100/234, Loss: 0.0920\n",
"Epoch : 15/50, Iter : 200/234, Loss: 0.1557\n",
"Epoch : 16/50, Iter : 100/234, Loss: 0.0914\n",
"Epoch : 16/50, Iter : 200/234, Loss: 0.1508\n",
"Epoch : 17/50, Iter : 100/234, Loss: 0.1273\n",
"Epoch : 17/50, Iter : 200/234, Loss: 0.1982\n",
"Epoch : 18/50, Iter : 100/234, Loss: 0.1752\n",
"Epoch : 18/50, Iter : 200/234, Loss: 0.1517\n",
"Epoch : 19/50, Iter : 100/234, Loss: 0.0586\n",
"Epoch : 19/50, Iter : 200/234, Loss: 0.0984\n",
"Epoch : 20/50, Iter : 100/234, Loss: 0.1409\n",
"Epoch : 20/50, Iter : 200/234, Loss: 0.1286\n",
"Epoch : 21/50, Iter : 100/234, Loss: 0.0900\n",
"Epoch : 21/50, Iter : 200/234, Loss: 0.1168\n",
"Epoch : 22/50, Iter : 100/234, Loss: 0.0755\n",
"Epoch : 22/50, Iter : 200/234, Loss: 0.1217\n",
"Epoch : 23/50, Iter : 100/234, Loss: 0.0703\n",
"Epoch : 23/50, Iter : 200/234, Loss: 0.1383\n",
"Epoch : 24/50, Iter : 100/234, Loss: 0.0916\n",
"Epoch : 24/50, Iter : 200/234, Loss: 0.0685\n",
"Epoch : 25/50, Iter : 100/234, Loss: 0.0947\n",
"Epoch : 25/50, Iter : 200/234, Loss: 0.1244\n",
"Epoch : 26/50, Iter : 100/234, Loss: 0.0615\n",
"Epoch : 26/50, Iter : 200/234, Loss: 0.0478\n",
"Epoch : 27/50, Iter : 100/234, Loss: 0.0280\n",
"Epoch : 27/50, Iter : 200/234, Loss: 0.0459\n",
"Epoch : 28/50, Iter : 100/234, Loss: 0.0213\n",
"Epoch : 28/50, Iter : 200/234, Loss: 0.0764\n",
"Epoch : 29/50, Iter : 100/234, Loss: 0.0391\n",
"Epoch : 29/50, Iter : 200/234, Loss: 0.0899\n",
"Epoch : 30/50, Iter : 100/234, Loss: 0.0541\n",
"Epoch : 30/50, Iter : 200/234, Loss: 0.0750\n",
"Epoch : 31/50, Iter : 100/234, Loss: 0.0605\n",
"Epoch : 31/50, Iter : 200/234, Loss: 0.0766\n",
"Epoch : 32/50, Iter : 100/234, Loss: 0.1368\n",
"Epoch : 32/50, Iter : 200/234, Loss: 0.0588\n",
"Epoch : 33/50, Iter : 100/234, Loss: 0.0253\n",
"Epoch : 33/50, Iter : 200/234, Loss: 0.0705\n",
"Epoch : 34/50, Iter : 100/234, Loss: 0.0248\n",
"Epoch : 34/50, Iter : 200/234, Loss: 0.0751\n",
"Epoch : 35/50, Iter : 100/234, Loss: 0.0449\n",
"Epoch : 35/50, Iter : 200/234, Loss: 0.1006\n",
"Epoch : 36/50, Iter : 100/234, Loss: 0.0281\n",
"Epoch : 36/50, Iter : 200/234, Loss: 0.0418\n",
"Epoch : 37/50, Iter : 100/234, Loss: 0.0547\n",
"Epoch : 37/50, Iter : 200/234, Loss: 0.1003\n",
"Epoch : 38/50, Iter : 100/234, Loss: 0.0694\n",
"Epoch : 38/50, Iter : 200/234, Loss: 0.0340\n",
"Epoch : 39/50, Iter : 100/234, Loss: 0.0620\n",
"Epoch : 39/50, Iter : 200/234, Loss: 0.1004\n",
"Epoch : 40/50, Iter : 100/234, Loss: 0.0588\n",
"Epoch : 40/50, Iter : 200/234, Loss: 0.0309\n",
"Epoch : 41/50, Iter : 100/234, Loss: 0.0387\n",
"Epoch : 41/50, Iter : 200/234, Loss: 0.0136\n",
"Epoch : 42/50, Iter : 100/234, Loss: 0.0149\n",
"Epoch : 42/50, Iter : 200/234, Loss: 0.0448\n",
"Epoch : 43/50, Iter : 100/234, Loss: 0.0076\n",
"Epoch : 43/50, Iter : 200/234, Loss: 0.0593\n",
"Epoch : 44/50, Iter : 100/234, Loss: 0.0267\n",
"Epoch : 44/50, Iter : 200/234, Loss: 0.0308\n",
"Epoch : 45/50, Iter : 100/234, Loss: 0.0150\n",
"Epoch : 45/50, Iter : 200/234, Loss: 0.0764\n",
"Epoch : 46/50, Iter : 100/234, Loss: 0.0221\n",
"Epoch : 46/50, Iter : 200/234, Loss: 0.0325\n",
"Epoch : 47/50, Iter : 100/234, Loss: 0.0190\n",
"Epoch : 47/50, Iter : 200/234, Loss: 0.0359\n",
"Epoch : 48/50, Iter : 100/234, Loss: 0.0256\n",
"Epoch : 48/50, Iter : 200/234, Loss: 0.0374\n",
"Epoch : 49/50, Iter : 100/234, Loss: 0.0198\n",
"Epoch : 49/50, Iter : 200/234, Loss: 0.0300\n",
"Epoch : 50/50, Iter : 100/234, Loss: 0.0465\n",
"Epoch : 50/50, Iter : 200/234, Loss: 0.0558\n",
"Wall time: 7min 18s\n"
]
}
],
"source": [
"%%time\n",
"#记录损失函数\n",
"losses = [];\n",
"for epoch in range(TOTAL_EPOCHS):\n",
" for i, (images, labels) in enumerate(train_loader):\n",
" images = images.float().to(DEVICE)\n",
" labels = labels.to(DEVICE)\n",
" #清零\n",
" optimizer.zero_grad()\n",
" outputs = cnn(images)\n",
" #计算损失函数\n",
" loss = criterion(outputs, labels)\n",
" loss.backward()\n",
" optimizer.step()\n",
" losses.append(loss.cpu().data.item());\n",
" if (i+1) % 100 == 0:\n",
" print ('Epoch : %d/%d, Iter : %d/%d, Loss: %.4f'%(epoch+1, TOTAL_EPOCHS, i+1, len(train_dataset)//BATCH_SIZE, loss.data.item()))\n",
" "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 训练后操作\n",
"### 可视化损失函数"
]
},
{
"cell_type": "code",
"execution_count": 28,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"plt.xkcd();\n",
"plt.xlabel('Epoch #');\n",
"plt.ylabel('Loss');\n",
"plt.plot(losses);\n",
"plt.show();"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 保存模型 "
]
},
{
"cell_type": "code",
"execution_count": 29,
"metadata": {},
"outputs": [],
"source": [
"torch.save(cnn.state_dict(), \"fm-cnn3.pth\")\n",
"# 加载用这个\n",
"#cnn.load_state_dict(torch.load(\"fm-cnn3.pth\"))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 模型评估\n",
"\n",
"模型评估就是使用测试集对模型进行的评估,应该是添加到训练中进行了,这里为了方便说明直接在训练完成后评估了"
]
},
{
"cell_type": "code",
"execution_count": 30,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"准确率: 90.0000 %\n"
]
}
],
"source": [
"cnn.eval()\n",
"correct = 0\n",
"total = 0\n",
"for images, labels in test_loader:\n",
" images = images.float().to(DEVICE)\n",
" outputs = cnn(images).cpu()\n",
" _, predicted = torch.max(outputs.data, 1)\n",
" total += labels.size(0)\n",
" correct += (predicted == labels).sum()\n",
"print('准确率: %.4f %%' % (100 * correct / total))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"模型评估的步骤如下:\n",
"1. 将网络的模式改为eval。\n",
"2. 将图片输入到网络中得到输出。\n",
"3. 通过取出one-hot输出的最大值来得到输出的 标签。\n",
"4. 统计正确的预测值。"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 进一步优化"
]
},
{
"cell_type": "code",
"execution_count": 31,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch : 1/20, Iter : 100/234, Loss: 0.0096\n",
"Epoch : 1/20, Iter : 200/234, Loss: 0.0124\n",
"Epoch : 2/20, Iter : 100/234, Loss: 0.0031\n",
"Epoch : 2/20, Iter : 200/234, Loss: 0.0020\n",
"Epoch : 3/20, Iter : 100/234, Loss: 0.0013\n",
"Epoch : 3/20, Iter : 200/234, Loss: 0.0041\n",
"Epoch : 4/20, Iter : 100/234, Loss: 0.0016\n",
"Epoch : 4/20, Iter : 200/234, Loss: 0.0023\n",
"Epoch : 5/20, Iter : 100/234, Loss: 0.0010\n",
"Epoch : 5/20, Iter : 200/234, Loss: 0.0008\n",
"Epoch : 6/20, Iter : 100/234, Loss: 0.0017\n",
"Epoch : 6/20, Iter : 200/234, Loss: 0.0010\n",
"Epoch : 7/20, Iter : 100/234, Loss: 0.0009\n",
"Epoch : 7/20, Iter : 200/234, Loss: 0.0009\n",
"Epoch : 8/20, Iter : 100/234, Loss: 0.0005\n",
"Epoch : 8/20, Iter : 200/234, Loss: 0.0008\n",
"Epoch : 9/20, Iter : 100/234, Loss: 0.0005\n",
"Epoch : 9/20, Iter : 200/234, Loss: 0.0006\n",
"Epoch : 10/20, Iter : 100/234, Loss: 0.0016\n",
"Epoch : 10/20, Iter : 200/234, Loss: 0.0011\n",
"Epoch : 11/20, Iter : 100/234, Loss: 0.0003\n",
"Epoch : 11/20, Iter : 200/234, Loss: 0.0009\n",
"Epoch : 12/20, Iter : 100/234, Loss: 0.0010\n",
"Epoch : 12/20, Iter : 200/234, Loss: 0.0002\n",
"Epoch : 13/20, Iter : 100/234, Loss: 0.0004\n",
"Epoch : 13/20, Iter : 200/234, Loss: 0.0005\n",
"Epoch : 14/20, Iter : 100/234, Loss: 0.0003\n",
"Epoch : 14/20, Iter : 200/234, Loss: 0.0004\n",
"Epoch : 15/20, Iter : 100/234, Loss: 0.0002\n",
"Epoch : 15/20, Iter : 200/234, Loss: 0.0005\n",
"Epoch : 16/20, Iter : 100/234, Loss: 0.0002\n",
"Epoch : 16/20, Iter : 200/234, Loss: 0.0007\n",
"Epoch : 17/20, Iter : 100/234, Loss: 0.0003\n",
"Epoch : 17/20, Iter : 200/234, Loss: 0.0002\n",
"Epoch : 18/20, Iter : 100/234, Loss: 0.0004\n",
"Epoch : 18/20, Iter : 200/234, Loss: 0.0001\n",
"Epoch : 19/20, Iter : 100/234, Loss: 0.0003\n",
"Epoch : 19/20, Iter : 200/234, Loss: 0.0005\n",
"Epoch : 20/20, Iter : 100/234, Loss: 0.0002\n",
"Epoch : 20/20, Iter : 200/234, Loss: 0.0002\n",
"Wall time: 2min 21s\n"
]
}
],
"source": [
"%%time\n",
"#修改学习率和批次\n",
"cnn.train()\n",
"LEARNING_RATE=LEARNING_RATE / 10\n",
"TOTAL_EPOCHS=20\n",
"optimizer = torch.optim.Adam(cnn.parameters(), lr=0.001)\n",
"losses = [];\n",
"for epoch in range(TOTAL_EPOCHS):\n",
" for i, (images, labels) in enumerate(train_loader):\n",
" images = images.float().to(DEVICE)\n",
" labels = labels.to(DEVICE)\n",
" #清零\n",
" optimizer.zero_grad()\n",
" outputs = cnn(images)\n",
" #计算损失函数\n",
" #损失函数直接放到CPU中因为还有其他的计算\n",
" loss = criterion(outputs, labels).cpu()\n",
" loss.backward()\n",
" optimizer.step()\n",
" losses.append(loss.data.item());\n",
" if (i+1) % 100 == 0:\n",
" print ('Epoch : %d/%d, Iter : %d/%d, Loss: %.4f'%(epoch+1, TOTAL_EPOCHS, i+1, len(train_dataset)//BATCH_SIZE, loss.data.item()))\n",
" "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"可视化一下损失"
]
},
{
"cell_type": "code",
"execution_count": 32,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"plt.xkcd();\n",
"plt.xlabel('Epoch #');\n",
"plt.ylabel('Loss');\n",
"plt.plot(losses);\n",
"plt.show();"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 再次进行评估"
]
},
{
"cell_type": "code",
"execution_count": 33,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"准确率: 91.0000 %\n"
]
}
],
"source": [
"cnn.eval()\n",
"correct = 0\n",
"total = 0\n",
"for images, labels in test_loader:\n",
" images = images.float().to(DEVICE)\n",
" outputs = cnn(images).cpu()\n",
" _, predicted = torch.max(outputs.data, 1)\n",
" total += labels.size(0)\n",
" correct += (predicted == labels).sum()\n",
"print('准确率: %.4f %%' % (100 * correct / total))"
]
},
{
"cell_type": "code",
"execution_count": 34,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch : 1/10, Iter : 100/234, Loss: 0.0002\n",
"Epoch : 1/10, Iter : 200/234, Loss: 0.0001\n",
"Epoch : 2/10, Iter : 100/234, Loss: 0.0001\n",
"Epoch : 2/10, Iter : 200/234, Loss: 0.0005\n",
"Epoch : 3/10, Iter : 100/234, Loss: 0.0002\n",
"Epoch : 3/10, Iter : 200/234, Loss: 0.0001\n",
"Epoch : 4/10, Iter : 100/234, Loss: 0.0003\n",
"Epoch : 4/10, Iter : 200/234, Loss: 0.0001\n",
"Epoch : 5/10, Iter : 100/234, Loss: 0.0002\n",
"Epoch : 5/10, Iter : 200/234, Loss: 0.0003\n",
"Epoch : 6/10, Iter : 100/234, Loss: 0.0002\n",
"Epoch : 6/10, Iter : 200/234, Loss: 0.0002\n",
"Epoch : 7/10, Iter : 100/234, Loss: 0.0001\n",
"Epoch : 7/10, Iter : 200/234, Loss: 0.0002\n",
"Epoch : 8/10, Iter : 100/234, Loss: 0.0008\n",
"Epoch : 8/10, Iter : 200/234, Loss: 0.0008\n",
"Epoch : 9/10, Iter : 100/234, Loss: 0.0005\n",
"Epoch : 9/10, Iter : 200/234, Loss: 0.0002\n",
"Epoch : 10/10, Iter : 100/234, Loss: 0.0006\n",
"Epoch : 10/10, Iter : 200/234, Loss: 0.0002\n",
"Wall time: 1min 9s\n"
]
}
],
"source": [
"%%time\n",
"#修改学习率和批次\n",
"cnn.train()\n",
"LEARNING_RATE=LEARNING_RATE / 10\n",
"TOTAL_EPOCHS=10\n",
"optimizer = torch.optim.Adam(cnn.parameters(), lr=0.001)\n",
"losses = [];\n",
"for epoch in range(TOTAL_EPOCHS):\n",
" for i, (images, labels) in enumerate(train_loader):\n",
" images = images.float().to(DEVICE)\n",
" labels = labels.to(DEVICE)\n",
" #清零\n",
" optimizer.zero_grad()\n",
" outputs = cnn(images)\n",
" #计算损失函数\n",
" #损失函数直接放到CPU中因为还有其他的计算\n",
" loss = criterion(outputs, labels)\n",
" loss.backward()\n",
" optimizer.step()\n",
" losses.append(loss.cpu().data.item());\n",
" if (i+1) % 100 == 0:\n",
" print ('Epoch : %d/%d, Iter : %d/%d, Loss: %.4f'%(epoch+1, TOTAL_EPOCHS, i+1, len(train_dataset)//BATCH_SIZE, loss.data.item()))\n",
" "
]
},
{
"cell_type": "code",
"execution_count": 35,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"plt.xkcd();\n",
"plt.xlabel('Epoch #');\n",
"plt.ylabel('Loss');\n",
"plt.plot(losses);\n",
"plt.show();"
]
},
{
"cell_type": "code",
"execution_count": 36,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"准确率: 91.0000 %\n"
]
}
],
"source": [
"cnn.eval()\n",
"correct = 0\n",
"total = 0\n",
"for images, labels in test_loader:\n",
" images = images.float().to(DEVICE)\n",
" outputs = cnn(images).cpu()\n",
" _, predicted = torch.max(outputs.data, 1)\n",
" total += labels.size(0)\n",
" correct += (predicted == labels).sum()\n",
"print('准确率: %.4f %%' % (100 * correct / total))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"损失小了,但是准确率没有提高,这就说明已经接近模型的瓶颈了,如果再要进行优化,就需要修改模型了。另外还有一个判断模型是否到瓶颈的标准,就是看损失函数,最后一次的训练的损失函数明显的没有下降的趋势,只是在震荡,这说明已经没有什么优化的空间了。\n",
"\n",
"通过简单的操作我们也能够看到Adam优化器的暴力性我们只要简单的修改学习率就能够达到优化的效果Adam优化器的使用一般情况下是首先使用0.1进行预热然后再用0.01进行大批次的训练最后使用0.001这个学习率进行收尾,再小的学习率一般情况就不需要了。\n",
"\n",
"## 总结\n",
"最后我们再总结一下几个超参数:\n",
"\n",
"`BATCH_SIZE`: 批次数量定义每次训练时多少数据作为一批这个批次需要在dataloader初始化时进行设置并且需要这对模型和显存进行配置如果出现OOM有线减小一般设为2的倍数\n",
"\n",
"`DEVICE`进行计算的设备主要是CPU还是GPU\n",
"\n",
"`LEARNING_RATE`:学习率,反向传播时使用\n",
"\n",
"`TOTAL_EPOCHS`:训练的批次,一般情况下会根据损失和准确率等阈值\n",
"\n",
"其实优化器和损失函数也算超参数,这里就不说了"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.3"
},
"toc": {
"base_numbering": 1,
"nav_menu": {},
"number_sections": true,
"sideBar": true,
"skip_h1_title": false,
"title_cell": "Table of Contents",
"title_sidebar": "Contents",
"toc_cell": false,
"toc_position": {},
"toc_section_display": true,
"toc_window_display": false
}
},
"nbformat": 4,
"nbformat_minor": 2
}