> 文档中心 > 现有网络模型的使用及修改

现有网络模型的使用及修改

pytorch官网→docs→torchvision→torchvision.models/Models and pre-trained weights
查看特定网络并添加线性

import torchvisionfrom torch import nnvgg16_False = torchvision.models.vgg16(pretrained=False)vgg16_True = torchvision.models.vgg16(pretrained=True)# vgg16被设计来用于ImageNet数据集的分类# 可以得知out_features=1000,即分类数 为1000,# 那如何更改网络使之适用于CIFAR10数据集呢?print(vgg16_True)train_data = torchvision.datasets.CIFAR10("data", train=True, download=True,transform=torchvision.transforms.ToTensor())# 对VGG16中特定层级添加线性层vgg16_True.classifier.add_module("add_linear", nn.Linear(1000, 10))print(vgg16_True)

原始VGG16网络结构

VGG(  (features): Sequential(    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))    (1): ReLU(inplace=True)    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))    (3): ReLU(inplace=True)    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)    (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))    (6): ReLU(inplace=True)    (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))    (8): ReLU(inplace=True)    (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)    (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))    (11): ReLU(inplace=True)    (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))    (13): ReLU(inplace=True)    (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))    (15): ReLU(inplace=True)    (16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)    (17): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))    (18): ReLU(inplace=True)    (19): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))    (20): ReLU(inplace=True)    (21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))    (22): ReLU(inplace=True)    (23): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)    (24): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))    (25): ReLU(inplace=True)    (26): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))    (27): ReLU(inplace=True)    (28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))    (29): ReLU(inplace=True)    (30): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)  )  (avgpool): AdaptiveAvgPool2d(output_size=(7, 7))  (classifier): Sequential(    (0): Linear(in_features=25088, out_features=4096, bias=True)    (1): ReLU(inplace=True)    (2): Dropout(p=0.5, inplace=False)    (3): Linear(in_features=4096, out_features=4096, bias=True)    (4): ReLU(inplace=True)    (5): Dropout(p=0.5, inplace=False)    (6): Linear(in_features=4096, out_features=1000, bias=True)  ))

修改后网络结构

VGG(  (features): Sequential(    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))    (1): ReLU(inplace=True)    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))    (3): ReLU(inplace=True)    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)    (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))    (6): ReLU(inplace=True)    (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))    (8): ReLU(inplace=True)    (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)    (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))    (11): ReLU(inplace=True)    (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))    (13): ReLU(inplace=True)    (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))    (15): ReLU(inplace=True)    (16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)    (17): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))    (18): ReLU(inplace=True)    (19): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))    (20): ReLU(inplace=True)    (21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))    (22): ReLU(inplace=True)    (23): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)    (24): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))    (25): ReLU(inplace=True)    (26): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))    (27): ReLU(inplace=True)    (28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))    (29): ReLU(inplace=True)    (30): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)  )  (avgpool): AdaptiveAvgPool2d(output_size=(7, 7))  (classifier): Sequential(    (0): Linear(in_features=25088, out_features=4096, bias=True)    (1): ReLU(inplace=True)    (2): Dropout(p=0.5, inplace=False)    (3): Linear(in_features=4096, out_features=4096, bias=True)    (4): ReLU(inplace=True)    (5): Dropout(p=0.5, inplace=False)    (6): Linear(in_features=4096, out_features=1000, bias=True)    (add_linear): Linear(in_features=1000, out_features=10, bias=True)  ))

查看特定网络并修改

print(vgg16_False)# 对VGG16中特定层进行修改vgg16_False.classifier[6] = nn.Linear(4096, 10)print(vgg16_False)

修改后结果

VGG(  (features): Sequential(    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))    (1): ReLU(inplace=True)    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))    (3): ReLU(inplace=True)    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)    (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))    (6): ReLU(inplace=True)    (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))    (8): ReLU(inplace=True)    (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)    (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))    (11): ReLU(inplace=True)    (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))    (13): ReLU(inplace=True)    (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))    (15): ReLU(inplace=True)    (16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)    (17): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))    (18): ReLU(inplace=True)    (19): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))    (20): ReLU(inplace=True)    (21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))    (22): ReLU(inplace=True)    (23): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)    (24): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))    (25): ReLU(inplace=True)    (26): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))    (27): ReLU(inplace=True)    (28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))    (29): ReLU(inplace=True)    (30): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)  )  (avgpool): AdaptiveAvgPool2d(output_size=(7, 7))  (classifier): Sequential(    (0): Linear(in_features=25088, out_features=4096, bias=True)    (1): ReLU(inplace=True)    (2): Dropout(p=0.5, inplace=False)    (3): Linear(in_features=4096, out_features=4096, bias=True)    (4): ReLU(inplace=True)    (5): Dropout(p=0.5, inplace=False)    (6): Linear(in_features=4096, out_features=10, bias=True)  ))

墨言文学成语大全