pytorch-handbook/chapter2/2.1.4-pytorch-basics-data-loader.ipynb

359 lines
9.9 KiB
Plaintext
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# PyTorch 基础 :数据的加载和预处理\n",
"PyTorch通过torch.utils.data对一般常用的数据加载进行了封装可以很容易地实现多线程数据预读和批量加载。\n",
"并且torchvision已经预先实现了常用图像数据集包括前面使用过的CIFAR-10ImageNet、COCO、MNIST、LSUN等数据集可通过torchvision.datasets方便的调用"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'1.0.1.post2'"
]
},
"execution_count": 1,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# 首先要引入相关的包\n",
"import torch\n",
"#打印一下版本\n",
"torch.__version__"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Dataset\n",
"Dataset是一个抽象类为了能够方便的读取需要将要使用的数据包装为Dataset类。\n",
"自定义的Dataset需要继承它并且实现两个成员方法\n",
"1. `__getitem__()` 该方法定义用索引(`0` 到 `len(self)`)获取一条数据或一个样本\n",
"2. `__len__()` 该方法返回数据集的总长度\n",
"\n",
"下面我们使用kaggle上的一个竞赛[bluebook for bulldozers](https://www.kaggle.com/c/bluebook-for-bulldozers/data)自定义一个数据集,为了方便介绍,我们使用里面的数据字典来做说明(因为条数少)"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"#引用\n",
"from torch.utils.data import Dataset\n",
"import pandas as pd"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"#定义一个数据集\n",
"class BulldozerDataset(Dataset):\n",
" \"\"\" 数据集演示 \"\"\"\n",
" def __init__(self, csv_file):\n",
" \"\"\"实现初始化方法,在初始化的时候将数据读载入\"\"\"\n",
" self.df=pd.read_csv(csv_file)\n",
" def __len__(self):\n",
" '''\n",
" 返回df的长度\n",
" '''\n",
" return len(self.df)\n",
" def __getitem__(self, idx):\n",
" '''\n",
" 根据 idx 返回一行数据\n",
" '''\n",
" return self.df.iloc[idx].SalePrice"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"至此,我们的数据集已经定义完成了,我们可以实例化一个对象访问它"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"ds_demo= BulldozerDataset('median_benchmark.csv')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"我们可以直接使用如下命令查看数据集数据\n"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"11573"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"#实现了 __len__ 方法所以可以直接使用len获取数据总数\n",
"len(ds_demo)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"24000.0"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"#用索引可以直接访问对应的数据,对应 __getitem__ 方法\n",
"ds_demo[0]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"自定义的数据集已经创建好了,下面我们使用官方提供的数据载入器,读取数据\n",
"## Dataloader\n",
"DataLoader为我们提供了对Dataset的读取操作常用参数有batch_size(每个batch的大小)、 shuffle(是否进行shuffle操作)、 num_workers(加载数据的时候使用几个子进程)。下面做一个简单的操作"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"dl = torch.utils.data.DataLoader(ds_demo, batch_size=10, shuffle=True, num_workers=0)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"DataLoader返回的是一个可迭代对象我们可以使用迭代器分次获取数据"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"tensor([24000., 24000., 24000., 24000., 24000., 24000., 24000., 24000., 24000.,\n",
" 24000.], dtype=torch.float64)\n"
]
}
],
"source": [
"idata=iter(dl)\n",
"print(next(idata))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"常见的用法是使用for循环对其进行遍历"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0 tensor([24000., 24000., 24000., 24000., 24000., 24000., 24000., 24000., 24000.,\n",
" 24000.], dtype=torch.float64)\n"
]
}
],
"source": [
"for i, data in enumerate(dl):\n",
" print(i,data)\n",
" # 为了节约空间,这里只循环一遍\n",
" break"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"我们已经可以通过dataset定义数据集并使用Datalorder载入和遍历数据集除了这些以外PyTorch还提供能torcvision的计算机视觉扩展包里面封装了\n",
"## torchvision 包\n",
"torchvision 是PyTorch中专门用来处理图像的库PyTorch官网的安装教程中最后的pip install torchvision 就是安装这个包。"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### torchvision.datasets\n",
"torchvision.datasets 可以理解为PyTorch团队自定义的dataset这些dataset帮我们提前处理好了很多的图片数据集我们拿来就可以直接使用\n",
"- MNIST\n",
"- COCO\n",
"- Captions\n",
"- Detection\n",
"- LSUN\n",
"- ImageFolder\n",
"- Imagenet-12\n",
"- CIFAR\n",
"- STL10\n",
"- SVHN\n",
"- PhotoTour\n",
"我们可以直接使用,示例如下:"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"import torchvision.datasets as datasets\n",
"trainset = datasets.MNIST(root='./data', # 表示 MNIST 数据的加载的目录\n",
" train=True, # 表示是否加载数据库的训练集false的时候加载测试集\n",
" download=True, # 表示是否自动下载 MNIST 数据集\n",
" transform=None) # 表示是否需要对数据进行预处理none为不进行预处理\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### torchvision.models\n",
"torchvision不仅提供了常用图片数据集还提供了训练好的模型可以加载之后直接使用或者在进行迁移学习\n",
"torchvision.models模块的 子模块中包含以下模型结构。\n",
"- AlexNet\n",
"- VGG\n",
"- ResNet\n",
"- SqueezeNet\n",
"- DenseNet"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
"#我们直接可以使用训练好的模型当然这个与datasets相同都是需要从服务器下载的\n",
"import torchvision.models as models\n",
"resnet18 = models.resnet18(pretrained=True)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### torchvision.transforms\n",
"transforms 模块提供了一般的图像转换操作类,用作数据处理和数据增强"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
"from torchvision import transforms as transforms\n",
"transform = transforms.Compose([\n",
" transforms.RandomCrop(32, padding=4), #先四周填充0在把图像随机裁剪成32*32\n",
" transforms.RandomHorizontalFlip(), #图像一半的概率翻转,一半的概率不翻转\n",
" transforms.RandomRotation((-45,45)), #随机旋转\n",
" transforms.ToTensor(),\n",
" transforms.Normalize((0.4914, 0.4822, 0.4465), (0.229, 0.224, 0.225)), #R,G,B每层的归一化用到的均值和方差\n",
"])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"肯定有人会问:(0.485, 0.456, 0.406), (0.2023, 0.1994, 0.2010) 这几个数字是什么意思?\n",
"\n",
"官方的这个帖子有详细的说明:\n",
"https://discuss.pytorch.org/t/normalization-in-the-mnist-example/457/21\n",
"这些都是根据ImageNet训练的归一化参数可以直接使用我们认为这个是固定值就可以"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"我们已经完成了Python的基本内容的介绍下面我们要介绍神经网络的理论基础里面的公式等内容我们都使用PyTorch来实现"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.8"
}
},
"nbformat": 4,
"nbformat_minor": 2
}