修正:一个严重的错误
This commit is contained in:
parent
6fff3a7323
commit
8b8455694f
@ -302,9 +302,12 @@
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Train Epoch: 0 [14848/60000 (25%)]\tLoss: 0.405838\n",
|
||||
"Train Epoch: 0 [30208/60000 (50%)]\tLoss: 0.206041\n",
|
||||
"Train Epoch: 0 [45568/60000 (75%)]\tLoss: 0.144166\n"
|
||||
"Train Epoch: 0 [14848/60000 (25%)]\tLoss: 0.271775\n",
|
||||
"warning: Embedding dir exists, did you set global_step for add_embedding()?\n",
|
||||
"Train Epoch: 0 [30208/60000 (50%)]\tLoss: 0.175213\n",
|
||||
"warning: Embedding dir exists, did you set global_step for add_embedding()?\n",
|
||||
"Train Epoch: 0 [45568/60000 (75%)]\tLoss: 0.115128\n",
|
||||
"warning: Embedding dir exists, did you set global_step for add_embedding()?\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
@ -382,7 +385,7 @@
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"vgg16 = models.vgg16() # 这里下载预训练好的模型\n",
|
||||
"vgg16 = models.vgg16(pretrained=True) # 这里下载预训练好的模型\n",
|
||||
"print(vgg16) # 打印一下这个模型"
|
||||
]
|
||||
},
|
||||
@ -401,14 +404,10 @@
|
||||
"source": [
|
||||
"transform_2 = transforms.Compose([\n",
|
||||
" transforms.Resize(224), \n",
|
||||
" transforms.CenterCrop(224),\n",
|
||||
" transforms.CenterCrop((224,224)),\n",
|
||||
" transforms.ToTensor(),\n",
|
||||
" # convert RGB to BGR\n",
|
||||
" # from <https://github.com/mrzhu-cool/pix2pix-pytorch/blob/master/util.py>\n",
|
||||
" transforms.Lambda(lambda x: torch.index_select(x, 0, torch.LongTensor([2, 1, 0]))),\n",
|
||||
" transforms.Lambda(lambda x: x*255),\n",
|
||||
" transforms.Normalize(mean = [103.939, 116.779, 123.68],\n",
|
||||
" std = [ 1, 1, 1 ]),\n",
|
||||
" transforms.Normalize(mean=[0.485, 0.456, 0.406],\n",
|
||||
" std=[0.229, 0.224, 0.225])\n",
|
||||
"])"
|
||||
]
|
||||
},
|
||||
@ -453,17 +452,21 @@
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"(1, 1000) 931\n"
|
||||
]
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"287"
|
||||
]
|
||||
},
|
||||
"execution_count": 14,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"raw_score = vgg16(vgg16_input)\n",
|
||||
"raw_score_numpy = raw_score.data.numpy()\n",
|
||||
"print(raw_score_numpy.shape, np.argmax(raw_score_numpy.ravel()))"
|
||||
"out = vgg16(vgg16_input)\n",
|
||||
"_, preds = torch.max(out.data, 1)\n",
|
||||
"label=preds.numpy()[0]\n",
|
||||
"label"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -487,7 +490,7 @@
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"打开tensorboard找到graphs 看看效果吧"
|
||||
"打开tensorboard找到graphs就可以看到vgg模型具体的架构了"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
Loading…
Reference in New Issue
Block a user