pytorch-handbook/chapter5/5.2-Structured-Data.ipynb
2020-04-22 10:16:15 +08:00

1463 lines
47 KiB
Plaintext
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'1.4.0'"
]
},
"execution_count": 1,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import numpy as np\n",
"import pandas as pd\n",
"import torch\n",
"from sklearn.preprocessing import LabelEncoder\n",
"from torch.utils.data import Dataset, DataLoader\n",
"import torch.nn.functional as F\n",
"import torch.nn as nn\n",
"from collections import Counter\n",
"torch.__version__"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 5.2 Pytorch处理结构化数据\n",
"## 简介\n",
"在介绍之前我们首先要明确下什么是结构化的数据。结构化数据可以从名称中看出是高度组织和整齐格式化的数据。它是可以放入表格和电子表格中的数据类型。对我们来说结构化数据可以理解为就是一张2维的表格例如一个csv文件就是结构化数据在英文一般被称作Tabular Data或者叫 structured data下面我们来看一下结构化数据的例子。\n",
"\n",
"一下文件来自于fastai的自带数据集\n",
"https://github.com/fastai/fastai/blob/master/examples/tabular.ipynb\n",
"fastai样例在这里\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 数据预处理\n",
"我们拿到的结构化数据一般都是一个csv文件或者数据库中的一张表格所以对于结构化的数据我们直接使用pasdas库处理就可以了"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array(['>=50k', '<50k'], dtype=object)"
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"#读入文件\n",
"df = pd.read_csv('./data/adult.csv')\n",
"#salary是这个数据集最后要分类的结果\n",
"df['salary'].unique()"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>age</th>\n",
" <th>workclass</th>\n",
" <th>fnlwgt</th>\n",
" <th>education</th>\n",
" <th>education-num</th>\n",
" <th>marital-status</th>\n",
" <th>occupation</th>\n",
" <th>relationship</th>\n",
" <th>race</th>\n",
" <th>sex</th>\n",
" <th>capital-gain</th>\n",
" <th>capital-loss</th>\n",
" <th>hours-per-week</th>\n",
" <th>native-country</th>\n",
" <th>salary</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>49</td>\n",
" <td>Private</td>\n",
" <td>101320</td>\n",
" <td>Assoc-acdm</td>\n",
" <td>12.0</td>\n",
" <td>Married-civ-spouse</td>\n",
" <td>NaN</td>\n",
" <td>Wife</td>\n",
" <td>White</td>\n",
" <td>Female</td>\n",
" <td>0</td>\n",
" <td>1902</td>\n",
" <td>40</td>\n",
" <td>United-States</td>\n",
" <td>&gt;=50k</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>44</td>\n",
" <td>Private</td>\n",
" <td>236746</td>\n",
" <td>Masters</td>\n",
" <td>14.0</td>\n",
" <td>Divorced</td>\n",
" <td>Exec-managerial</td>\n",
" <td>Not-in-family</td>\n",
" <td>White</td>\n",
" <td>Male</td>\n",
" <td>10520</td>\n",
" <td>0</td>\n",
" <td>45</td>\n",
" <td>United-States</td>\n",
" <td>&gt;=50k</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>38</td>\n",
" <td>Private</td>\n",
" <td>96185</td>\n",
" <td>HS-grad</td>\n",
" <td>NaN</td>\n",
" <td>Divorced</td>\n",
" <td>NaN</td>\n",
" <td>Unmarried</td>\n",
" <td>Black</td>\n",
" <td>Female</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>32</td>\n",
" <td>United-States</td>\n",
" <td>&lt;50k</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>38</td>\n",
" <td>Self-emp-inc</td>\n",
" <td>112847</td>\n",
" <td>Prof-school</td>\n",
" <td>15.0</td>\n",
" <td>Married-civ-spouse</td>\n",
" <td>Prof-specialty</td>\n",
" <td>Husband</td>\n",
" <td>Asian-Pac-Islander</td>\n",
" <td>Male</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>40</td>\n",
" <td>United-States</td>\n",
" <td>&gt;=50k</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>42</td>\n",
" <td>Self-emp-not-inc</td>\n",
" <td>82297</td>\n",
" <td>7th-8th</td>\n",
" <td>NaN</td>\n",
" <td>Married-civ-spouse</td>\n",
" <td>Other-service</td>\n",
" <td>Wife</td>\n",
" <td>Black</td>\n",
" <td>Female</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>50</td>\n",
" <td>United-States</td>\n",
" <td>&lt;50k</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" age workclass fnlwgt education education-num \\\n",
"0 49 Private 101320 Assoc-acdm 12.0 \n",
"1 44 Private 236746 Masters 14.0 \n",
"2 38 Private 96185 HS-grad NaN \n",
"3 38 Self-emp-inc 112847 Prof-school 15.0 \n",
"4 42 Self-emp-not-inc 82297 7th-8th NaN \n",
"\n",
" marital-status occupation relationship race \\\n",
"0 Married-civ-spouse NaN Wife White \n",
"1 Divorced Exec-managerial Not-in-family White \n",
"2 Divorced NaN Unmarried Black \n",
"3 Married-civ-spouse Prof-specialty Husband Asian-Pac-Islander \n",
"4 Married-civ-spouse Other-service Wife Black \n",
"\n",
" sex capital-gain capital-loss hours-per-week native-country salary \n",
"0 Female 0 1902 40 United-States >=50k \n",
"1 Male 10520 0 45 United-States >=50k \n",
"2 Female 0 0 32 United-States <50k \n",
"3 Male 0 0 40 United-States >=50k \n",
"4 Female 0 0 50 United-States <50k "
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"#查看数据类型\n",
"df.head()"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>age</th>\n",
" <th>fnlwgt</th>\n",
" <th>education-num</th>\n",
" <th>capital-gain</th>\n",
" <th>capital-loss</th>\n",
" <th>hours-per-week</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>count</th>\n",
" <td>32561.000000</td>\n",
" <td>3.256100e+04</td>\n",
" <td>32074.000000</td>\n",
" <td>32561.000000</td>\n",
" <td>32561.000000</td>\n",
" <td>32561.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th>mean</th>\n",
" <td>38.581647</td>\n",
" <td>1.897784e+05</td>\n",
" <td>10.079815</td>\n",
" <td>1077.648844</td>\n",
" <td>87.303830</td>\n",
" <td>40.437456</td>\n",
" </tr>\n",
" <tr>\n",
" <th>std</th>\n",
" <td>13.640433</td>\n",
" <td>1.055500e+05</td>\n",
" <td>2.572999</td>\n",
" <td>7385.292085</td>\n",
" <td>402.960219</td>\n",
" <td>12.347429</td>\n",
" </tr>\n",
" <tr>\n",
" <th>min</th>\n",
" <td>17.000000</td>\n",
" <td>1.228500e+04</td>\n",
" <td>1.000000</td>\n",
" <td>0.000000</td>\n",
" <td>0.000000</td>\n",
" <td>1.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th>25%</th>\n",
" <td>28.000000</td>\n",
" <td>1.178270e+05</td>\n",
" <td>9.000000</td>\n",
" <td>0.000000</td>\n",
" <td>0.000000</td>\n",
" <td>40.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th>50%</th>\n",
" <td>37.000000</td>\n",
" <td>1.783560e+05</td>\n",
" <td>10.000000</td>\n",
" <td>0.000000</td>\n",
" <td>0.000000</td>\n",
" <td>40.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th>75%</th>\n",
" <td>48.000000</td>\n",
" <td>2.370510e+05</td>\n",
" <td>12.000000</td>\n",
" <td>0.000000</td>\n",
" <td>0.000000</td>\n",
" <td>45.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th>max</th>\n",
" <td>90.000000</td>\n",
" <td>1.484705e+06</td>\n",
" <td>16.000000</td>\n",
" <td>99999.000000</td>\n",
" <td>4356.000000</td>\n",
" <td>99.000000</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" age fnlwgt education-num capital-gain capital-loss \\\n",
"count 32561.000000 3.256100e+04 32074.000000 32561.000000 32561.000000 \n",
"mean 38.581647 1.897784e+05 10.079815 1077.648844 87.303830 \n",
"std 13.640433 1.055500e+05 2.572999 7385.292085 402.960219 \n",
"min 17.000000 1.228500e+04 1.000000 0.000000 0.000000 \n",
"25% 28.000000 1.178270e+05 9.000000 0.000000 0.000000 \n",
"50% 37.000000 1.783560e+05 10.000000 0.000000 0.000000 \n",
"75% 48.000000 2.370510e+05 12.000000 0.000000 0.000000 \n",
"max 90.000000 1.484705e+06 16.000000 99999.000000 4356.000000 \n",
"\n",
" hours-per-week \n",
"count 32561.000000 \n",
"mean 40.437456 \n",
"std 12.347429 \n",
"min 1.000000 \n",
"25% 40.000000 \n",
"50% 40.000000 \n",
"75% 45.000000 \n",
"max 99.000000 "
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"#pandas的describe可以告诉我们整个数据集的大概结构是非常有用的\n",
"df.describe()"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"32561"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"#查看一共有多少数据\n",
"len(df)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"对于模型的训练,只能够处理数字类型的数据,所以这里面我们首先要将数据分成三个类别\n",
"- 训练的结果标签:即训练的结果,通过这个结果我们就能够明确的知道我们这次训练的任务是什么,是分类的任务,还是回归的任务。\n",
"- 分类数据:这类的数据是离散的,无法通过直接输入到模型中进行训练,所以我们在预处理的时候需要优先对这部分进行处理,这也是数据预处理的主要工作之一\n",
"- 数值型数据:这类数据是直接可以输入到模型中的,但是这部分数据有可能还是离散的,所以如果需要也可以对其进行处理,并且处理后会对训练的精度有很大的提升,这里暂且不讨论"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"#训练结果\n",
"result_var = 'salary'\n",
"#分类型数据\n",
"cat_names = ['workclass', 'education', 'marital-status', 'occupation', 'relationship', 'race','sex','native-country']\n",
"#数值型数据\n",
"cont_names = ['age', 'fnlwgt', 'education-num','capital-gain','capital-loss','hours-per-week']"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"人工确认完数据类型后,我们可以看一下分类类型数据的数量和分布情况"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"workclass 9 Counter({' Private': 22696, ' Self-emp-not-inc': 2541, ' Local-gov': 2093, ' ?': 1836, ' State-gov': 1298, ' Self-emp-inc': 1116, ' Federal-gov': 960, ' Without-pay': 14, ' Never-worked': 7})\n",
"\r\n",
"\n",
"education 16 Counter({' HS-grad': 10501, ' Some-college': 7291, ' Bachelors': 5355, ' Masters': 1723, ' Assoc-voc': 1382, ' 11th': 1175, ' Assoc-acdm': 1067, ' 10th': 933, ' 7th-8th': 646, ' Prof-school': 576, ' 9th': 514, ' 12th': 433, ' Doctorate': 413, ' 5th-6th': 333, ' 1st-4th': 168, ' Preschool': 51})\n",
"\r\n",
"\n",
"marital-status 7 Counter({' Married-civ-spouse': 14976, ' Never-married': 10683, ' Divorced': 4443, ' Separated': 1025, ' Widowed': 993, ' Married-spouse-absent': 418, ' Married-AF-spouse': 23})\n",
"\r\n",
"\n",
"occupation 16 Counter({' Prof-specialty': 4073, ' Craft-repair': 4028, ' Exec-managerial': 4009, ' Adm-clerical': 3720, ' Sales': 3590, ' Other-service': 3247, ' Machine-op-inspct': 1968, ' ?': 1820, ' Transport-moving': 1566, ' Handlers-cleaners': 1347, ' Farming-fishing': 977, ' Tech-support': 905, ' Protective-serv': 643, nan: 512, ' Priv-house-serv': 147, ' Armed-Forces': 9})\n",
"\r\n",
"\n",
"relationship 6 Counter({' Husband': 13193, ' Not-in-family': 8305, ' Own-child': 5068, ' Unmarried': 3446, ' Wife': 1568, ' Other-relative': 981})\n",
"\r\n",
"\n",
"race 5 Counter({' White': 27816, ' Black': 3124, ' Asian-Pac-Islander': 1039, ' Amer-Indian-Eskimo': 311, ' Other': 271})\n",
"\r\n",
"\n",
"sex 2 Counter({' Male': 21790, ' Female': 10771})\n",
"\r\n",
"\n",
"native-country 42 Counter({' United-States': 29170, ' Mexico': 643, ' ?': 583, ' Philippines': 198, ' Germany': 137, ' Canada': 121, ' Puerto-Rico': 114, ' El-Salvador': 106, ' India': 100, ' Cuba': 95, ' England': 90, ' Jamaica': 81, ' South': 80, ' China': 75, ' Italy': 73, ' Dominican-Republic': 70, ' Vietnam': 67, ' Guatemala': 64, ' Japan': 62, ' Poland': 60, ' Columbia': 59, ' Taiwan': 51, ' Haiti': 44, ' Iran': 43, ' Portugal': 37, ' Nicaragua': 34, ' Peru': 31, ' Greece': 29, ' France': 29, ' Ecuador': 28, ' Ireland': 24, ' Hong': 20, ' Trinadad&Tobago': 19, ' Cambodia': 19, ' Thailand': 18, ' Laos': 18, ' Yugoslavia': 16, ' Outlying-US(Guam-USVI-etc)': 14, ' Hungary': 13, ' Honduras': 13, ' Scotland': 12, ' Holand-Netherlands': 1})\n",
"\r\n",
"\n"
]
}
],
"source": [
"for col in df.columns:\n",
" if col in cat_names:\n",
" ccol=Counter(df[col])\n",
" print(col,len(ccol),ccol)\n",
" print(\"\\r\\n\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"下一步就是要将分类型数据转成数字型数据,在这部分里面,我们还做了对于缺失数据的填充"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"for col in df.columns:\n",
" if col in cat_names:\n",
" df[col].fillna('---')\n",
" df[col] = LabelEncoder().fit_transform(df[col].astype(str))\n",
" if col in cont_names:\n",
" df[col]=df[col].fillna(0)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"上面的代码中:\n",
"\n",
"我们首先使用了pandas的fillna函数对分类的数据做了空值的填充这里面标识成一个与其他现有值不一样的值就可以这里面我使用的三个中划线 --- 作为标记然后就是使用了sklearn的LabelEncoder函数进行了数据的处理\n",
"\n",
"然后有对我们的数值型的数据做了0填充的处理对于数值型数据的填充也可以使用平均值或者其他方式填充这个不是我们的重点就不详细说明了。"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>age</th>\n",
" <th>workclass</th>\n",
" <th>fnlwgt</th>\n",
" <th>education</th>\n",
" <th>education-num</th>\n",
" <th>marital-status</th>\n",
" <th>occupation</th>\n",
" <th>relationship</th>\n",
" <th>race</th>\n",
" <th>sex</th>\n",
" <th>capital-gain</th>\n",
" <th>capital-loss</th>\n",
" <th>hours-per-week</th>\n",
" <th>native-country</th>\n",
" <th>salary</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>49</td>\n",
" <td>4</td>\n",
" <td>101320</td>\n",
" <td>7</td>\n",
" <td>12.0</td>\n",
" <td>2</td>\n",
" <td>15</td>\n",
" <td>5</td>\n",
" <td>4</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>1902</td>\n",
" <td>40</td>\n",
" <td>39</td>\n",
" <td>&gt;=50k</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>44</td>\n",
" <td>4</td>\n",
" <td>236746</td>\n",
" <td>12</td>\n",
" <td>14.0</td>\n",
" <td>0</td>\n",
" <td>4</td>\n",
" <td>1</td>\n",
" <td>4</td>\n",
" <td>1</td>\n",
" <td>10520</td>\n",
" <td>0</td>\n",
" <td>45</td>\n",
" <td>39</td>\n",
" <td>&gt;=50k</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>38</td>\n",
" <td>4</td>\n",
" <td>96185</td>\n",
" <td>11</td>\n",
" <td>0.0</td>\n",
" <td>0</td>\n",
" <td>15</td>\n",
" <td>4</td>\n",
" <td>2</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>32</td>\n",
" <td>39</td>\n",
" <td>&lt;50k</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>38</td>\n",
" <td>5</td>\n",
" <td>112847</td>\n",
" <td>14</td>\n",
" <td>15.0</td>\n",
" <td>2</td>\n",
" <td>10</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>40</td>\n",
" <td>39</td>\n",
" <td>&gt;=50k</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>42</td>\n",
" <td>6</td>\n",
" <td>82297</td>\n",
" <td>5</td>\n",
" <td>0.0</td>\n",
" <td>2</td>\n",
" <td>8</td>\n",
" <td>5</td>\n",
" <td>2</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>50</td>\n",
" <td>39</td>\n",
" <td>&lt;50k</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" age workclass fnlwgt education education-num marital-status \\\n",
"0 49 4 101320 7 12.0 2 \n",
"1 44 4 236746 12 14.0 0 \n",
"2 38 4 96185 11 0.0 0 \n",
"3 38 5 112847 14 15.0 2 \n",
"4 42 6 82297 5 0.0 2 \n",
"\n",
" occupation relationship race sex capital-gain capital-loss \\\n",
"0 15 5 4 0 0 1902 \n",
"1 4 1 4 1 10520 0 \n",
"2 15 4 2 0 0 0 \n",
"3 10 0 1 1 0 0 \n",
"4 8 5 2 0 0 0 \n",
"\n",
" hours-per-week native-country salary \n",
"0 40 39 >=50k \n",
"1 45 39 >=50k \n",
"2 32 39 <50k \n",
"3 40 39 >=50k \n",
"4 50 39 <50k "
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"df.head()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"数据处理完成后可以看到,现在所有的数据都是数字类型的了,可以直接输入到模型进行训练了."
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([1, 1, 0, ..., 1, 0, 0])"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"#分割下训练数据和标签\n",
"Y = df['salary']\n",
"Y_label = LabelEncoder()\n",
"Y=Y_label.fit_transform(Y)\n",
"Y"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {
"scrolled": true
},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>age</th>\n",
" <th>workclass</th>\n",
" <th>fnlwgt</th>\n",
" <th>education</th>\n",
" <th>education-num</th>\n",
" <th>marital-status</th>\n",
" <th>occupation</th>\n",
" <th>relationship</th>\n",
" <th>race</th>\n",
" <th>sex</th>\n",
" <th>capital-gain</th>\n",
" <th>capital-loss</th>\n",
" <th>hours-per-week</th>\n",
" <th>native-country</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>49</td>\n",
" <td>4</td>\n",
" <td>101320</td>\n",
" <td>7</td>\n",
" <td>12.0</td>\n",
" <td>2</td>\n",
" <td>15</td>\n",
" <td>5</td>\n",
" <td>4</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>1902</td>\n",
" <td>40</td>\n",
" <td>39</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>44</td>\n",
" <td>4</td>\n",
" <td>236746</td>\n",
" <td>12</td>\n",
" <td>14.0</td>\n",
" <td>0</td>\n",
" <td>4</td>\n",
" <td>1</td>\n",
" <td>4</td>\n",
" <td>1</td>\n",
" <td>10520</td>\n",
" <td>0</td>\n",
" <td>45</td>\n",
" <td>39</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>38</td>\n",
" <td>4</td>\n",
" <td>96185</td>\n",
" <td>11</td>\n",
" <td>0.0</td>\n",
" <td>0</td>\n",
" <td>15</td>\n",
" <td>4</td>\n",
" <td>2</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>32</td>\n",
" <td>39</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>38</td>\n",
" <td>5</td>\n",
" <td>112847</td>\n",
" <td>14</td>\n",
" <td>15.0</td>\n",
" <td>2</td>\n",
" <td>10</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>40</td>\n",
" <td>39</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>42</td>\n",
" <td>6</td>\n",
" <td>82297</td>\n",
" <td>5</td>\n",
" <td>0.0</td>\n",
" <td>2</td>\n",
" <td>8</td>\n",
" <td>5</td>\n",
" <td>2</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>50</td>\n",
" <td>39</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>32556</th>\n",
" <td>36</td>\n",
" <td>4</td>\n",
" <td>297449</td>\n",
" <td>9</td>\n",
" <td>13.0</td>\n",
" <td>0</td>\n",
" <td>10</td>\n",
" <td>1</td>\n",
" <td>4</td>\n",
" <td>1</td>\n",
" <td>14084</td>\n",
" <td>0</td>\n",
" <td>40</td>\n",
" <td>39</td>\n",
" </tr>\n",
" <tr>\n",
" <th>32557</th>\n",
" <td>23</td>\n",
" <td>0</td>\n",
" <td>123983</td>\n",
" <td>9</td>\n",
" <td>13.0</td>\n",
" <td>4</td>\n",
" <td>0</td>\n",
" <td>3</td>\n",
" <td>3</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>40</td>\n",
" <td>39</td>\n",
" </tr>\n",
" <tr>\n",
" <th>32558</th>\n",
" <td>53</td>\n",
" <td>4</td>\n",
" <td>157069</td>\n",
" <td>7</td>\n",
" <td>12.0</td>\n",
" <td>2</td>\n",
" <td>7</td>\n",
" <td>0</td>\n",
" <td>4</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>40</td>\n",
" <td>39</td>\n",
" </tr>\n",
" <tr>\n",
" <th>32559</th>\n",
" <td>32</td>\n",
" <td>2</td>\n",
" <td>217296</td>\n",
" <td>11</td>\n",
" <td>9.0</td>\n",
" <td>2</td>\n",
" <td>14</td>\n",
" <td>5</td>\n",
" <td>4</td>\n",
" <td>0</td>\n",
" <td>4064</td>\n",
" <td>0</td>\n",
" <td>22</td>\n",
" <td>39</td>\n",
" </tr>\n",
" <tr>\n",
" <th>32560</th>\n",
" <td>26</td>\n",
" <td>4</td>\n",
" <td>182308</td>\n",
" <td>15</td>\n",
" <td>10.0</td>\n",
" <td>2</td>\n",
" <td>10</td>\n",
" <td>0</td>\n",
" <td>4</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>40</td>\n",
" <td>39</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>32561 rows × 14 columns</p>\n",
"</div>"
],
"text/plain": [
" age workclass fnlwgt education education-num marital-status \\\n",
"0 49 4 101320 7 12.0 2 \n",
"1 44 4 236746 12 14.0 0 \n",
"2 38 4 96185 11 0.0 0 \n",
"3 38 5 112847 14 15.0 2 \n",
"4 42 6 82297 5 0.0 2 \n",
"... ... ... ... ... ... ... \n",
"32556 36 4 297449 9 13.0 0 \n",
"32557 23 0 123983 9 13.0 4 \n",
"32558 53 4 157069 7 12.0 2 \n",
"32559 32 2 217296 11 9.0 2 \n",
"32560 26 4 182308 15 10.0 2 \n",
"\n",
" occupation relationship race sex capital-gain capital-loss \\\n",
"0 15 5 4 0 0 1902 \n",
"1 4 1 4 1 10520 0 \n",
"2 15 4 2 0 0 0 \n",
"3 10 0 1 1 0 0 \n",
"4 8 5 2 0 0 0 \n",
"... ... ... ... ... ... ... \n",
"32556 10 1 4 1 14084 0 \n",
"32557 0 3 3 1 0 0 \n",
"32558 7 0 4 1 0 0 \n",
"32559 14 5 4 0 4064 0 \n",
"32560 10 0 4 1 0 0 \n",
"\n",
" hours-per-week native-country \n",
"0 40 39 \n",
"1 45 39 \n",
"2 32 39 \n",
"3 40 39 \n",
"4 50 39 \n",
"... ... ... \n",
"32556 40 39 \n",
"32557 40 39 \n",
"32558 40 39 \n",
"32559 22 39 \n",
"32560 40 39 \n",
"\n",
"[32561 rows x 14 columns]"
]
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"X=df.drop(columns=result_var)\n",
"X"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"以上,基本的数据预处理已经完成了,上面展示的只是一些必要的处理,如果要提高训练准确率还有很多技巧,这里就不详细说明了。\n",
"## 定义数据集\n",
"要使用pytorch处理数据肯定要使用Dataset进行数据集的定义下面定义一个简单的数据集"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
"class tabularDataset(Dataset):\n",
" def __init__(self, X, Y):\n",
" self.x = X.values\n",
" self.y = Y\n",
" \n",
" def __len__(self):\n",
" return len(self.y)\n",
" \n",
" def __getitem__(self, idx):\n",
" return (self.x[idx], self.y[idx])"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [],
"source": [
"train_ds = tabularDataset(X, Y)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"可以直接使用索引访问定义好的数据集中的数据"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(array([4.9000e+01, 4.0000e+00, 1.0132e+05, 7.0000e+00, 1.2000e+01,\n",
" 2.0000e+00, 1.5000e+01, 5.0000e+00, 4.0000e+00, 0.0000e+00,\n",
" 0.0000e+00, 1.9020e+03, 4.0000e+01, 3.9000e+01]),\n",
" 1)"
]
},
"execution_count": 14,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"train_ds[0]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 定义模型\n",
"数据已经准备完毕了下一步就是要定义我们的模型了这里使用了3层线性层的简单模型作为处理"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [],
"source": [
"class tabularModel(nn.Module):\n",
" def __init__(self):\n",
" super().__init__()\n",
" self.lin1 = nn.Linear(14, 500)\n",
" self.lin2 = nn.Linear(500, 100)\n",
" self.lin3 = nn.Linear(100, 2)\n",
" self.bn_in = nn.BatchNorm1d(14)\n",
" self.bn1 = nn.BatchNorm1d(500)\n",
" self.bn2 = nn.BatchNorm1d(100)\n",
" \n",
"\n",
" def forward(self,x_in):\n",
" #print(x_in.shape)\n",
" x = self.bn_in(x_in)\n",
" x = F.relu(self.lin1(x))\n",
" x = self.bn1(x)\n",
" #print(x)\n",
" \n",
" \n",
" x = F.relu(self.lin2(x))\n",
" x = self.bn2(x)\n",
" #print(x)\n",
" \n",
" x = self.lin3(x)\n",
" x=torch.sigmoid(x)\n",
" return x"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"在定义模型的时候看到了我们加入了Batch Normalization来做批量的归一化\n",
"批量归一化的内容请见这篇文章https://mp.weixin.qq.com/s/FFLQBocTZGqnyN79JbSYcQ\n",
"\n",
"或者扫描这个二维码,在微信中查看:\n",
"![](https://raw.githubusercontent.com/zergtant/pytorch-handbook/master/deephub.jpg)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 训练"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"cuda\n"
]
}
],
"source": [
"#训练前指定使用的设备\n",
"DEVICE=torch.device(\"cpu\")\n",
"if torch.cuda.is_available():\n",
" DEVICE=torch.device(\"cuda\")\n",
"print(DEVICE)"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [],
"source": [
"#损失函数\n",
"criterion =nn.CrossEntropyLoss()"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"tabularModel(\n",
" (lin1): Linear(in_features=14, out_features=500, bias=True)\n",
" (lin2): Linear(in_features=500, out_features=100, bias=True)\n",
" (lin3): Linear(in_features=100, out_features=2, bias=True)\n",
" (bn_in): BatchNorm1d(14, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (bn1): BatchNorm1d(500, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (bn2): BatchNorm1d(100, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
")\n"
]
}
],
"source": [
"#实例化模型\n",
"model = tabularModel().to(DEVICE)\n",
"print(model)"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([[0.5110, 0.1931],\n",
" [0.4274, 0.5801],\n",
" [0.5549, 0.7322]], device='cuda:0', grad_fn=<SigmoidBackward>)"
]
},
"execution_count": 19,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"#测试模型是否没问题\n",
"rn=torch.rand(3,14).to(DEVICE)\n",
"model(rn)"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [],
"source": [
"#学习率\n",
"LEARNING_RATE=0.01\n",
"#BS\n",
"batch_size = 1024\n",
"#优化器\n",
"optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)\n"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {},
"outputs": [],
"source": [
"#DataLoader加载数据\n",
"train_dl = DataLoader(train_ds, batch_size=batch_size,shuffle=True)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"以上的基本步骤是每个训练过程都需要的,所以就不多介绍了,下面开始进行模型的训练"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch : 1/100, Loss: 0.4936\n",
"Epoch : 2/100, Loss: 0.4766\n",
"Epoch : 3/100, Loss: 0.4693\n",
"Epoch : 4/100, Loss: 0.4653\n",
"Epoch : 5/100, Loss: 0.4627\n",
"Epoch : 6/100, Loss: 0.4606\n",
"Epoch : 7/100, Loss: 0.4591\n",
"Epoch : 8/100, Loss: 0.4582\n",
"Epoch : 9/100, Loss: 0.4573\n",
"Epoch : 10/100, Loss: 0.4565\n",
"Epoch : 11/100, Loss: 0.4557\n",
"Epoch : 12/100, Loss: 0.4551\n",
"Epoch : 13/100, Loss: 0.4546\n",
"Epoch : 14/100, Loss: 0.4540\n",
"Epoch : 15/100, Loss: 0.4535\n",
"Epoch : 16/100, Loss: 0.4530\n",
"Epoch : 17/100, Loss: 0.4526\n",
"Epoch : 18/100, Loss: 0.4522\n",
"Epoch : 19/100, Loss: 0.4519\n",
"Epoch : 20/100, Loss: 0.4515\n",
"Epoch : 21/100, Loss: 0.4511\n",
"Epoch : 22/100, Loss: 0.4508\n",
"Epoch : 23/100, Loss: 0.4504\n",
"Epoch : 24/100, Loss: 0.4502\n",
"Epoch : 25/100, Loss: 0.4499\n",
"Epoch : 26/100, Loss: 0.4496\n",
"Epoch : 27/100, Loss: 0.4492\n",
"Epoch : 28/100, Loss: 0.4489\n",
"Epoch : 29/100, Loss: 0.4486\n",
"Epoch : 30/100, Loss: 0.4483\n",
"Epoch : 31/100, Loss: 0.4480\n",
"Epoch : 32/100, Loss: 0.4477\n",
"Epoch : 33/100, Loss: 0.4474\n",
"Epoch : 34/100, Loss: 0.4471\n",
"Epoch : 35/100, Loss: 0.4469\n",
"Epoch : 36/100, Loss: 0.4466\n",
"Epoch : 37/100, Loss: 0.4463\n",
"Epoch : 38/100, Loss: 0.4460\n",
"Epoch : 39/100, Loss: 0.4458\n",
"Epoch : 40/100, Loss: 0.4455\n",
"Epoch : 41/100, Loss: 0.4452\n",
"Epoch : 42/100, Loss: 0.4449\n",
"Epoch : 43/100, Loss: 0.4447\n",
"Epoch : 44/100, Loss: 0.4445\n",
"Epoch : 45/100, Loss: 0.4442\n",
"Epoch : 46/100, Loss: 0.4439\n",
"Epoch : 47/100, Loss: 0.4437\n",
"Epoch : 48/100, Loss: 0.4434\n",
"Epoch : 49/100, Loss: 0.4432\n",
"Epoch : 50/100, Loss: 0.4429\n",
"Epoch : 51/100, Loss: 0.4426\n",
"Epoch : 52/100, Loss: 0.4424\n",
"Epoch : 53/100, Loss: 0.4421\n",
"Epoch : 54/100, Loss: 0.4419\n",
"Epoch : 55/100, Loss: 0.4417\n",
"Epoch : 56/100, Loss: 0.4414\n",
"Epoch : 57/100, Loss: 0.4411\n",
"Epoch : 58/100, Loss: 0.4409\n",
"Epoch : 59/100, Loss: 0.4406\n",
"Epoch : 60/100, Loss: 0.4404\n",
"Epoch : 61/100, Loss: 0.4402\n",
"Epoch : 62/100, Loss: 0.4399\n",
"Epoch : 63/100, Loss: 0.4397\n",
"Epoch : 64/100, Loss: 0.4394\n",
"Epoch : 65/100, Loss: 0.4392\n",
"Epoch : 66/100, Loss: 0.4390\n",
"Epoch : 67/100, Loss: 0.4387\n",
"Epoch : 68/100, Loss: 0.4384\n",
"Epoch : 69/100, Loss: 0.4382\n",
"Epoch : 70/100, Loss: 0.4380\n",
"Epoch : 71/100, Loss: 0.4377\n",
"Epoch : 72/100, Loss: 0.4375\n",
"Epoch : 73/100, Loss: 0.4373\n",
"Epoch : 74/100, Loss: 0.4371\n",
"Epoch : 75/100, Loss: 0.4368\n",
"Epoch : 76/100, Loss: 0.4366\n",
"Epoch : 77/100, Loss: 0.4364\n",
"Epoch : 78/100, Loss: 0.4362\n",
"Epoch : 79/100, Loss: 0.4360\n",
"Epoch : 80/100, Loss: 0.4358\n",
"Epoch : 81/100, Loss: 0.4356\n",
"Epoch : 82/100, Loss: 0.4353\n",
"Epoch : 83/100, Loss: 0.4351\n",
"Epoch : 84/100, Loss: 0.4348\n",
"Epoch : 85/100, Loss: 0.4346\n",
"Epoch : 86/100, Loss: 0.4344\n",
"Epoch : 87/100, Loss: 0.4342\n",
"Epoch : 88/100, Loss: 0.4340\n",
"Epoch : 89/100, Loss: 0.4338\n",
"Epoch : 90/100, Loss: 0.4336\n",
"Epoch : 91/100, Loss: 0.4333\n",
"Epoch : 92/100, Loss: 0.4331\n",
"Epoch : 93/100, Loss: 0.4329\n",
"Epoch : 94/100, Loss: 0.4328\n",
"Epoch : 95/100, Loss: 0.4326\n",
"Epoch : 96/100, Loss: 0.4324\n",
"Epoch : 97/100, Loss: 0.4322\n",
"Epoch : 98/100, Loss: 0.4320\n",
"Epoch : 99/100, Loss: 0.4318\n",
"Epoch : 100/100, Loss: 0.4316\n",
"Wall time: 49.6 s\n"
]
}
],
"source": [
"%%time\n",
"model.train()\n",
"#训练10轮\n",
"TOTAL_EPOCHS=100\n",
"#记录损失函数\n",
"losses = [];\n",
"for epoch in range(TOTAL_EPOCHS):\n",
" for i, (x, y) in enumerate(train_dl):\n",
" x = x.float().to(DEVICE) #输入必须未float类型\n",
" y = y.long().to(DEVICE) #结果标签必须未long类型\n",
" #清零\n",
" optimizer.zero_grad()\n",
" outputs = model(x)\n",
" #计算损失函数\n",
" loss = criterion(outputs, y)\n",
" loss.backward()\n",
" optimizer.step()\n",
" losses.append(loss.cpu().data.item())\n",
" print ('Epoch : %d/%d, Loss: %.4f'%(epoch+1, TOTAL_EPOCHS, np.mean(losses)))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"训练完成后我们可以看一下模型的准确率"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"准确率: 89.0000 %\n"
]
}
],
"source": [
"model.eval()\n",
"correct = 0\n",
"total = 0\n",
"for i,(x, y) in enumerate(train_dl):\n",
" x = x.float().to(DEVICE)\n",
" y = y.long()\n",
" outputs = model(x).cpu()\n",
" _, predicted = torch.max(outputs.data, 1)\n",
" total += y.size(0)\n",
" correct += (predicted == y).sum()\n",
"print('准确率: %.4f %%' % (100 * correct / total))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"以上就是基本的流程了"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "deeplearning",
"language": "python",
"name": "dl"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.7"
},
"toc": {
"base_numbering": 1,
"nav_menu": {},
"number_sections": true,
"sideBar": true,
"skip_h1_title": true,
"title_cell": "Table of Contents",
"title_sidebar": "Contents",
"toc_cell": false,
"toc_position": {
"height": "calc(100% - 180px)",
"left": "10px",
"top": "150px",
"width": "307.2px"
},
"toc_section_display": true,
"toc_window_display": true
}
},
"nbformat": 4,
"nbformat_minor": 2
}