更新RNN参数和手动实现RNN

This commit is contained in:
zergtant 2020-04-22 17:12:58 +08:00
parent e5810ef041
commit e9c995d79a
2 changed files with 105 additions and 11 deletions

BIN
chapter2/13.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 28 KiB

View File

@ -2,16 +2,16 @@
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"execution_count": 37,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'1.0.0'"
"'1.4.0'"
]
},
"execution_count": 1,
"execution_count": 37,
"metadata": {},
"output_type": "execute_result"
}
@ -72,7 +72,7 @@
"循环神经网络具有特别好的记忆特性,能够将记忆内容应用到当前情景下,但是网络的记忆能力并没有想象的那么有效。记忆最大的问题在于它有遗忘性,我们总是更加清楚地记得最近发生的事情而遗忘很久之前发生的事情,循环神经网络同样有这样的问题。\n",
"\n",
"pytorch 中使用 nn.RNN 类来搭建基于序列的循环神经网络,它的构造函数有以下几个参数:\n",
"- nput_size输入数据X的特征值的数目。 \n",
"- input_size输入数据X的特征值的数目。 \n",
"- hidden_size隐藏层的神经元数量也就是隐藏层的特征数量。\n",
"- num_layers循环神经网络的层数默认值是 1。 \n",
"- bias默认为 True如果为 false 则表示神经元不使用 bias 偏移参数。\n",
@ -84,7 +84,7 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": 38,
"metadata": {},
"outputs": [
{
@ -103,6 +103,87 @@
"print(output.size(),hn.size())"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"初学者看到上面的介绍,肯定还是一头雾水,这些都是什么东西,在实际中如何使用?\n",
"下面我们通过pytorch来手写一个RNN的实现这样通过自己的实现就会对RNN的结构有个更深入的了解了。\n",
"\n",
"在实现之前我们继续深入介绍一下RNN的工作机制RNN其实也是一个普通的神经网络只不过多了一个 hidden_state 来保存历史信息。这个hidden_state的作用就是为了保存以前的状态我们常说RNN中保存的记忆状态信息就是这个 hidden_state 。\n",
"\n",
"对于RNN来说我们只要己住一个公式\n",
"\n",
"$h_t = \\tanh(W_{ih} x_t + b_{ih} + W_{hh} h_{(t-1)} + b_{hh}) $\n",
"\n",
"这个公式来自官网:\n",
"https://pytorch.org/docs/stable/nn.html?highlight=rnn#torch.nn.RNN\n",
"\n",
"这个公式里面的 $x_t$ 是我们当前状态的输入值,$h_{(t-1)}$ 就是上面说的要传入的上一个状态的hidden_state也就是记忆部分。\n",
"整个网络要训练的部分就是 $W_{ih}$ 当前状态输入值的权重,$W_{hh}$ hidden_state也就是上一个状态的权重还有这两个输入偏置值。这四个值加起来使用tanh进行激活pytorch默认是使用tanh作为激活也可以通过设置使用relu作为激活函数。\n",
"\n",
"\n",
"上面讲的步骤就是用红框圈出的一次计算的过程\n",
"![](13.png)\n",
"\n",
"这个步骤与普通的神经网络没有任何的区别,而 RNN 因为多了 序列sequence 这个维度,要使用同一个模型跑 n 次前向传播这个n就是我们序列设置的个数。\n",
"下面我们开始手动实现我们的RNN参考的是karpathy大佬的文章https://karpathy.github.io/2015/05/21/rnn-effectiveness/"
]
},
{
"cell_type": "code",
"execution_count": 110,
"metadata": {},
"outputs": [],
"source": [
"class RNN(object):\n",
" def __init__(self,input_size,hidden_size):\n",
" super().__init__()\n",
" self.W_xh=torch.nn.Linear(input_size,hidden_size) #因为最后的操作是相加 所以hidden要和output 的shape一致\n",
" self.W_hh=torch.nn.Linear(hidden_size,hidden_size)\n",
" \n",
" def __call__(self,x,hidden):\n",
" return self.step(x,hidden)\n",
" def step(self, x, hidden):\n",
" #前向传播的一步\n",
" h1=self.W_hh(hidden)\n",
" w1=self.W_xh(x)\n",
" out = torch.tanh( h1+w1)\n",
" hidden=self.W_hh.weight\n",
" return out,hidden"
]
},
{
"cell_type": "code",
"execution_count": 111,
"metadata": {},
"outputs": [],
"source": [
"rnn = RNN(20,50)\n",
"input = torch.randn( 32 , 20)\n",
"h_0 =torch.randn(32 , 50) \n",
"seq_len = input.shape[0]"
]
},
{
"cell_type": "code",
"execution_count": 112,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"torch.Size([32, 50]) torch.Size([32, 50])\n"
]
}
],
"source": [
"for i in range(seq_len):\n",
" output,hn= rnn(input[i, :], h_0)\n",
"print(output.size(),h_0.size())"
]
},
{
"cell_type": "markdown",
"metadata": {},
@ -127,7 +208,7 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": 61,
"metadata": {},
"outputs": [
{
@ -160,7 +241,7 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 62,
"metadata": {},
"outputs": [
{
@ -176,7 +257,7 @@
"input = torch.randn(5, 3, 10)\n",
"h_0= torch.randn(2, 3, 20)\n",
"output, hn = rnn(input, h0)\n",
"print(output.size(),h0.size())"
"print(output.size(),hn.size())"
]
},
{
@ -278,9 +359,9 @@
],
"metadata": {
"kernelspec": {
"display_name": "pytorch 1.0",
"display_name": "deeplearning",
"language": "python",
"name": "pytorch1"
"name": "dl"
},
"language_info": {
"codemirror_mode": {
@ -292,7 +373,20 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.7"
"version": "3.7.7"
},
"toc": {
"base_numbering": 1,
"nav_menu": {},
"number_sections": true,
"sideBar": true,
"skip_h1_title": true,
"title_cell": "Table of Contents",
"title_sidebar": "Contents",
"toc_cell": false,
"toc_position": {},
"toc_section_display": true,
"toc_window_display": true
}
},
"nbformat": 4,