pytorch-handbook/chapter3/3.2-mnist.ipynb
2019-11-19 16:07:27 +08:00

433 lines
16 KiB
Plaintext
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'1.0.0'"
]
},
"execution_count": 1,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import torch\n",
"import torch.nn as nn\n",
"import torch.nn.functional as F\n",
"import torch.optim as optim\n",
"from torchvision import datasets, transforms\n",
"torch.__version__"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 3.2 MNIST数据集手写数字识别\n",
"\n",
"## 3.2.1 数据集介绍\n",
"MNIST 包括6万张28x28的训练样本1万张测试样本很多教程都会对它”下手”几乎成为一个 “典范”可以说它就是计算机视觉里面的Hello World。所以我们这里也会使用MNIST来进行实战。\n",
"\n",
"前面在介绍卷积神经网络的时候说到过LeNet-5LeNet-5之所以强大就是因为在当时的环境下将MNIST数据的识别率提高到了99%这里我们也自己从头搭建一个卷积神经网络也达到99%的准确率"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 3.2.2 手写数字识别\n",
"首先,我们定义一些超参数"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"BATCH_SIZE=512 #大概需要2G的显存\n",
"EPOCHS=20 # 总共训练批次\n",
"DEVICE = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\") # 让torch判断是否使用GPU建议使用GPU环境因为会快很多"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"因为Pytorch里面包含了MNIST的数据集所以我们这里直接使用即可。\n",
"如果第一次执行会生成data文件夹并且需要一些时间下载如果以前下载过就不会再次下载了\n",
"\n",
"由于官方已经实现了dataset所以这里可以直接使用DataLoader来对数据进行读取"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz\n",
"Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz\n",
"Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz\n",
"Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz\n",
"Processing...\n",
"Done!\n"
]
}
],
"source": [
"train_loader = torch.utils.data.DataLoader(\n",
" datasets.MNIST('data', train=True, download=True, \n",
" transform=transforms.Compose([\n",
" transforms.ToTensor(),\n",
" transforms.Normalize((0.1307,), (0.3081,))\n",
" ])),\n",
" batch_size=BATCH_SIZE, shuffle=True)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"测试集"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"test_loader = torch.utils.data.DataLoader(\n",
" datasets.MNIST('data', train=False, transform=transforms.Compose([\n",
" transforms.ToTensor(),\n",
" transforms.Normalize((0.1307,), (0.3081,))\n",
" ])),\n",
" batch_size=BATCH_SIZE, shuffle=True)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"下面我们定义一个网络网络包含两个卷积层conv1和conv2然后紧接着两个线性层作为输出最后输出10个维度这10个维度我们作为0-9的标识来确定识别出的是那个数字\n",
"\n",
"在这里建议大家将每一层的输入和输出维度都作为注释标注出来,这样后面阅读代码的会方便很多"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"class ConvNet(nn.Module):\n",
" def __init__(self):\n",
" super().__init__()\n",
" # batch*1*28*28每次会送入batch个样本输入通道数1黑白图像图像分辨率是28x28\n",
" # 下面的卷积层Conv2d的第一个参数指输入通道数第二个参数指输出通道数第三个参数指卷积核的大小\n",
" self.conv1 = nn.Conv2d(1, 10, 5) # 输入通道数1输出通道数10核的大小5\n",
" self.conv2 = nn.Conv2d(10, 20, 3) # 输入通道数10输出通道数20核的大小3\n",
" # 下面的全连接层Linear的第一个参数指输入通道数第二个参数指输出通道数\n",
" self.fc1 = nn.Linear(20*10*10, 500) # 输入通道数是2000输出通道数是500\n",
" self.fc2 = nn.Linear(500, 10) # 输入通道数是500输出通道数是10即10分类\n",
" def forward(self,x):\n",
" in_size = x.size(0) # 在本例中in_size=512也就是BATCH_SIZE的值。输入的x可以看成是512*1*28*28的张量。\n",
" out = self.conv1(x) # batch*1*28*28 -> batch*10*24*2428x28的图像经过一次核为5x5的卷积输出变为24x24\n",
" out = F.relu(out) # batch*10*24*24激活函数ReLU不改变形状\n",
" out = F.max_pool2d(out, 2, 2) # batch*10*24*24 -> batch*10*12*122*2的池化层会减半\n",
" out = self.conv2(out) # batch*10*12*12 -> batch*20*10*10再卷积一次核的大小是3\n",
" out = F.relu(out) # batch*20*10*10\n",
" out = out.view(in_size, -1) # batch*20*10*10 -> batch*2000out的第二维是-1说明是自动推算本例中第二维是20*10*10\n",
" out = self.fc1(out) # batch*2000 -> batch*500\n",
" out = F.relu(out) # batch*500\n",
" out = self.fc2(out) # batch*500 -> batch*10\n",
" out = F.log_softmax(out, dim=1) # 计算log(softmax(x))\n",
" return out"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"我们实例化一个网络,实例化后使用.to方法将网络移动到GPU\n",
"\n",
"优化器我们也直接选择简单暴力的Adam"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"model = ConvNet().to(DEVICE)\n",
"optimizer = optim.Adam(model.parameters())"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"下面定义一下训练的函数,我们将训练的所有操作都封装到这个函数中"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"def train(model, device, train_loader, optimizer, epoch):\n",
" model.train()\n",
" for batch_idx, (data, target) in enumerate(train_loader):\n",
" data, target = data.to(device), target.to(device)\n",
" optimizer.zero_grad()\n",
" output = model(data)\n",
" loss = F.nll_loss(output, target)\n",
" loss.backward()\n",
" optimizer.step()\n",
" if(batch_idx+1)%30 == 0: \n",
" print('Train Epoch: {} [{}/{} ({:.0f}%)]\\tLoss: {:.6f}'.format(\n",
" epoch, batch_idx * len(data), len(train_loader.dataset),\n",
" 100. * batch_idx / len(train_loader), loss.item()))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"测试的操作也一样封装成一个函数"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"def test(model, device, test_loader):\n",
" model.eval()\n",
" test_loss = 0\n",
" correct = 0\n",
" with torch.no_grad():\n",
" for data, target in test_loader:\n",
" data, target = data.to(device), target.to(device)\n",
" output = model(data)\n",
" test_loss += F.nll_loss(output, target, reduction='sum').item() # 将一批的损失相加\n",
" pred = output.max(1, keepdim=True)[1] # 找到概率最大的下标\n",
" correct += pred.eq(target.view_as(pred)).sum().item()\n",
"\n",
" test_loss /= len(test_loader.dataset)\n",
" print('\\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\\n'.format(\n",
" test_loss, correct, len(test_loader.dataset),\n",
" 100. * correct / len(test_loader.dataset)))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"下面开始训练,这里就体现出封装起来的好处了,只要写两行就可以了"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Train Epoch: 1 [14848/60000 (25%)]\tLoss: 0.272529\n",
"Train Epoch: 1 [30208/60000 (50%)]\tLoss: 0.235455\n",
"Train Epoch: 1 [45568/60000 (75%)]\tLoss: 0.101858\n",
"\n",
"Test set: Average loss: 0.1018, Accuracy: 9695/10000 (97%)\n",
"\n",
"Train Epoch: 2 [14848/60000 (25%)]\tLoss: 0.057989\n",
"Train Epoch: 2 [30208/60000 (50%)]\tLoss: 0.083935\n",
"Train Epoch: 2 [45568/60000 (75%)]\tLoss: 0.051921\n",
"\n",
"Test set: Average loss: 0.0523, Accuracy: 9825/10000 (98%)\n",
"\n",
"Train Epoch: 3 [14848/60000 (25%)]\tLoss: 0.045383\n",
"Train Epoch: 3 [30208/60000 (50%)]\tLoss: 0.049402\n",
"Train Epoch: 3 [45568/60000 (75%)]\tLoss: 0.061366\n",
"\n",
"Test set: Average loss: 0.0408, Accuracy: 9866/10000 (99%)\n",
"\n",
"Train Epoch: 4 [14848/60000 (25%)]\tLoss: 0.035253\n",
"Train Epoch: 4 [30208/60000 (50%)]\tLoss: 0.038444\n",
"Train Epoch: 4 [45568/60000 (75%)]\tLoss: 0.036877\n",
"\n",
"Test set: Average loss: 0.0433, Accuracy: 9859/10000 (99%)\n",
"\n",
"Train Epoch: 5 [14848/60000 (25%)]\tLoss: 0.038996\n",
"Train Epoch: 5 [30208/60000 (50%)]\tLoss: 0.020670\n",
"Train Epoch: 5 [45568/60000 (75%)]\tLoss: 0.034658\n",
"\n",
"Test set: Average loss: 0.0339, Accuracy: 9885/10000 (99%)\n",
"\n",
"Train Epoch: 6 [14848/60000 (25%)]\tLoss: 0.067320\n",
"Train Epoch: 6 [30208/60000 (50%)]\tLoss: 0.016328\n",
"Train Epoch: 6 [45568/60000 (75%)]\tLoss: 0.017037\n",
"\n",
"Test set: Average loss: 0.0348, Accuracy: 9881/10000 (99%)\n",
"\n",
"Train Epoch: 7 [14848/60000 (25%)]\tLoss: 0.022150\n",
"Train Epoch: 7 [30208/60000 (50%)]\tLoss: 0.009608\n",
"Train Epoch: 7 [45568/60000 (75%)]\tLoss: 0.012742\n",
"\n",
"Test set: Average loss: 0.0346, Accuracy: 9895/10000 (99%)\n",
"\n",
"Train Epoch: 8 [14848/60000 (25%)]\tLoss: 0.010173\n",
"Train Epoch: 8 [30208/60000 (50%)]\tLoss: 0.019482\n",
"Train Epoch: 8 [45568/60000 (75%)]\tLoss: 0.012159\n",
"\n",
"Test set: Average loss: 0.0323, Accuracy: 9886/10000 (99%)\n",
"\n",
"Train Epoch: 9 [14848/60000 (25%)]\tLoss: 0.007792\n",
"Train Epoch: 9 [30208/60000 (50%)]\tLoss: 0.006970\n",
"Train Epoch: 9 [45568/60000 (75%)]\tLoss: 0.004989\n",
"\n",
"Test set: Average loss: 0.0294, Accuracy: 9909/10000 (99%)\n",
"\n",
"Train Epoch: 10 [14848/60000 (25%)]\tLoss: 0.003764\n",
"Train Epoch: 10 [30208/60000 (50%)]\tLoss: 0.005944\n",
"Train Epoch: 10 [45568/60000 (75%)]\tLoss: 0.001866\n",
"\n",
"Test set: Average loss: 0.0361, Accuracy: 9902/10000 (99%)\n",
"\n",
"Train Epoch: 11 [14848/60000 (25%)]\tLoss: 0.002737\n",
"Train Epoch: 11 [30208/60000 (50%)]\tLoss: 0.014134\n",
"Train Epoch: 11 [45568/60000 (75%)]\tLoss: 0.001365\n",
"\n",
"Test set: Average loss: 0.0309, Accuracy: 9905/10000 (99%)\n",
"\n",
"Train Epoch: 12 [14848/60000 (25%)]\tLoss: 0.003344\n",
"Train Epoch: 12 [30208/60000 (50%)]\tLoss: 0.003090\n",
"Train Epoch: 12 [45568/60000 (75%)]\tLoss: 0.004847\n",
"\n",
"Test set: Average loss: 0.0318, Accuracy: 9902/10000 (99%)\n",
"\n",
"Train Epoch: 13 [14848/60000 (25%)]\tLoss: 0.001278\n",
"Train Epoch: 13 [30208/60000 (50%)]\tLoss: 0.003016\n",
"Train Epoch: 13 [45568/60000 (75%)]\tLoss: 0.001328\n",
"\n",
"Test set: Average loss: 0.0358, Accuracy: 9906/10000 (99%)\n",
"\n",
"Train Epoch: 14 [14848/60000 (25%)]\tLoss: 0.002219\n",
"Train Epoch: 14 [30208/60000 (50%)]\tLoss: 0.003487\n",
"Train Epoch: 14 [45568/60000 (75%)]\tLoss: 0.014429\n",
"\n",
"Test set: Average loss: 0.0376, Accuracy: 9896/10000 (99%)\n",
"\n",
"Train Epoch: 15 [14848/60000 (25%)]\tLoss: 0.003042\n",
"Train Epoch: 15 [30208/60000 (50%)]\tLoss: 0.002974\n",
"Train Epoch: 15 [45568/60000 (75%)]\tLoss: 0.000871\n",
"\n",
"Test set: Average loss: 0.0346, Accuracy: 9909/10000 (99%)\n",
"\n",
"Train Epoch: 16 [14848/60000 (25%)]\tLoss: 0.000618\n",
"Train Epoch: 16 [30208/60000 (50%)]\tLoss: 0.003164\n",
"Train Epoch: 16 [45568/60000 (75%)]\tLoss: 0.007245\n",
"\n",
"Test set: Average loss: 0.0357, Accuracy: 9905/10000 (99%)\n",
"\n",
"Train Epoch: 17 [14848/60000 (25%)]\tLoss: 0.001874\n",
"Train Epoch: 17 [30208/60000 (50%)]\tLoss: 0.013951\n",
"Train Epoch: 17 [45568/60000 (75%)]\tLoss: 0.000729\n",
"\n",
"Test set: Average loss: 0.0322, Accuracy: 9922/10000 (99%)\n",
"\n",
"Train Epoch: 18 [14848/60000 (25%)]\tLoss: 0.002581\n",
"Train Epoch: 18 [30208/60000 (50%)]\tLoss: 0.001396\n",
"Train Epoch: 18 [45568/60000 (75%)]\tLoss: 0.015521\n",
"\n",
"Test set: Average loss: 0.0389, Accuracy: 9914/10000 (99%)\n",
"\n",
"Train Epoch: 19 [14848/60000 (25%)]\tLoss: 0.000283\n",
"Train Epoch: 19 [30208/60000 (50%)]\tLoss: 0.001385\n",
"Train Epoch: 19 [45568/60000 (75%)]\tLoss: 0.011184\n",
"\n",
"Test set: Average loss: 0.0383, Accuracy: 9901/10000 (99%)\n",
"\n",
"Train Epoch: 20 [14848/60000 (25%)]\tLoss: 0.000472\n",
"Train Epoch: 20 [30208/60000 (50%)]\tLoss: 0.003306\n",
"Train Epoch: 20 [45568/60000 (75%)]\tLoss: 0.018017\n",
"\n",
"Test set: Average loss: 0.0393, Accuracy: 9899/10000 (99%)\n",
"\n"
]
}
],
"source": [
"for epoch in range(1, EPOCHS + 1):\n",
" train(model, DEVICE, train_loader, optimizer, epoch)\n",
" test(model, DEVICE, test_loader)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"我们看一下结果准确率99%,没问题"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"如果你的模型连MNIST都搞不定那么你的模型没有任何的价值\n",
"\n",
"即使你的模型搞定了MNIST你的模型也可能没有任何的价值\n",
"\n",
"MNIST是一个很简单的数据集由于它的局限性只能作为研究用途对实际应用带来的价值非常有限。但是通过这个例子我们可以完全了解一个实际项目的工作流程\n",
"\n",
"我们找到数据集,对数据做预处理,定义我们的模型,调整超参数,测试训练,再通过训练结果对超参数进行调整或者对模型进行调整。\n",
"\n",
"并且通过这个实战我们已经有了一个很好的模板,以后的项目都可以以这个模板为样例"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "pytorch 1.0",
"language": "python",
"name": "pytorch1"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.7"
}
},
"nbformat": 4,
"nbformat_minor": 2
}