【Pytorch】使用Pytorch进行知识蒸馏
1. 导入各种包
import pandas as pdimport numpy as npimport matplotlib.pyplot as pltimport tensorflow as tfimport torchimport torch.nn.functional as Fimport torchvisionfrom torch import nnfrom torchvision import transformsfrom torch.utils.data import DataLoader# from torchinfo import summaryfrom tqdm import tqdm
2. 设置随机种子
#设置随机种子torch.manual_seed(0)# device = torch.device("cuda" if torch.cuda.is_available() else "pcu") # 使用云GPU# 使用cuDNN加速卷积运算torch.backends.cudnn.benchmark=True
3. 加载 MNIST 数据集
执行后,MNIST数据集会下载到"dataset/"
文件夹下
# 载入训练集train_dataset = torchvision.datasets.MNIST( root="dataset/", # MNIST数据集存放目录 train=True, #为train=True 时,加载训练集 transform=transforms.ToTensor(), # 图像处理、转不同格式显示 download=True)# 载入测试集test_dataset = torchvision.datasets.MNIST( root="dataset/", train=False, #为train=False 时,加载测试集 transform=transforms.ToTensor(), # 图像处理、转不同格式显示 download=True)train_loder = DataLoader(dataset=train_dataset,batch_size=32,shuffle=True)test_loder = DataLoader(dataset=test_dataset, batch_size=32,shuffle=False) # 从数据库中每次抽出batch size个样本
4. 定义教师模型
class TeacherModel(nn.Module): def __init__(self,in_channels=1,num_classes=10): super(TeacherModel, self).__init__() self.relu = nn.ReLU() self.fc1 = nn.Linear(784,1200) self.fc2 = nn.Linear(1200,1200) self.fc2 = nn.Linear(1200,num_classes) self.dropout = nn.Dropout(p = 0.5) def forward(self,x): x = x.view(-1,784) x = self.fc1(x) x = self.relu(x) x = self.fc2(x) x = self.dropout(x) x = self.relu(x) x = self.fc3(x) return x
附录
1. 关于 import torch.nn as nn
torch.nn
是用于设置网络中的全连接层的,需要注意在二维图像处理的任务中,全连接层的输入与输出一般都设置为二维张量,形状通常为[batch_size, size]
,不同于卷积层要求输入输出是四维张量。
in_features
指的是输入的二维张量的大小,即输入的[batch_size, size]中的size。
out_features
指的是输出的二维张量的大小,即输出的二维张量的形状为[batch_size,output_size]
,当然,它也代表了该全连接层的神经元个数。从输入输出的张量的shape角度来理解,相当于一个输入为[batch_size, in_features]
的张量变换成了[batch_size, out_features]
的输出张量。
import torch as tfrom torch import nn# in_features由输入张量的形状决定,out_features则决定了输出张量的形状 connected_layer = nn.Linear(in_features = 64*64*3, out_features = 1)# 假定输入的图像形状为[64,64,3]input = t.randn(1,64,64,3)# 将四维张量转换为二维张量之后,才能作为全连接层的输入input = input.view(1,64*64*3)print(input.shape)output = connected_layer(input) # 调用全连接层print(output.shape)# 运行结果:# input shape is %s torch.Size([1, 12288])# output shape is %s torch.Size([1, 1])
2. 关于 nn.functional
import torch.nn.functional as F
包含 torch.nn 库中所有函数
同时包含大量 loss 和 activation function
import torch.nn.functional as Floss_func = F.cross_entropyloss = loss_func(model(x), y)loss.backward()
其中 loss.backward() 更新模型的梯度,包括 weights 和 bias
3. 关于from torch.utils.data import DataLoader
DataLoader:数据加载器
,结合了数据集和取样器,并且可以提供多个线程处理数据集。
在训练模型时使用到此函数,用来把训练数据分成多个小组,此函数每次抛出一组数据。直至把所有的数据都抛出。就是做一个数据的初始化。
torch.utils.data.DataLoader(dataset,batch_size=1, shuffle=False, sampler=None,batch_sampler=None, num_workers=0, collate_fn=<function default_collate>,pin_memory=False, drop_last=False,timeout=0, worker_init_fn=None)
4. 关于torch.optim
torch.optim是一个实现了多种优化算法的包,大多数通用的方法都已支持,提供了丰富的接口调用,未来更多精炼的优化算法也将整合进来。
为了使用torch.optim,需先构造一个优化器对象Optimizer,用来保存当前的状态,并能够根据计算得到的梯度来更新参数。
要构建一个优化器optimizer,你必须给它一个可进行迭代优化的包含了所有参数(所有的参数必须是变量s)的列表。 然后,您可以指定程序优化特定的选项,例如学习速率,权重衰减等。
optimizer = optim.SGD(model.parameters(), lr = 0.01, momentum=0.9)optimizer = optim.Adam([var1, var2], lr = 0.0001)self.optimizer_D_B = torch.optim.Adam(self.netD_B.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))
4. 关于model.train()
model.train()
的作用是启用 Batch Normalization 和 Dropout。
如果模型中有BN层(Batch Normalization)和Dropout,需要在训练时添加model.train()。model.train()是保证BN层能够用到每一批数据的均值和方差。对于Dropout,model.train()是随机取一部分网络连接来训练更新参数。
4. 关于optimizer.zero_grad()
optimizer.zero_grad()意思是把梯度置零,也就是把loss关于weight的导数变成0.
另外Pytorch 为什么每一轮batch需要设置optimizer.zero_grad:
根据pytorch中的backward()函数的计算,当网络参量进行反馈时,梯度是被积累的而不是被替换掉;但是在每一个batch时毫无疑问并不需要将两个batch的梯度混合起来累积,因此这里就需要每个batch设置一遍zero_grad 了。
在学习pytorch的时候注意到,对于每个batch大都执行了这样的操作:
# zero the parameter gradients optimizer.zero_grad() # 梯度初始化为零 # forward + backward + optimize outputs = net(inputs) # 前向传播求出预测的值 loss = criterion(outputs, labels) # 求loss loss.backward() # 反向传播求梯度 optimizer.step() # 更新所有参数