1527 lines
50 KiB
Plaintext
1527 lines
50 KiB
Plaintext
{
|
||
"cells": [
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 1,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"'1.3.0'"
|
||
]
|
||
},
|
||
"execution_count": 1,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"import numpy as np\n",
|
||
"import pandas as pd\n",
|
||
"import torch\n",
|
||
"from sklearn.preprocessing import LabelEncoder\n",
|
||
"from torch.utils.data import Dataset, DataLoader\n",
|
||
"import torch.nn.functional as F\n",
|
||
"import torch.nn as nn\n",
|
||
"from collections import Counter\n",
|
||
"torch.__version__"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"# 5.2 Pytorch处理结构化数据\n",
|
||
"## 简介\n",
|
||
"在介绍之前,我们首先要明确下什么是结构化的数据。结构化数据,可以从名称中看出,是高度组织和整齐格式化的数据。它是可以放入表格和电子表格中的数据类型。对我们来说,结构化数据可以理解为就是一张2维的表格,例如一个csv文件,就是结构化数据,在英文一般被称作Tabular Data或者叫 structured data,下面我们来看一下结构化数据的例子。\n",
|
||
"\n",
|
||
"一下文件来自于fastai的自带数据集:\n",
|
||
"https://github.com/fastai/fastai/blob/master/examples/tabular.ipynb\n",
|
||
"fastai样例在这里\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"## 数据预处理\n",
|
||
"我们拿到的结构化数据,一般都是一个csv文件或者数据库中的一张表格,所以对于结构化的数据,我们直接使用pasdas库处理就可以了"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 2,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"array(['>=50k', '<50k'], dtype=object)"
|
||
]
|
||
},
|
||
"execution_count": 2,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"#读入文件\n",
|
||
"df = pd.read_csv('./data/adult.csv')\n",
|
||
"#salary是这个数据集最后要分类的结果\n",
|
||
"df['salary'].unique()"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 3,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/html": [
|
||
"<div>\n",
|
||
"<style scoped>\n",
|
||
" .dataframe tbody tr th:only-of-type {\n",
|
||
" vertical-align: middle;\n",
|
||
" }\n",
|
||
"\n",
|
||
" .dataframe tbody tr th {\n",
|
||
" vertical-align: top;\n",
|
||
" }\n",
|
||
"\n",
|
||
" .dataframe thead th {\n",
|
||
" text-align: right;\n",
|
||
" }\n",
|
||
"</style>\n",
|
||
"<table border=\"1\" class=\"dataframe\">\n",
|
||
" <thead>\n",
|
||
" <tr style=\"text-align: right;\">\n",
|
||
" <th></th>\n",
|
||
" <th>age</th>\n",
|
||
" <th>workclass</th>\n",
|
||
" <th>fnlwgt</th>\n",
|
||
" <th>education</th>\n",
|
||
" <th>education-num</th>\n",
|
||
" <th>marital-status</th>\n",
|
||
" <th>occupation</th>\n",
|
||
" <th>relationship</th>\n",
|
||
" <th>race</th>\n",
|
||
" <th>sex</th>\n",
|
||
" <th>capital-gain</th>\n",
|
||
" <th>capital-loss</th>\n",
|
||
" <th>hours-per-week</th>\n",
|
||
" <th>native-country</th>\n",
|
||
" <th>salary</th>\n",
|
||
" </tr>\n",
|
||
" </thead>\n",
|
||
" <tbody>\n",
|
||
" <tr>\n",
|
||
" <td>0</td>\n",
|
||
" <td>49</td>\n",
|
||
" <td>Private</td>\n",
|
||
" <td>101320</td>\n",
|
||
" <td>Assoc-acdm</td>\n",
|
||
" <td>12.0</td>\n",
|
||
" <td>Married-civ-spouse</td>\n",
|
||
" <td>NaN</td>\n",
|
||
" <td>Wife</td>\n",
|
||
" <td>White</td>\n",
|
||
" <td>Female</td>\n",
|
||
" <td>0</td>\n",
|
||
" <td>1902</td>\n",
|
||
" <td>40</td>\n",
|
||
" <td>United-States</td>\n",
|
||
" <td>>=50k</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <td>1</td>\n",
|
||
" <td>44</td>\n",
|
||
" <td>Private</td>\n",
|
||
" <td>236746</td>\n",
|
||
" <td>Masters</td>\n",
|
||
" <td>14.0</td>\n",
|
||
" <td>Divorced</td>\n",
|
||
" <td>Exec-managerial</td>\n",
|
||
" <td>Not-in-family</td>\n",
|
||
" <td>White</td>\n",
|
||
" <td>Male</td>\n",
|
||
" <td>10520</td>\n",
|
||
" <td>0</td>\n",
|
||
" <td>45</td>\n",
|
||
" <td>United-States</td>\n",
|
||
" <td>>=50k</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <td>2</td>\n",
|
||
" <td>38</td>\n",
|
||
" <td>Private</td>\n",
|
||
" <td>96185</td>\n",
|
||
" <td>HS-grad</td>\n",
|
||
" <td>NaN</td>\n",
|
||
" <td>Divorced</td>\n",
|
||
" <td>NaN</td>\n",
|
||
" <td>Unmarried</td>\n",
|
||
" <td>Black</td>\n",
|
||
" <td>Female</td>\n",
|
||
" <td>0</td>\n",
|
||
" <td>0</td>\n",
|
||
" <td>32</td>\n",
|
||
" <td>United-States</td>\n",
|
||
" <td><50k</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <td>3</td>\n",
|
||
" <td>38</td>\n",
|
||
" <td>Self-emp-inc</td>\n",
|
||
" <td>112847</td>\n",
|
||
" <td>Prof-school</td>\n",
|
||
" <td>15.0</td>\n",
|
||
" <td>Married-civ-spouse</td>\n",
|
||
" <td>Prof-specialty</td>\n",
|
||
" <td>Husband</td>\n",
|
||
" <td>Asian-Pac-Islander</td>\n",
|
||
" <td>Male</td>\n",
|
||
" <td>0</td>\n",
|
||
" <td>0</td>\n",
|
||
" <td>40</td>\n",
|
||
" <td>United-States</td>\n",
|
||
" <td>>=50k</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <td>4</td>\n",
|
||
" <td>42</td>\n",
|
||
" <td>Self-emp-not-inc</td>\n",
|
||
" <td>82297</td>\n",
|
||
" <td>7th-8th</td>\n",
|
||
" <td>NaN</td>\n",
|
||
" <td>Married-civ-spouse</td>\n",
|
||
" <td>Other-service</td>\n",
|
||
" <td>Wife</td>\n",
|
||
" <td>Black</td>\n",
|
||
" <td>Female</td>\n",
|
||
" <td>0</td>\n",
|
||
" <td>0</td>\n",
|
||
" <td>50</td>\n",
|
||
" <td>United-States</td>\n",
|
||
" <td><50k</td>\n",
|
||
" </tr>\n",
|
||
" </tbody>\n",
|
||
"</table>\n",
|
||
"</div>"
|
||
],
|
||
"text/plain": [
|
||
" age workclass fnlwgt education education-num \\\n",
|
||
"0 49 Private 101320 Assoc-acdm 12.0 \n",
|
||
"1 44 Private 236746 Masters 14.0 \n",
|
||
"2 38 Private 96185 HS-grad NaN \n",
|
||
"3 38 Self-emp-inc 112847 Prof-school 15.0 \n",
|
||
"4 42 Self-emp-not-inc 82297 7th-8th NaN \n",
|
||
"\n",
|
||
" marital-status occupation relationship race \\\n",
|
||
"0 Married-civ-spouse NaN Wife White \n",
|
||
"1 Divorced Exec-managerial Not-in-family White \n",
|
||
"2 Divorced NaN Unmarried Black \n",
|
||
"3 Married-civ-spouse Prof-specialty Husband Asian-Pac-Islander \n",
|
||
"4 Married-civ-spouse Other-service Wife Black \n",
|
||
"\n",
|
||
" sex capital-gain capital-loss hours-per-week native-country salary \n",
|
||
"0 Female 0 1902 40 United-States >=50k \n",
|
||
"1 Male 10520 0 45 United-States >=50k \n",
|
||
"2 Female 0 0 32 United-States <50k \n",
|
||
"3 Male 0 0 40 United-States >=50k \n",
|
||
"4 Female 0 0 50 United-States <50k "
|
||
]
|
||
},
|
||
"execution_count": 3,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"#查看数据类型\n",
|
||
"df.head()"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 4,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/html": [
|
||
"<div>\n",
|
||
"<style scoped>\n",
|
||
" .dataframe tbody tr th:only-of-type {\n",
|
||
" vertical-align: middle;\n",
|
||
" }\n",
|
||
"\n",
|
||
" .dataframe tbody tr th {\n",
|
||
" vertical-align: top;\n",
|
||
" }\n",
|
||
"\n",
|
||
" .dataframe thead th {\n",
|
||
" text-align: right;\n",
|
||
" }\n",
|
||
"</style>\n",
|
||
"<table border=\"1\" class=\"dataframe\">\n",
|
||
" <thead>\n",
|
||
" <tr style=\"text-align: right;\">\n",
|
||
" <th></th>\n",
|
||
" <th>age</th>\n",
|
||
" <th>fnlwgt</th>\n",
|
||
" <th>education-num</th>\n",
|
||
" <th>capital-gain</th>\n",
|
||
" <th>capital-loss</th>\n",
|
||
" <th>hours-per-week</th>\n",
|
||
" </tr>\n",
|
||
" </thead>\n",
|
||
" <tbody>\n",
|
||
" <tr>\n",
|
||
" <td>count</td>\n",
|
||
" <td>32561.000000</td>\n",
|
||
" <td>3.256100e+04</td>\n",
|
||
" <td>32074.000000</td>\n",
|
||
" <td>32561.000000</td>\n",
|
||
" <td>32561.000000</td>\n",
|
||
" <td>32561.000000</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <td>mean</td>\n",
|
||
" <td>38.581647</td>\n",
|
||
" <td>1.897784e+05</td>\n",
|
||
" <td>10.079815</td>\n",
|
||
" <td>1077.648844</td>\n",
|
||
" <td>87.303830</td>\n",
|
||
" <td>40.437456</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <td>std</td>\n",
|
||
" <td>13.640433</td>\n",
|
||
" <td>1.055500e+05</td>\n",
|
||
" <td>2.572999</td>\n",
|
||
" <td>7385.292085</td>\n",
|
||
" <td>402.960219</td>\n",
|
||
" <td>12.347429</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <td>min</td>\n",
|
||
" <td>17.000000</td>\n",
|
||
" <td>1.228500e+04</td>\n",
|
||
" <td>1.000000</td>\n",
|
||
" <td>0.000000</td>\n",
|
||
" <td>0.000000</td>\n",
|
||
" <td>1.000000</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <td>25%</td>\n",
|
||
" <td>28.000000</td>\n",
|
||
" <td>1.178270e+05</td>\n",
|
||
" <td>9.000000</td>\n",
|
||
" <td>0.000000</td>\n",
|
||
" <td>0.000000</td>\n",
|
||
" <td>40.000000</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <td>50%</td>\n",
|
||
" <td>37.000000</td>\n",
|
||
" <td>1.783560e+05</td>\n",
|
||
" <td>10.000000</td>\n",
|
||
" <td>0.000000</td>\n",
|
||
" <td>0.000000</td>\n",
|
||
" <td>40.000000</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <td>75%</td>\n",
|
||
" <td>48.000000</td>\n",
|
||
" <td>2.370510e+05</td>\n",
|
||
" <td>12.000000</td>\n",
|
||
" <td>0.000000</td>\n",
|
||
" <td>0.000000</td>\n",
|
||
" <td>45.000000</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <td>max</td>\n",
|
||
" <td>90.000000</td>\n",
|
||
" <td>1.484705e+06</td>\n",
|
||
" <td>16.000000</td>\n",
|
||
" <td>99999.000000</td>\n",
|
||
" <td>4356.000000</td>\n",
|
||
" <td>99.000000</td>\n",
|
||
" </tr>\n",
|
||
" </tbody>\n",
|
||
"</table>\n",
|
||
"</div>"
|
||
],
|
||
"text/plain": [
|
||
" age fnlwgt education-num capital-gain capital-loss \\\n",
|
||
"count 32561.000000 3.256100e+04 32074.000000 32561.000000 32561.000000 \n",
|
||
"mean 38.581647 1.897784e+05 10.079815 1077.648844 87.303830 \n",
|
||
"std 13.640433 1.055500e+05 2.572999 7385.292085 402.960219 \n",
|
||
"min 17.000000 1.228500e+04 1.000000 0.000000 0.000000 \n",
|
||
"25% 28.000000 1.178270e+05 9.000000 0.000000 0.000000 \n",
|
||
"50% 37.000000 1.783560e+05 10.000000 0.000000 0.000000 \n",
|
||
"75% 48.000000 2.370510e+05 12.000000 0.000000 0.000000 \n",
|
||
"max 90.000000 1.484705e+06 16.000000 99999.000000 4356.000000 \n",
|
||
"\n",
|
||
" hours-per-week \n",
|
||
"count 32561.000000 \n",
|
||
"mean 40.437456 \n",
|
||
"std 12.347429 \n",
|
||
"min 1.000000 \n",
|
||
"25% 40.000000 \n",
|
||
"50% 40.000000 \n",
|
||
"75% 45.000000 \n",
|
||
"max 99.000000 "
|
||
]
|
||
},
|
||
"execution_count": 4,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"#pandas的describe可以告诉我们整个数据集的大概结构,是非常有用的\n",
|
||
"df.describe()"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 5,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"32561"
|
||
]
|
||
},
|
||
"execution_count": 5,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"#查看一共有多少数据\n",
|
||
"len(df)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"对于模型的训练,只能够处理数字类型的数据,所以这里面我们首先要将数据分成三个类别\n",
|
||
"- 训练的结果标签:即训练的结果,通过这个结果我们就能够明确的知道我们这次训练的任务是什么,是分类的任务,还是回归的任务。\n",
|
||
"- 分类数据:这类的数据是离散的,无法通过直接输入到模型中进行训练,所以我们在预处理的时候需要优先对这部分进行处理,这也是数据预处理的主要工作之一\n",
|
||
"- 数值型数据:这类数据是直接可以输入到模型中的,但是这部分数据有可能还是离散的,所以如果需要也可以对其进行处理,并且处理后会对训练的精度有很大的提升,这里暂且不讨论"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 6,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"#训练结果\n",
|
||
"result_var = 'salary'\n",
|
||
"#分类型数据\n",
|
||
"cat_names = ['workclass', 'education', 'marital-status', 'occupation', 'relationship', 'race','sex','native-country']\n",
|
||
"#数值型数据\n",
|
||
"cont_names = ['age', 'fnlwgt', 'education-num','capital-gain','capital-loss','hours-per-week']"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"人工确认完数据类型后,我们可以看一下分类类型数据的数量和分布情况"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 7,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"workclass 9 Counter({' Private': 22696, ' Self-emp-not-inc': 2541, ' Local-gov': 2093, ' ?': 1836, ' State-gov': 1298, ' Self-emp-inc': 1116, ' Federal-gov': 960, ' Without-pay': 14, ' Never-worked': 7})\n",
|
||
"\r\n",
|
||
"\n",
|
||
"education 16 Counter({' HS-grad': 10501, ' Some-college': 7291, ' Bachelors': 5355, ' Masters': 1723, ' Assoc-voc': 1382, ' 11th': 1175, ' Assoc-acdm': 1067, ' 10th': 933, ' 7th-8th': 646, ' Prof-school': 576, ' 9th': 514, ' 12th': 433, ' Doctorate': 413, ' 5th-6th': 333, ' 1st-4th': 168, ' Preschool': 51})\n",
|
||
"\r\n",
|
||
"\n",
|
||
"marital-status 7 Counter({' Married-civ-spouse': 14976, ' Never-married': 10683, ' Divorced': 4443, ' Separated': 1025, ' Widowed': 993, ' Married-spouse-absent': 418, ' Married-AF-spouse': 23})\n",
|
||
"\r\n",
|
||
"\n",
|
||
"occupation 16 Counter({' Prof-specialty': 4073, ' Craft-repair': 4028, ' Exec-managerial': 4009, ' Adm-clerical': 3720, ' Sales': 3590, ' Other-service': 3247, ' Machine-op-inspct': 1968, ' ?': 1820, ' Transport-moving': 1566, ' Handlers-cleaners': 1347, ' Farming-fishing': 977, ' Tech-support': 905, ' Protective-serv': 643, nan: 512, ' Priv-house-serv': 147, ' Armed-Forces': 9})\n",
|
||
"\r\n",
|
||
"\n",
|
||
"relationship 6 Counter({' Husband': 13193, ' Not-in-family': 8305, ' Own-child': 5068, ' Unmarried': 3446, ' Wife': 1568, ' Other-relative': 981})\n",
|
||
"\r\n",
|
||
"\n",
|
||
"race 5 Counter({' White': 27816, ' Black': 3124, ' Asian-Pac-Islander': 1039, ' Amer-Indian-Eskimo': 311, ' Other': 271})\n",
|
||
"\r\n",
|
||
"\n",
|
||
"sex 2 Counter({' Male': 21790, ' Female': 10771})\n",
|
||
"\r\n",
|
||
"\n",
|
||
"native-country 42 Counter({' United-States': 29170, ' Mexico': 643, ' ?': 583, ' Philippines': 198, ' Germany': 137, ' Canada': 121, ' Puerto-Rico': 114, ' El-Salvador': 106, ' India': 100, ' Cuba': 95, ' England': 90, ' Jamaica': 81, ' South': 80, ' China': 75, ' Italy': 73, ' Dominican-Republic': 70, ' Vietnam': 67, ' Guatemala': 64, ' Japan': 62, ' Poland': 60, ' Columbia': 59, ' Taiwan': 51, ' Haiti': 44, ' Iran': 43, ' Portugal': 37, ' Nicaragua': 34, ' Peru': 31, ' Greece': 29, ' France': 29, ' Ecuador': 28, ' Ireland': 24, ' Hong': 20, ' Trinadad&Tobago': 19, ' Cambodia': 19, ' Thailand': 18, ' Laos': 18, ' Yugoslavia': 16, ' Outlying-US(Guam-USVI-etc)': 14, ' Hungary': 13, ' Honduras': 13, ' Scotland': 12, ' Holand-Netherlands': 1})\n",
|
||
"\r\n",
|
||
"\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"for col in df.columns:\n",
|
||
" if col in cat_names:\n",
|
||
" ccol=Counter(df[col])\n",
|
||
" print(col,len(ccol),ccol)\n",
|
||
" print(\"\\r\\n\")"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"下一步就是要将分类型数据转成数字型数据,在这部分里面,我们还做了对于缺失数据的填充"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 8,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"for col in df.columns:\n",
|
||
" if col in cat_names:\n",
|
||
" df[col].fillna('---')\n",
|
||
" df[col] = LabelEncoder().fit_transform(df[col].astype(str))\n",
|
||
" if col in cont_names:\n",
|
||
" df[col]=df[col].fillna(0)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"上面的代码中:\n",
|
||
"\n",
|
||
"我们首先使用了pandas的fillna函数对分类的数据做了空值的填充,这里面标识成一个与其他现有值不一样的值就可以,这里面我使用的三个中划线 --- 作为标记,然后就是使用了sklearn的LabelEncoder函数进行了数据的处理\n",
|
||
"\n",
|
||
"然后有对我们的数值型的数据做了0填充的处理,对于数值型数据的填充,也可以使用平均值,或者其他方式填充,这个不是我们的重点,就不详细说明了。"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 9,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/html": [
|
||
"<div>\n",
|
||
"<style scoped>\n",
|
||
" .dataframe tbody tr th:only-of-type {\n",
|
||
" vertical-align: middle;\n",
|
||
" }\n",
|
||
"\n",
|
||
" .dataframe tbody tr th {\n",
|
||
" vertical-align: top;\n",
|
||
" }\n",
|
||
"\n",
|
||
" .dataframe thead th {\n",
|
||
" text-align: right;\n",
|
||
" }\n",
|
||
"</style>\n",
|
||
"<table border=\"1\" class=\"dataframe\">\n",
|
||
" <thead>\n",
|
||
" <tr style=\"text-align: right;\">\n",
|
||
" <th></th>\n",
|
||
" <th>age</th>\n",
|
||
" <th>workclass</th>\n",
|
||
" <th>fnlwgt</th>\n",
|
||
" <th>education</th>\n",
|
||
" <th>education-num</th>\n",
|
||
" <th>marital-status</th>\n",
|
||
" <th>occupation</th>\n",
|
||
" <th>relationship</th>\n",
|
||
" <th>race</th>\n",
|
||
" <th>sex</th>\n",
|
||
" <th>capital-gain</th>\n",
|
||
" <th>capital-loss</th>\n",
|
||
" <th>hours-per-week</th>\n",
|
||
" <th>native-country</th>\n",
|
||
" <th>salary</th>\n",
|
||
" </tr>\n",
|
||
" </thead>\n",
|
||
" <tbody>\n",
|
||
" <tr>\n",
|
||
" <td>0</td>\n",
|
||
" <td>49</td>\n",
|
||
" <td>4</td>\n",
|
||
" <td>101320</td>\n",
|
||
" <td>7</td>\n",
|
||
" <td>12.0</td>\n",
|
||
" <td>2</td>\n",
|
||
" <td>15</td>\n",
|
||
" <td>5</td>\n",
|
||
" <td>4</td>\n",
|
||
" <td>0</td>\n",
|
||
" <td>0</td>\n",
|
||
" <td>1902</td>\n",
|
||
" <td>40</td>\n",
|
||
" <td>39</td>\n",
|
||
" <td>>=50k</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <td>1</td>\n",
|
||
" <td>44</td>\n",
|
||
" <td>4</td>\n",
|
||
" <td>236746</td>\n",
|
||
" <td>12</td>\n",
|
||
" <td>14.0</td>\n",
|
||
" <td>0</td>\n",
|
||
" <td>4</td>\n",
|
||
" <td>1</td>\n",
|
||
" <td>4</td>\n",
|
||
" <td>1</td>\n",
|
||
" <td>10520</td>\n",
|
||
" <td>0</td>\n",
|
||
" <td>45</td>\n",
|
||
" <td>39</td>\n",
|
||
" <td>>=50k</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <td>2</td>\n",
|
||
" <td>38</td>\n",
|
||
" <td>4</td>\n",
|
||
" <td>96185</td>\n",
|
||
" <td>11</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0</td>\n",
|
||
" <td>15</td>\n",
|
||
" <td>4</td>\n",
|
||
" <td>2</td>\n",
|
||
" <td>0</td>\n",
|
||
" <td>0</td>\n",
|
||
" <td>0</td>\n",
|
||
" <td>32</td>\n",
|
||
" <td>39</td>\n",
|
||
" <td><50k</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <td>3</td>\n",
|
||
" <td>38</td>\n",
|
||
" <td>5</td>\n",
|
||
" <td>112847</td>\n",
|
||
" <td>14</td>\n",
|
||
" <td>15.0</td>\n",
|
||
" <td>2</td>\n",
|
||
" <td>10</td>\n",
|
||
" <td>0</td>\n",
|
||
" <td>1</td>\n",
|
||
" <td>1</td>\n",
|
||
" <td>0</td>\n",
|
||
" <td>0</td>\n",
|
||
" <td>40</td>\n",
|
||
" <td>39</td>\n",
|
||
" <td>>=50k</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <td>4</td>\n",
|
||
" <td>42</td>\n",
|
||
" <td>6</td>\n",
|
||
" <td>82297</td>\n",
|
||
" <td>5</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>2</td>\n",
|
||
" <td>8</td>\n",
|
||
" <td>5</td>\n",
|
||
" <td>2</td>\n",
|
||
" <td>0</td>\n",
|
||
" <td>0</td>\n",
|
||
" <td>0</td>\n",
|
||
" <td>50</td>\n",
|
||
" <td>39</td>\n",
|
||
" <td><50k</td>\n",
|
||
" </tr>\n",
|
||
" </tbody>\n",
|
||
"</table>\n",
|
||
"</div>"
|
||
],
|
||
"text/plain": [
|
||
" age workclass fnlwgt education education-num marital-status \\\n",
|
||
"0 49 4 101320 7 12.0 2 \n",
|
||
"1 44 4 236746 12 14.0 0 \n",
|
||
"2 38 4 96185 11 0.0 0 \n",
|
||
"3 38 5 112847 14 15.0 2 \n",
|
||
"4 42 6 82297 5 0.0 2 \n",
|
||
"\n",
|
||
" occupation relationship race sex capital-gain capital-loss \\\n",
|
||
"0 15 5 4 0 0 1902 \n",
|
||
"1 4 1 4 1 10520 0 \n",
|
||
"2 15 4 2 0 0 0 \n",
|
||
"3 10 0 1 1 0 0 \n",
|
||
"4 8 5 2 0 0 0 \n",
|
||
"\n",
|
||
" hours-per-week native-country salary \n",
|
||
"0 40 39 >=50k \n",
|
||
"1 45 39 >=50k \n",
|
||
"2 32 39 <50k \n",
|
||
"3 40 39 >=50k \n",
|
||
"4 50 39 <50k "
|
||
]
|
||
},
|
||
"execution_count": 9,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"df.head()"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"数据处理完成后可以看到,现在所有的数据都是数字类型的了,可以直接输入到模型进行训练了."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 10,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"array([1, 1, 0, ..., 1, 0, 0])"
|
||
]
|
||
},
|
||
"execution_count": 10,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"#分割下训练数据和标签\n",
|
||
"Y = df['salary']\n",
|
||
"Y_label = LabelEncoder()\n",
|
||
"Y=Y_label.fit_transform(Y)\n",
|
||
"Y"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 11,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/html": [
|
||
"<div>\n",
|
||
"<style scoped>\n",
|
||
" .dataframe tbody tr th:only-of-type {\n",
|
||
" vertical-align: middle;\n",
|
||
" }\n",
|
||
"\n",
|
||
" .dataframe tbody tr th {\n",
|
||
" vertical-align: top;\n",
|
||
" }\n",
|
||
"\n",
|
||
" .dataframe thead th {\n",
|
||
" text-align: right;\n",
|
||
" }\n",
|
||
"</style>\n",
|
||
"<table border=\"1\" class=\"dataframe\">\n",
|
||
" <thead>\n",
|
||
" <tr style=\"text-align: right;\">\n",
|
||
" <th></th>\n",
|
||
" <th>age</th>\n",
|
||
" <th>workclass</th>\n",
|
||
" <th>fnlwgt</th>\n",
|
||
" <th>education</th>\n",
|
||
" <th>education-num</th>\n",
|
||
" <th>marital-status</th>\n",
|
||
" <th>occupation</th>\n",
|
||
" <th>relationship</th>\n",
|
||
" <th>race</th>\n",
|
||
" <th>sex</th>\n",
|
||
" <th>capital-gain</th>\n",
|
||
" <th>capital-loss</th>\n",
|
||
" <th>hours-per-week</th>\n",
|
||
" <th>native-country</th>\n",
|
||
" </tr>\n",
|
||
" </thead>\n",
|
||
" <tbody>\n",
|
||
" <tr>\n",
|
||
" <td>0</td>\n",
|
||
" <td>49</td>\n",
|
||
" <td>4</td>\n",
|
||
" <td>101320</td>\n",
|
||
" <td>7</td>\n",
|
||
" <td>12.0</td>\n",
|
||
" <td>2</td>\n",
|
||
" <td>15</td>\n",
|
||
" <td>5</td>\n",
|
||
" <td>4</td>\n",
|
||
" <td>0</td>\n",
|
||
" <td>0</td>\n",
|
||
" <td>1902</td>\n",
|
||
" <td>40</td>\n",
|
||
" <td>39</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <td>1</td>\n",
|
||
" <td>44</td>\n",
|
||
" <td>4</td>\n",
|
||
" <td>236746</td>\n",
|
||
" <td>12</td>\n",
|
||
" <td>14.0</td>\n",
|
||
" <td>0</td>\n",
|
||
" <td>4</td>\n",
|
||
" <td>1</td>\n",
|
||
" <td>4</td>\n",
|
||
" <td>1</td>\n",
|
||
" <td>10520</td>\n",
|
||
" <td>0</td>\n",
|
||
" <td>45</td>\n",
|
||
" <td>39</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <td>2</td>\n",
|
||
" <td>38</td>\n",
|
||
" <td>4</td>\n",
|
||
" <td>96185</td>\n",
|
||
" <td>11</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0</td>\n",
|
||
" <td>15</td>\n",
|
||
" <td>4</td>\n",
|
||
" <td>2</td>\n",
|
||
" <td>0</td>\n",
|
||
" <td>0</td>\n",
|
||
" <td>0</td>\n",
|
||
" <td>32</td>\n",
|
||
" <td>39</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <td>3</td>\n",
|
||
" <td>38</td>\n",
|
||
" <td>5</td>\n",
|
||
" <td>112847</td>\n",
|
||
" <td>14</td>\n",
|
||
" <td>15.0</td>\n",
|
||
" <td>2</td>\n",
|
||
" <td>10</td>\n",
|
||
" <td>0</td>\n",
|
||
" <td>1</td>\n",
|
||
" <td>1</td>\n",
|
||
" <td>0</td>\n",
|
||
" <td>0</td>\n",
|
||
" <td>40</td>\n",
|
||
" <td>39</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <td>4</td>\n",
|
||
" <td>42</td>\n",
|
||
" <td>6</td>\n",
|
||
" <td>82297</td>\n",
|
||
" <td>5</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>2</td>\n",
|
||
" <td>8</td>\n",
|
||
" <td>5</td>\n",
|
||
" <td>2</td>\n",
|
||
" <td>0</td>\n",
|
||
" <td>0</td>\n",
|
||
" <td>0</td>\n",
|
||
" <td>50</td>\n",
|
||
" <td>39</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <td>32556</td>\n",
|
||
" <td>36</td>\n",
|
||
" <td>4</td>\n",
|
||
" <td>297449</td>\n",
|
||
" <td>9</td>\n",
|
||
" <td>13.0</td>\n",
|
||
" <td>0</td>\n",
|
||
" <td>10</td>\n",
|
||
" <td>1</td>\n",
|
||
" <td>4</td>\n",
|
||
" <td>1</td>\n",
|
||
" <td>14084</td>\n",
|
||
" <td>0</td>\n",
|
||
" <td>40</td>\n",
|
||
" <td>39</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <td>32557</td>\n",
|
||
" <td>23</td>\n",
|
||
" <td>0</td>\n",
|
||
" <td>123983</td>\n",
|
||
" <td>9</td>\n",
|
||
" <td>13.0</td>\n",
|
||
" <td>4</td>\n",
|
||
" <td>0</td>\n",
|
||
" <td>3</td>\n",
|
||
" <td>3</td>\n",
|
||
" <td>1</td>\n",
|
||
" <td>0</td>\n",
|
||
" <td>0</td>\n",
|
||
" <td>40</td>\n",
|
||
" <td>39</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <td>32558</td>\n",
|
||
" <td>53</td>\n",
|
||
" <td>4</td>\n",
|
||
" <td>157069</td>\n",
|
||
" <td>7</td>\n",
|
||
" <td>12.0</td>\n",
|
||
" <td>2</td>\n",
|
||
" <td>7</td>\n",
|
||
" <td>0</td>\n",
|
||
" <td>4</td>\n",
|
||
" <td>1</td>\n",
|
||
" <td>0</td>\n",
|
||
" <td>0</td>\n",
|
||
" <td>40</td>\n",
|
||
" <td>39</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <td>32559</td>\n",
|
||
" <td>32</td>\n",
|
||
" <td>2</td>\n",
|
||
" <td>217296</td>\n",
|
||
" <td>11</td>\n",
|
||
" <td>9.0</td>\n",
|
||
" <td>2</td>\n",
|
||
" <td>14</td>\n",
|
||
" <td>5</td>\n",
|
||
" <td>4</td>\n",
|
||
" <td>0</td>\n",
|
||
" <td>4064</td>\n",
|
||
" <td>0</td>\n",
|
||
" <td>22</td>\n",
|
||
" <td>39</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <td>32560</td>\n",
|
||
" <td>26</td>\n",
|
||
" <td>4</td>\n",
|
||
" <td>182308</td>\n",
|
||
" <td>15</td>\n",
|
||
" <td>10.0</td>\n",
|
||
" <td>2</td>\n",
|
||
" <td>10</td>\n",
|
||
" <td>0</td>\n",
|
||
" <td>4</td>\n",
|
||
" <td>1</td>\n",
|
||
" <td>0</td>\n",
|
||
" <td>0</td>\n",
|
||
" <td>40</td>\n",
|
||
" <td>39</td>\n",
|
||
" </tr>\n",
|
||
" </tbody>\n",
|
||
"</table>\n",
|
||
"<p>32561 rows × 14 columns</p>\n",
|
||
"</div>"
|
||
],
|
||
"text/plain": [
|
||
" age workclass fnlwgt education education-num marital-status \\\n",
|
||
"0 49 4 101320 7 12.0 2 \n",
|
||
"1 44 4 236746 12 14.0 0 \n",
|
||
"2 38 4 96185 11 0.0 0 \n",
|
||
"3 38 5 112847 14 15.0 2 \n",
|
||
"4 42 6 82297 5 0.0 2 \n",
|
||
"... ... ... ... ... ... ... \n",
|
||
"32556 36 4 297449 9 13.0 0 \n",
|
||
"32557 23 0 123983 9 13.0 4 \n",
|
||
"32558 53 4 157069 7 12.0 2 \n",
|
||
"32559 32 2 217296 11 9.0 2 \n",
|
||
"32560 26 4 182308 15 10.0 2 \n",
|
||
"\n",
|
||
" occupation relationship race sex capital-gain capital-loss \\\n",
|
||
"0 15 5 4 0 0 1902 \n",
|
||
"1 4 1 4 1 10520 0 \n",
|
||
"2 15 4 2 0 0 0 \n",
|
||
"3 10 0 1 1 0 0 \n",
|
||
"4 8 5 2 0 0 0 \n",
|
||
"... ... ... ... ... ... ... \n",
|
||
"32556 10 1 4 1 14084 0 \n",
|
||
"32557 0 3 3 1 0 0 \n",
|
||
"32558 7 0 4 1 0 0 \n",
|
||
"32559 14 5 4 0 4064 0 \n",
|
||
"32560 10 0 4 1 0 0 \n",
|
||
"\n",
|
||
" hours-per-week native-country \n",
|
||
"0 40 39 \n",
|
||
"1 45 39 \n",
|
||
"2 32 39 \n",
|
||
"3 40 39 \n",
|
||
"4 50 39 \n",
|
||
"... ... ... \n",
|
||
"32556 40 39 \n",
|
||
"32557 40 39 \n",
|
||
"32558 40 39 \n",
|
||
"32559 22 39 \n",
|
||
"32560 40 39 \n",
|
||
"\n",
|
||
"[32561 rows x 14 columns]"
|
||
]
|
||
},
|
||
"execution_count": 11,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"X=df.drop(columns=result_var)\n",
|
||
"X"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"以上,基本的数据预处理已经完成了,上面展示的只是一些必要的处理,如果要提高训练准确率还有很多技巧,这里就不详细说明了。\n",
|
||
"## 定义数据集\n",
|
||
"要使用pytorch处理数据,肯定要使用Dataset进行数据集的定义,下面定义一个简单的数据集"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 12,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"class tabularDataset(Dataset):\n",
|
||
" def __init__(self, X, Y):\n",
|
||
" self.x = X#.to_numpy().astype(float)\n",
|
||
" self.y = Y\n",
|
||
" \n",
|
||
" def __len__(self):\n",
|
||
" return len(self.y)\n",
|
||
" \n",
|
||
" def __getitem__(self, idx):\n",
|
||
" return (self.x.values[idx], self.y[idx])"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 13,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"train_ds = tabularDataset(X, Y)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"可以直接使用索引访问定义好的数据集中的数据"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 14,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"(array([4.9000e+01, 4.0000e+00, 1.0132e+05, 7.0000e+00, 1.2000e+01,\n",
|
||
" 2.0000e+00, 1.5000e+01, 5.0000e+00, 4.0000e+00, 0.0000e+00,\n",
|
||
" 0.0000e+00, 1.9020e+03, 4.0000e+01, 3.9000e+01]), 1)"
|
||
]
|
||
},
|
||
"execution_count": 14,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"train_ds[0]"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"## 定义模型\n",
|
||
"数据已经准备完毕了,下一步就是要定义我们的模型了,这里使用了3层线性层的简单模型作为处理"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 15,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"class tabularModel(nn.Module):\n",
|
||
" def __init__(self):\n",
|
||
" super().__init__()\n",
|
||
" self.lin1 = nn.Linear(14, 500)\n",
|
||
" self.lin2 = nn.Linear(500, 100)\n",
|
||
" self.lin3 = nn.Linear(100, 2)\n",
|
||
" self.bn1 = nn.BatchNorm1d(14)\n",
|
||
" self.bn2 = nn.BatchNorm1d(500)\n",
|
||
" self.bn3 = nn.BatchNorm1d(100)\n",
|
||
" \n",
|
||
"\n",
|
||
" def forward(self,x_in):\n",
|
||
" #print(x_in.shape)\n",
|
||
" x=x_in\n",
|
||
" x = self.bn1(x)\n",
|
||
" x = F.relu(self.lin1(x))\n",
|
||
" #print(x)\n",
|
||
" \n",
|
||
" x = self.bn2(x)\n",
|
||
" x = F.relu(self.lin2(x))\n",
|
||
" #print(x)\n",
|
||
" \n",
|
||
" x = self.bn3(x)\n",
|
||
" x = self.lin3(x)\n",
|
||
" x=torch.sigmoid(x)\n",
|
||
" return x"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"在定义模型的时候看到了我们加入了Batch Normalization来做批量的归一化:\n",
|
||
"批量归一化的内容请见这篇文章:https://mp.weixin.qq.com/s/FFLQBocTZGqnyN79JbSYcQ\n",
|
||
"\n",
|
||
"或者扫描这个二维码,在微信中查看:\n",
|
||
""
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"## 训练"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 16,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"cuda\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"#训练前指定使用的设备\n",
|
||
"DEVICE=torch.device(\"cpu\")\n",
|
||
"if torch.cuda.is_available():\n",
|
||
" DEVICE=torch.device(\"cuda\")\n",
|
||
"print(DEVICE)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 17,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"#损失函数\n",
|
||
"criterion =nn.CrossEntropyLoss().to(DEVICE)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 18,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"tabularModel(\n",
|
||
" (lin1): Linear(in_features=14, out_features=500, bias=True)\n",
|
||
" (lin2): Linear(in_features=500, out_features=100, bias=True)\n",
|
||
" (lin3): Linear(in_features=100, out_features=2, bias=True)\n",
|
||
" (bn1): BatchNorm1d(14, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||
" (bn2): BatchNorm1d(500, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||
" (bn3): BatchNorm1d(100, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||
")\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"#实例化模型\n",
|
||
"model = tabularModel().to(DEVICE)\n",
|
||
"print(model)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 19,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"tensor([[0.4717, 0.4703],\n",
|
||
" [0.6462, 0.3824],\n",
|
||
" [0.3931, 0.6696]], device='cuda:0', grad_fn=<SigmoidBackward>)"
|
||
]
|
||
},
|
||
"execution_count": 19,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"#测试模型是否没问题\n",
|
||
"rn=torch.rand(3,14).to(DEVICE)\n",
|
||
"model(rn)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 25,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"#学习率\n",
|
||
"LEARNING_RATE=0.01\n",
|
||
"#BS\n",
|
||
"batch_size = 2048\n",
|
||
"#优化器\n",
|
||
"optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 26,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"#DataLoader加载数据\n",
|
||
"train_dl = DataLoader(train_ds, batch_size=batch_size,shuffle=True)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"以上的基本步骤是每个训练过程都需要的,所以就不多介绍了,下面开始进行模型的训练"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 27,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Epoch : 1/10, Loss: 0.7006\n",
|
||
"Epoch : 1/10, Loss: 0.6112\n",
|
||
"Epoch : 1/10, Loss: 0.5917\n",
|
||
"Epoch : 1/10, Loss: 0.5557\n",
|
||
"Epoch : 1/10, Loss: 0.5422\n",
|
||
"Epoch : 1/10, Loss: 0.5292\n",
|
||
"Epoch : 1/10, Loss: 0.5104\n",
|
||
"Epoch : 1/10, Loss: 0.5031\n",
|
||
"Epoch : 1/10, Loss: 0.5068\n",
|
||
"Epoch : 1/10, Loss: 0.4939\n",
|
||
"Epoch : 1/10, Loss: 0.4782\n",
|
||
"Epoch : 1/10, Loss: 0.4864\n",
|
||
"Epoch : 1/10, Loss: 0.4845\n",
|
||
"Epoch : 1/10, Loss: 0.4787\n",
|
||
"Epoch : 1/10, Loss: 0.4746\n",
|
||
"Epoch : 1/10, Loss: 0.4671\n",
|
||
"Epoch : 2/10, Loss: 0.4517\n",
|
||
"Epoch : 2/10, Loss: 0.4780\n",
|
||
"Epoch : 2/10, Loss: 0.4703\n",
|
||
"Epoch : 2/10, Loss: 0.4614\n",
|
||
"Epoch : 2/10, Loss: 0.4655\n",
|
||
"Epoch : 2/10, Loss: 0.4640\n",
|
||
"Epoch : 2/10, Loss: 0.4520\n",
|
||
"Epoch : 2/10, Loss: 0.4710\n",
|
||
"Epoch : 2/10, Loss: 0.4644\n",
|
||
"Epoch : 2/10, Loss: 0.4633\n",
|
||
"Epoch : 2/10, Loss: 0.4533\n",
|
||
"Epoch : 2/10, Loss: 0.4676\n",
|
||
"Epoch : 2/10, Loss: 0.4724\n",
|
||
"Epoch : 2/10, Loss: 0.4544\n",
|
||
"Epoch : 2/10, Loss: 0.4530\n",
|
||
"Epoch : 2/10, Loss: 0.4597\n",
|
||
"Epoch : 3/10, Loss: 0.4546\n",
|
||
"Epoch : 3/10, Loss: 0.4642\n",
|
||
"Epoch : 3/10, Loss: 0.4499\n",
|
||
"Epoch : 3/10, Loss: 0.4639\n",
|
||
"Epoch : 3/10, Loss: 0.4541\n",
|
||
"Epoch : 3/10, Loss: 0.4662\n",
|
||
"Epoch : 3/10, Loss: 0.4602\n",
|
||
"Epoch : 3/10, Loss: 0.4574\n",
|
||
"Epoch : 3/10, Loss: 0.4523\n",
|
||
"Epoch : 3/10, Loss: 0.4701\n",
|
||
"Epoch : 3/10, Loss: 0.4536\n",
|
||
"Epoch : 3/10, Loss: 0.4593\n",
|
||
"Epoch : 3/10, Loss: 0.4531\n",
|
||
"Epoch : 3/10, Loss: 0.4496\n",
|
||
"Epoch : 3/10, Loss: 0.4651\n",
|
||
"Epoch : 3/10, Loss: 0.4557\n",
|
||
"Epoch : 4/10, Loss: 0.4614\n",
|
||
"Epoch : 4/10, Loss: 0.4517\n",
|
||
"Epoch : 4/10, Loss: 0.4541\n",
|
||
"Epoch : 4/10, Loss: 0.4529\n",
|
||
"Epoch : 4/10, Loss: 0.4641\n",
|
||
"Epoch : 4/10, Loss: 0.4590\n",
|
||
"Epoch : 4/10, Loss: 0.4578\n",
|
||
"Epoch : 4/10, Loss: 0.4534\n",
|
||
"Epoch : 4/10, Loss: 0.4645\n",
|
||
"Epoch : 4/10, Loss: 0.4429\n",
|
||
"Epoch : 4/10, Loss: 0.4533\n",
|
||
"Epoch : 4/10, Loss: 0.4579\n",
|
||
"Epoch : 4/10, Loss: 0.4551\n",
|
||
"Epoch : 4/10, Loss: 0.4468\n",
|
||
"Epoch : 4/10, Loss: 0.4586\n",
|
||
"Epoch : 4/10, Loss: 0.4530\n",
|
||
"Epoch : 5/10, Loss: 0.4383\n",
|
||
"Epoch : 5/10, Loss: 0.4542\n",
|
||
"Epoch : 5/10, Loss: 0.4515\n",
|
||
"Epoch : 5/10, Loss: 0.4523\n",
|
||
"Epoch : 5/10, Loss: 0.4564\n",
|
||
"Epoch : 5/10, Loss: 0.4517\n",
|
||
"Epoch : 5/10, Loss: 0.4580\n",
|
||
"Epoch : 5/10, Loss: 0.4533\n",
|
||
"Epoch : 5/10, Loss: 0.4570\n",
|
||
"Epoch : 5/10, Loss: 0.4625\n",
|
||
"Epoch : 5/10, Loss: 0.4532\n",
|
||
"Epoch : 5/10, Loss: 0.4619\n",
|
||
"Epoch : 5/10, Loss: 0.4534\n",
|
||
"Epoch : 5/10, Loss: 0.4462\n",
|
||
"Epoch : 5/10, Loss: 0.4515\n",
|
||
"Epoch : 5/10, Loss: 0.4533\n",
|
||
"Epoch : 6/10, Loss: 0.4517\n",
|
||
"Epoch : 6/10, Loss: 0.4444\n",
|
||
"Epoch : 6/10, Loss: 0.4564\n",
|
||
"Epoch : 6/10, Loss: 0.4503\n",
|
||
"Epoch : 6/10, Loss: 0.4554\n",
|
||
"Epoch : 6/10, Loss: 0.4498\n",
|
||
"Epoch : 6/10, Loss: 0.4512\n",
|
||
"Epoch : 6/10, Loss: 0.4413\n",
|
||
"Epoch : 6/10, Loss: 0.4497\n",
|
||
"Epoch : 6/10, Loss: 0.4587\n",
|
||
"Epoch : 6/10, Loss: 0.4476\n",
|
||
"Epoch : 6/10, Loss: 0.4568\n",
|
||
"Epoch : 6/10, Loss: 0.4568\n",
|
||
"Epoch : 6/10, Loss: 0.4550\n",
|
||
"Epoch : 6/10, Loss: 0.4527\n",
|
||
"Epoch : 6/10, Loss: 0.4585\n",
|
||
"Epoch : 7/10, Loss: 0.4436\n",
|
||
"Epoch : 7/10, Loss: 0.4496\n",
|
||
"Epoch : 7/10, Loss: 0.4517\n",
|
||
"Epoch : 7/10, Loss: 0.4510\n",
|
||
"Epoch : 7/10, Loss: 0.4520\n",
|
||
"Epoch : 7/10, Loss: 0.4563\n",
|
||
"Epoch : 7/10, Loss: 0.4373\n",
|
||
"Epoch : 7/10, Loss: 0.4375\n",
|
||
"Epoch : 7/10, Loss: 0.4619\n",
|
||
"Epoch : 7/10, Loss: 0.4540\n",
|
||
"Epoch : 7/10, Loss: 0.4569\n",
|
||
"Epoch : 7/10, Loss: 0.4635\n",
|
||
"Epoch : 7/10, Loss: 0.4607\n",
|
||
"Epoch : 7/10, Loss: 0.4435\n",
|
||
"Epoch : 7/10, Loss: 0.4495\n",
|
||
"Epoch : 7/10, Loss: 0.4521\n",
|
||
"Epoch : 8/10, Loss: 0.4473\n",
|
||
"Epoch : 8/10, Loss: 0.4610\n",
|
||
"Epoch : 8/10, Loss: 0.4420\n",
|
||
"Epoch : 8/10, Loss: 0.4545\n",
|
||
"Epoch : 8/10, Loss: 0.4434\n",
|
||
"Epoch : 8/10, Loss: 0.4628\n",
|
||
"Epoch : 8/10, Loss: 0.4430\n",
|
||
"Epoch : 8/10, Loss: 0.4542\n",
|
||
"Epoch : 8/10, Loss: 0.4460\n",
|
||
"Epoch : 8/10, Loss: 0.4588\n",
|
||
"Epoch : 8/10, Loss: 0.4511\n",
|
||
"Epoch : 8/10, Loss: 0.4352\n",
|
||
"Epoch : 8/10, Loss: 0.4478\n",
|
||
"Epoch : 8/10, Loss: 0.4534\n",
|
||
"Epoch : 8/10, Loss: 0.4577\n",
|
||
"Epoch : 8/10, Loss: 0.4470\n",
|
||
"Epoch : 9/10, Loss: 0.4498\n",
|
||
"Epoch : 9/10, Loss: 0.4527\n",
|
||
"Epoch : 9/10, Loss: 0.4533\n",
|
||
"Epoch : 9/10, Loss: 0.4461\n",
|
||
"Epoch : 9/10, Loss: 0.4549\n",
|
||
"Epoch : 9/10, Loss: 0.4489\n",
|
||
"Epoch : 9/10, Loss: 0.4396\n",
|
||
"Epoch : 9/10, Loss: 0.4437\n",
|
||
"Epoch : 9/10, Loss: 0.4424\n",
|
||
"Epoch : 9/10, Loss: 0.4459\n",
|
||
"Epoch : 9/10, Loss: 0.4521\n",
|
||
"Epoch : 9/10, Loss: 0.4540\n",
|
||
"Epoch : 9/10, Loss: 0.4429\n",
|
||
"Epoch : 9/10, Loss: 0.4544\n",
|
||
"Epoch : 9/10, Loss: 0.4495\n",
|
||
"Epoch : 9/10, Loss: 0.4626\n",
|
||
"Epoch : 10/10, Loss: 0.4517\n",
|
||
"Epoch : 10/10, Loss: 0.4376\n",
|
||
"Epoch : 10/10, Loss: 0.4476\n",
|
||
"Epoch : 10/10, Loss: 0.4466\n",
|
||
"Epoch : 10/10, Loss: 0.4462\n",
|
||
"Epoch : 10/10, Loss: 0.4474\n",
|
||
"Epoch : 10/10, Loss: 0.4452\n",
|
||
"Epoch : 10/10, Loss: 0.4580\n",
|
||
"Epoch : 10/10, Loss: 0.4395\n",
|
||
"Epoch : 10/10, Loss: 0.4396\n",
|
||
"Epoch : 10/10, Loss: 0.4473\n",
|
||
"Epoch : 10/10, Loss: 0.4443\n",
|
||
"Epoch : 10/10, Loss: 0.4608\n",
|
||
"Epoch : 10/10, Loss: 0.4610\n",
|
||
"Epoch : 10/10, Loss: 0.4512\n",
|
||
"Epoch : 10/10, Loss: 0.4504\n",
|
||
"Wall time: 23min 1s\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"%%time\n",
|
||
"model.train()\n",
|
||
"#训练10轮\n",
|
||
"TOTAL_EPOCHS=10\n",
|
||
"#记录损失函数\n",
|
||
"losses = [];\n",
|
||
"for epoch in range(TOTAL_EPOCHS):\n",
|
||
" for i, (x, y) in enumerate(train_dl):\n",
|
||
" x = x.float().to(DEVICE) #输入必须未float类型\n",
|
||
" y = y.long().to(DEVICE) #结果标签必须未long类型\n",
|
||
" #清零\n",
|
||
" optimizer.zero_grad()\n",
|
||
" outputs = model(x)\n",
|
||
" #计算损失函数\n",
|
||
" loss = criterion(outputs, y)\n",
|
||
" loss.backward()\n",
|
||
" optimizer.step()\n",
|
||
" losses.append(loss.cpu().data.item()); \n",
|
||
" print ('Epoch : %d/%d, Loss: %.4f'%(epoch+1, TOTAL_EPOCHS, loss.data.item()))"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"训练完成后我们可以看一下模型的准确率"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 28,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"准确率: 86.0000 %\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"model.eval()\n",
|
||
"correct = 0\n",
|
||
"total = 0\n",
|
||
"for i,(x, y) in enumerate(train_dl):\n",
|
||
" x = x.float().to(DEVICE)\n",
|
||
" y = y.long()\n",
|
||
" outputs = model(x).cpu()\n",
|
||
" _, predicted = torch.max(outputs.data, 1)\n",
|
||
" total += y.size(0)\n",
|
||
" correct += (predicted == y).sum()\n",
|
||
"print('准确率: %.4f %%' % (100 * correct / total))"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"通过基本的训练流程,准确率虽然达到86%了,但是损失在0.4就已经不下降了,说明这个网络最大程度就是这个水平了,那么还有什么办法提高准确程度呢?。后面还会介绍更高级的数据的处理方法,以提高准确程度"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": []
|
||
}
|
||
],
|
||
"metadata": {
|
||
"kernelspec": {
|
||
"display_name": "deep learning",
|
||
"language": "python",
|
||
"name": "dl"
|
||
},
|
||
"language_info": {
|
||
"codemirror_mode": {
|
||
"name": "ipython",
|
||
"version": 3
|
||
},
|
||
"file_extension": ".py",
|
||
"mimetype": "text/x-python",
|
||
"name": "python",
|
||
"nbconvert_exporter": "python",
|
||
"pygments_lexer": "ipython3",
|
||
"version": "3.6.9"
|
||
},
|
||
"toc": {
|
||
"base_numbering": 1,
|
||
"nav_menu": {},
|
||
"number_sections": true,
|
||
"sideBar": true,
|
||
"skip_h1_title": true,
|
||
"title_cell": "Table of Contents",
|
||
"title_sidebar": "Contents",
|
||
"toc_cell": false,
|
||
"toc_position": {
|
||
"height": "calc(100% - 180px)",
|
||
"left": "10px",
|
||
"top": "150px",
|
||
"width": "307.2px"
|
||
},
|
||
"toc_section_display": true,
|
||
"toc_window_display": true
|
||
}
|
||
},
|
||
"nbformat": 4,
|
||
"nbformat_minor": 2
|
||
}
|