> 文档中心 > Pytorch 使用DCGAN生成动漫人物头像 入门级实战教程

Pytorch 使用DCGAN生成动漫人物头像 入门级实战教程

有关DCGAN实战的小例子之前已经更新过一篇,感兴趣的朋友可以点击查看

Pytorch 使用DCGAN生成MNIST手写数字 入门级教程

有关DCGAN的相关原理:DCGAN论文解读-----DCGAN原理简介与基础GAN的区别

一、数据集说明

本实验使用到的动漫人物头像数据集,大约有两万多张动漫人物头像,已上传资源供大家免费下载

动漫人物头像数据集 anime-facehttps://download.csdn.net/download/m0_62128864/85072972

二、读取数据集

# 读取图片class Face_dataset(data.Dataset):    def __init__(self, imgs_path): self.imgs_path = imgs_path    def __getitem__(self, index): imgs_path = self.imgs_path[index] pil_img = Image.open(imgs_path) pil_img = transform(pil_img) return pil_img    def __len__(self): return len(self.imgs_path)

三、定义生成器

# 定义生成器class Generator(nn.Module):    def __init__(self): super(Generator,self).__init__() self.linear1 = nn.Linear(100, 256*16*16)   self.bn1 = nn.BatchNorm1d(256*16*16) self.deconv1 = nn.ConvTranspose2d(256, 128,kernel_size=(3,3),stride=1,  padding=1  )   # 得到128*16*16的图像 self.bn2 = nn.BatchNorm2d(128) self.deconv2 = nn.ConvTranspose2d(128, 64,kernel_size=(4,4),stride=2,padding=1  # 64*32*32) self.bn3 = nn.BatchNorm2d(64) self.deconv3 = nn.ConvTranspose2d(64, 3,kernel_size=(4, 4),stride=2,padding=1  # 3*64*64)    def forward(self, x): x = F.relu(self.linear1(x)) x = self.bn1(x) x = x.view(-1, 256, 16, 16) x = F.relu(self.deconv1(x)) x = self.bn2(x) x = F.relu(self.deconv2(x)) x = self.bn3(x) x = torch.tanh(self.deconv3(x)) return x

四、定义鉴别器

# 定义判别器# input:1*28*28class Discriminator(nn.Module):    def __init__(self): super(Discriminator, self).__init__() self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=2) self.conv2 = nn.Conv2d(64, 128, kernel_size=3, stride=2) self.bn = nn.BatchNorm2d(128) self.fc = nn.Linear(128*15*15, 1)    def forward(self, x): x = F.dropout2d(F.leaky_relu(self.conv1(x)))  # dropout减轻判别器性能 x = F.dropout2d(F.leaky_relu(self.conv2(x)))  # (batch, 128,15,15) x = self.bn(x) x = x.view(-1, 128*15*15)   # (batch, 128,15,15)--->  (batch, 128*15*15) x = torch.sigmoid(self.fc(x)) return x

五、完整代码

import torchimport torch.nn as nnimport torch.nn.functional as Ffrom torch.utils import dataimport torchvisionfrom torchvision import transformsimport numpy as npimport matplotlib.pyplot as pltimport osimport globfrom PIL import Imageimgs_path = glob.glob('data/anime-faces/*.png')transform = transforms.Compose([transforms.ToTensor(),      transforms.Normalize(mean=0.5, std=0.5),      ])# 读取图片class Face_dataset(data.Dataset):    def __init__(self, imgs_path): self.imgs_path = imgs_path    def __getitem__(self, index): imgs_path = self.imgs_path[index] pil_img = Image.open(imgs_path) pil_img = transform(pil_img) return pil_img    def __len__(self): return len(self.imgs_path)dataset = Face_dataset(imgs_path)BATCH_SIZE = 32dataloader =data.DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)imgs_batch = next(iter(dataloader))# 定义生成器class Generator(nn.Module):    def __init__(self): super(Generator,self).__init__() self.linear1 = nn.Linear(100, 256*16*16)   self.bn1 = nn.BatchNorm1d(256*16*16) self.deconv1 = nn.ConvTranspose2d(256, 128,kernel_size=(3,3),stride=1,  padding=1  )   # 得到128*16*16的图像 self.bn2 = nn.BatchNorm2d(128) self.deconv2 = nn.ConvTranspose2d(128, 64,kernel_size=(4,4),stride=2,padding=1  # 64*32*32) self.bn3 = nn.BatchNorm2d(64) self.deconv3 = nn.ConvTranspose2d(64, 3,kernel_size=(4, 4),stride=2,padding=1  # 3*64*64)    def forward(self, x): x = F.relu(self.linear1(x)) x = self.bn1(x) x = x.view(-1, 256, 16, 16) x = F.relu(self.deconv1(x)) x = self.bn2(x) x = F.relu(self.deconv2(x)) x = self.bn3(x) x = torch.tanh(self.deconv3(x)) return x# 定义判别器# input:1*28*28class Discriminator(nn.Module):    def __init__(self): super(Discriminator, self).__init__() self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=2) self.conv2 = nn.Conv2d(64, 128, kernel_size=3, stride=2) self.bn = nn.BatchNorm2d(128) self.fc = nn.Linear(128*15*15, 1)    def forward(self, x): x = F.dropout2d(F.leaky_relu(self.conv1(x)))  # dropout减轻判别器性能 x = F.dropout2d(F.leaky_relu(self.conv2(x)))  # (batch, 128,15,15) x = self.bn(x) x = x.view(-1, 128*15*15)   # (batch, 128,15,15)--->  (batch, 128*15*15) x = torch.sigmoid(self.fc(x)) return x# 初始化模型device = 'cuda' if torch.cuda.is_available() else 'cpu'gen = Generator().to(device)dis = Discriminator().to(device)# 损失计算函数loss_function = torch.nn.BCELoss()# 定义优化器d_optim = torch.optim.Adam(dis.parameters(), lr=0.00001)g_optim = torch.optim.Adam(gen.parameters(), lr=0.0001)def generate_and_save_images(model, epoch, test_input):    predictions = model(test_input).permute(0, 2, 3, 1).cpu().numpy()     fig = plt.figure(figsize=(12, 8))    for i in range(predictions.shape[0]): plt.subplot(2, 4, i + 1) plt.imshow((predictions[i] + 1) / 2)   plt.axis("off") plt.show()test_input = torch.randn(8, 100, device=device)# 开始训练D_loss = []G_loss = []# 训练循环for epoch in range(100):    d_epoch_loss = 0    g_epoch_loss = 0    batch_count = len(dataloader)  # 返回的是批次数    # 对全部的数据集做一次迭代    for step, img in enumerate(dataloader): img = img.to(device)  # 上传到设备上 size = img.shape[0]    # 返回img的第一维的大小 random_noise = torch.randn(size, 100, device=device)   d_optim.zero_grad()  # 将上述步骤的梯度归零 real_output = dis(img)   d_real_loss = loss_function(real_output, torch.ones_like(real_output, device=device) ) d_real_loss.backward() #求解梯度 # 得到判别器在生成图像上的损失 gen_img = gen(random_noise) fake_output = dis(gen_img.detach())   d_fake_loss = loss_function(fake_output, torch.zeros_like(fake_output, device=device)) d_fake_loss.backward() d_loss = d_real_loss + d_fake_loss d_optim.step()  # 优化 # 得到生成器的损失 g_optim.zero_grad() fake_output = dis(gen_img) g_loss = loss_function(fake_output,   torch.ones_like(fake_output, device=device)) g_loss.backward() g_optim.step() with torch.no_grad():     d_epoch_loss += d_loss.item()g_epoch_loss += g_loss.item()    with torch.no_grad(): d_epoch_loss /= batch_count  # 平均loss g_epoch_loss /= batch_count D_loss.append(d_epoch_loss) G_loss.append(g_epoch_loss) generate_and_save_images(gen, epoch, test_input)    print('Epoch:', epoch)plt.plot(D_loss, label='D_loss')plt.plot(G_loss, label='G_loss')plt.legend()

六、运行结果展示

Epoch:0

Epoch:50 

Epoch:80

央视天气网