GANで顔を描く

import os
import glob
import cv2
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils import tensorboard
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm


class CelebDataset(Dataset):
def __init__(self, data_dir, img_width):
self.img_width = img_width
self.filelist = glob.glob(os.path.join(data_dir, '*'))

def __len__(self):
return len(self.filelist)

def __getitem__(self, idx):
img = cv2.imread(self.filelist[idx])
img = cv2.resize(img, dsize=(self.img_width, self.img_width)).astype('float32') / 255.0
img = img.transpose(2, 0, 1)
return img


class Discriminator(nn.Module):
def __init__(self):
super().__init__()
self.last_ch = 128
self.conv1 = nn.Conv2d(3, self.last_ch // 8, 4, 2, 1, bias=False)
self.bn1 = nn.BatchNorm2d(self.last_ch // 8)
self.conv2 = nn.Conv2d(self.last_ch // 8, self.last_ch // 4, 4, 2, 1, bias=False)
self.bn2 = nn.BatchNorm2d(self.last_ch // 4)
self.conv3 = nn.Conv2d(self.last_ch // 4, self.last_ch // 2, 4, 2, 1, bias=False)
self.bn3 = nn.BatchNorm2d(self.last_ch // 2)
self.conv4 = nn.Conv2d(self.last_ch // 2, self.last_ch, 4, 2, 1, bias=False)
self.bn4 = nn.BatchNorm2d(self.last_ch)
self.conv5 = nn.Conv2d(self.last_ch, 1, 4, 1, 0, bias=False)

def forward(self, x):
x = F.leaky_relu(self.bn1(self.conv1(x)))
x = F.leaky_relu(self.bn2(self.conv2(x)))
x = F.leaky_relu(self.bn3(self.conv3(x)))
x = F.leaky_relu(self.bn4(self.conv4(x)))
return self.conv5(x).view(-1, 1)


class Generator(nn.Module):
def __init__(self, latent_size):
super().__init__()
self.first_ch = 512
self.deconv0 = nn.ConvTranspose2d(latent_size, self.first_ch, 4, 1, 0, bias=False)
self.bn0 = nn.BatchNorm2d(self.first_ch)
self.deconv1 = nn.ConvTranspose2d(self.first_ch, self.first_ch // 2, 4, 2, 1, bias=False)
self.bn1 = nn.BatchNorm2d(self.first_ch // 2)
self.deconv2 = nn.ConvTranspose2d(self.first_ch // 2, self.first_ch // 4, 4, 2, 1, bias=False)
self.bn2 = nn.BatchNorm2d(self.first_ch // 4)
self.deconv3 = nn.ConvTranspose2d(self.first_ch // 4, self.first_ch // 8, 4, 2, 1, bias=False)
self.bn3 = nn.BatchNorm2d(self.first_ch // 8)
self.deconv4 = nn.ConvTranspose2d(self.first_ch // 8, 3, 4, 2, 1, bias=False)

def forward(self, x):
x = F.leaky_relu(self.bn0(self.deconv0(x)))
x = F.leaky_relu(self.bn1(self.deconv1(x)))
x = F.leaky_relu(self.bn2(self.deconv2(x)))
x = F.leaky_relu(self.bn3(self.deconv3(x)))
return torch.sigmoid(self.deconv4(x))


def train_epoch(data_loader, discriminator, generator, d_optimizer, g_optimizer,
criterion, latent_size, writer, epoch):

discriminator.train()
generator.train()

d_loss_sum = 0
g_loss_sum = 0
cnt = 0

for real_images in tqdm(data_loader):

real_images = real_images.cuda()

real_labels = torch.ones(real_images.shape[0], 1).cuda()
fake_labels = torch.zeros(real_images.shape[0], 1).cuda()

# loss of discriminator for real data
d_real_output = discriminator(real_images)
d_real_loss = criterion(d_real_output, real_labels)

# loss of discriminator for fake data
z_for_discriminator = torch.randn(real_images.shape[0], latent_size, 1, 1).cuda()
d_fake_output = discriminator(generator(z_for_discriminator))
d_fake_loss = criterion(d_fake_output, fake_labels)

# optimize discriminator
d_loss = d_real_loss + d_fake_loss
d_optimizer.zero_grad()
d_loss.backward()
d_optimizer.step()

# loss of generator
z_for_generator = torch.randn(real_images.shape[0], latent_size, 1, 1).cuda()
g_fake_output = discriminator(generator(z_for_generator))
g_loss = criterion(g_fake_output, real_labels)

# optimize generator
g_optimizer.zero_grad()
g_loss.backward()
g_optimizer.step()

d_loss_sum += d_loss
g_loss_sum += g_loss

cnt += 1

writer.add_scalar("discriminator loss", d_loss_sum.item() / cnt, epoch)
writer.add_scalar("generator loss", g_loss_sum.item() / cnt, epoch)


def main():

torch.manual_seed(0)

img_width = 64
latent_size = 100
batch_size = 512

dataset = CelebDataset('img_align_celeba', img_width=img_width)
data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=8)

discriminator = Discriminator().cuda()
generator = Generator(latent_size).cuda()

discriminator = nn.DataParallel(discriminator)
generator = nn.DataParallel(generator)

criterion = nn.BCEWithLogitsLoss()
d_optimizer = torch.optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))
g_optimizer = torch.optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))

writer = tensorboard.SummaryWriter(log_dir="logs")

for epoch in range(300):
print(epoch)
train_epoch(data_loader, discriminator, generator, d_optimizer, g_optimizer,
criterion, latent_size, writer, epoch)

generator.eval()
z = torch.randn(1, latent_size, 1, 1).cuda()
fake_image = generator(z)
img = (fake_image[0].cpu().detach().numpy().transpose(1, 2, 0) * 255).astype('uint8')
cv2.imwrite('img/{0:04d}.png'.format(epoch), img)

if epoch % 10 == 0:
torch.save(discriminator.state_dict(), 'checkpoints/discriminator{0:03d}.pth'.format(epoch))
torch.save(generator.state_dict(), 'checkpoints/generator{0:03d}.pth'.format(epoch))


if __name__ == '__main__':
main()

f:id:LeMU_Research:20201115024535p:plain f:id:LeMU_Research:20201115024545p:plain

f:id:LeMU_Research:20201115024715p:plain