基于Pytorch的神经网络之autoencoder
目录
1. 引言
2.结构
3.搭建网络
4.代码
1. 引言
今天我们来学习一种在压缩数据方面有较好效果的神经网络:自编码(autoencoder)。
2.结构
自编码的主要思想就是将数据先不断encode进行降维提取其中的关键信息,再decode解码成新的信息,我们的目标是要使我们生成的信息和原信息尽可能相似,有点像我们做题一样,先刷大量的题,提取其中的关键解题思路,遇到同类型的题目就会写了,基本结构如下
这个网络先将信息不断压缩,再解压,对比原始数据和新数据的差别,再反向传播修正参数,最后输出的新数据就会越来越接近原始数据,最后最中间的一层就是这组信息的关键特征。
3.搭建网络
我们以输出手写数字为例,我们还是只看网络的搭建过程:
#搭建网络class AutoEncoder(nn.Module): def __init__(self): super(AutoEncoder, self).__init__() self.encoder = nn.Sequential( nn.Linear(28*28, 128), nn.Tanh(), nn.Linear(128, 64), nn.Tanh(), nn.Linear(64, 12), nn.Tanh(), nn.Linear(12, 3), ) self.decoder = nn.Sequential( nn.Linear(3, 12), nn.Tanh(), nn.Linear(12, 64), nn.Tanh(), nn.Linear(64, 128), nn.Tanh(), nn.Linear(128, 28*28), nn.Sigmoid(),) def forward(self, x): encoded = self.encoder(x) decoded = self.decoder(encoded) return encoded, decoded
我们需要构建编码和解码两个网络,编码部分就是将数据一层一层压缩,解码部分就是对应着一层层解压,中间加入非线性函数,最后再加一个激活函数,但是要确保数据的范围与原来一样,因为我们的目标是生成与原数据一样的数据,最后前向传播先编码再解码即可。
下面是效果
开始:
中间:
最后:
可以看出输出图片与原始图片越来越接近了。
最后的损失:
loss:0.33
4.代码
#调库import torchimport torch.nn as nnimport torch.utils.data as Dataimport torchvisionimport matplotlib.pyplot as pltimport numpy as np#超参数EPOCH = 10BATCH_SIZE = 64LR = 0.005 # learning rateDOWNLOAD_MNIST = FalseN_TEST_IMG = 5#下载数据train_data = torchvision.datasets.MNIST( root='./mnist/', train=True, transform=torchvision.transforms.ToTensor(), download=DOWNLOAD_MNIST, )#小批数据train_loader = Data.DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)#搭建网络class AutoEncoder(nn.Module): def __init__(self): super(AutoEncoder, self).__init__() self.encoder = nn.Sequential( nn.Linear(28*28, 128), nn.Tanh(), nn.Linear(128, 64), nn.Tanh(), nn.Linear(64, 12), nn.Tanh(), nn.Linear(12, 3), ) self.decoder = nn.Sequential( nn.Linear(3, 12), nn.Tanh(), nn.Linear(12, 64), nn.Tanh(), nn.Linear(64, 128), nn.Tanh(), nn.Linear(128, 28*28), nn.Sigmoid(),) def forward(self, x): encoded = self.encoder(x) decoded = self.decoder(encoded) return encoded, decodedautoencoder = AutoEncoder()optimizer = torch.optim.Adam(autoencoder.parameters(), lr=LR)loss_func = nn.MSELoss()f, a = plt.subplots(2, N_TEST_IMG, figsize=(5, 2))plt.ion() view_data = train_data.train_data[:N_TEST_IMG].view(-1, 28*28).type(torch.FloatTensor)/255.for i in range(N_TEST_IMG): a[0][i].imshow(np.reshape(view_data.data.numpy()[i], (28, 28)), cmap='gray'); a[0][i].set_xticks(()); a[0][i].set_yticks(())#训练并绘图for epoch in range(EPOCH): for step, (x, b_label) in enumerate(train_loader): b_x = x.view(-1, 28*28) b_y = x.view(-1, 28*28) encoded, decoded = autoencoder(b_x) loss = loss_func(decoded, b_y) optimizer.zero_grad() loss.backward() optimizer.step() if step % 100 == 0: print('Epoch: ', epoch, '| train loss: %.4f' % loss.data.numpy()) _, decoded_data = autoencoder(view_data) for i in range(N_TEST_IMG): a[1][i].clear() a[1][i].imshow(np.reshape(decoded_data.data.numpy()[i], (28, 28)), cmap='gray') a[1][i].set_xticks(()); a[1][i].set_yticks(()) plt.draw(); plt.pause(0.05)plt.ioff()plt.show()