GANで大きな顔を描く

import os
import glob
import cv2
import torch
from torch import nn
import torch.nn.functional as F
from torch.nn.utils.spectral_norm import spectral_norm
from torch.utils import tensorboard
from torch.utils.data import Dataset, DataLoader
import torchvision
from tqdm import tqdm


class CelebDataset(Dataset):
def __init__(self, data_dir, img_width):
self.img_width = img_width
self.filelist = glob.glob(os.path.join(data_dir, '*'))

def __len__(self):
return len(self.filelist)

def __getitem__(self, idx):
img = cv2.imread(self.filelist[idx])
img = cv2.resize(img, dsize=(self.img_width, self.img_width))
img = img[:, :, ::-1]
img = img.astype('float32') / 255.0
img = img.transpose(2, 0, 1)
return img


class ResBlock(nn.Module):

def __init__(self, in_ch, out_ch, stride):
super().__init__()
conv1_ = nn.Conv2d(in_ch, out_ch, kernel_size=3, stride=stride, padding=1, bias=False)
conv2_ = nn.Conv2d(out_ch, out_ch, kernel_size=3, stride=1, padding=1, bias=False)
nn.init.xavier_uniform_(conv1_.weight.data, 1.)
nn.init.xavier_uniform_(conv2_.weight.data, 1.)
self.conv1 = spectral_norm(conv1_)
self.conv2 = spectral_norm(conv2_)

downsample_ = nn.Conv2d(in_ch, out_ch, kernel_size=1, stride=stride, padding=0, bias=False)
nn.init.xavier_uniform_(downsample_.weight.data, 1.)
self.downsample = spectral_norm(downsample_)

def forward(self, x):
out = F.leaky_relu(self.conv1(x), negative_slope=0.2, inplace=True)
out = self.conv2(out)

identity = self.downsample(x)

out += identity
return out


class Discriminator(nn.Module):
def __init__(self):
super().__init__()
self.last_ch = 512
self.block1 = ResBlock(3, self.last_ch // 16, 2)
self.block2 = ResBlock(self.last_ch // 16, self.last_ch // 8, 2)
self.block3 = ResBlock(self.last_ch // 8, self.last_ch // 4, 2)
self.block4 = ResBlock(self.last_ch // 4, self.last_ch // 2, 2)
self.block5 = ResBlock(self.last_ch // 2, self.last_ch, 2)
self.last_conv = nn.Conv2d(self.last_ch, 1, 4, 1, 0, bias=False)

def forward(self, x):
x = F.leaky_relu(self.block1(x), negative_slope=0.2, inplace=True)
x = F.leaky_relu(self.block2(x), negative_slope=0.2, inplace=True)
x = F.leaky_relu(self.block3(x), negative_slope=0.2, inplace=True)
x = F.leaky_relu(self.block4(x), negative_slope=0.2, inplace=True)
x = F.leaky_relu(self.block5(x), negative_slope=0.2, inplace=True)
return self.last_conv(x).view(-1, 1)


class Generator(nn.Module):
def __init__(self, latent_size):
super().__init__()
self.first_ch = 512
self.deconv0 = nn.ConvTranspose2d(latent_size, self.first_ch, 4, 1, 0, bias=False)
self.bn0 = nn.BatchNorm2d(self.first_ch)
self.deconv1 = nn.ConvTranspose2d(self.first_ch, self.first_ch // 2, 4, 2, 1, bias=False)
self.bn1 = nn.BatchNorm2d(self.first_ch // 2)
self.deconv2 = nn.ConvTranspose2d(self.first_ch // 2, self.first_ch // 4, 4, 2, 1, bias=False)
self.bn2 = nn.BatchNorm2d(self.first_ch // 4)
self.deconv3 = nn.ConvTranspose2d(self.first_ch // 4, self.first_ch // 8, 4, 2, 1, bias=False)
self.bn3 = nn.BatchNorm2d(self.first_ch // 8)
self.deconv4 = nn.ConvTranspose2d(self.first_ch // 8, self.first_ch // 16, 4, 2, 1, bias=False)
self.bn4 = nn.BatchNorm2d(self.first_ch // 16)
self.deconv5 = nn.ConvTranspose2d(self.first_ch // 16, 3, 4, 2, 1, bias=False)

nn.init.xavier_uniform_(self.deconv0.weight.data, 1.)
nn.init.xavier_uniform_(self.deconv1.weight.data, 1.)
nn.init.xavier_uniform_(self.deconv2.weight.data, 1.)
nn.init.xavier_uniform_(self.deconv3.weight.data, 1.)
nn.init.xavier_uniform_(self.deconv4.weight.data, 1.)
nn.init.xavier_uniform_(self.deconv5.weight.data, 1.)

def forward(self, x):
x = F.leaky_relu(self.bn0(self.deconv0(x)), negative_slope=0.2, inplace=True)
x = F.leaky_relu(self.bn1(self.deconv1(x)), negative_slope=0.2, inplace=True)
x = F.leaky_relu(self.bn2(self.deconv2(x)), negative_slope=0.2, inplace=True)
x = F.leaky_relu(self.bn3(self.deconv3(x)), negative_slope=0.2, inplace=True)
x = F.leaky_relu(self.bn4(self.deconv4(x)), negative_slope=0.2, inplace=True)
return torch.tanh(self.deconv5(x))


def train_epoch(data_loader, discriminator, generator, d_optimizer, g_optimizer,
criterion, latent_size, writer, epoch):

discriminator.train()
generator.train()

d_loss_sum = 0
g_loss_sum = 0
cnt = 0

for real_images in tqdm(data_loader):

real_images = real_images.cuda()

real_labels = torch.ones(real_images.shape[0], 1).cuda()
fake_labels = torch.zeros(real_images.shape[0], 1).cuda()

# loss of discriminator for real data
d_real_output = discriminator(real_images)
# d_real_loss = criterion(d_real_output, real_labels)
d_real_loss = nn.ReLU()(1.0 - d_real_output).mean()

# loss of discriminator for fake data
z_for_discriminator = torch.randn(real_images.shape[0], latent_size, 1, 1).cuda()
d_fake_output = discriminator(generator(z_for_discriminator))
# d_fake_loss = criterion(d_fake_output, fake_labels)
d_fake_loss = nn.ReLU()(1.0 + d_fake_output).mean()

# optimize discriminator
d_loss = d_real_loss + d_fake_loss
d_optimizer.zero_grad()
d_loss.backward()
d_optimizer.step()

# loss of generator
z_for_generator = torch.randn(real_images.shape[0], latent_size, 1, 1).cuda()
fake_imgs = generator(z_for_generator)
g_fake_output = discriminator(fake_imgs)
# g_loss = criterion(g_fake_output, real_labels)
g_loss = -g_fake_output.mean()

# optimize generator
g_optimizer.zero_grad()
g_loss.backward()
g_optimizer.step()

d_loss_sum += d_loss
g_loss_sum += g_loss

cnt += 1

writer.add_scalar("discriminator loss", d_loss_sum.item() / cnt, epoch)
writer.add_scalar("generator loss", g_loss_sum.item() / cnt, epoch)
torchvision.utils.save_image(fake_imgs[:64] * 0.5 + 0.5, "img/epoch_{0:03d}.png".format(epoch), nrow=8)


def main():

torch.manual_seed(0)

img_width = 128
latent_size = 128
batch_size = 512

transform = torchvision.transforms.Compose([
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize(mean=(0.5,), std=(0.5,))
])

dataset = CelebDataset('img_align_celeba', img_width=img_width)
data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=8)

discriminator = Discriminator().cuda()
generator = Generator(latent_size).cuda()

discriminator = nn.DataParallel(discriminator)
generator = nn.DataParallel(generator)

criterion = nn.BCEWithLogitsLoss()
d_optimizer = torch.optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))
g_optimizer = torch.optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))

writer = tensorboard.SummaryWriter(log_dir="logs")
if not os.path.exists('img'):
os.mkdir('img')
if not os.path.exists('checkpoints'):
os.mkdir('checkpoints')

for epoch in range(300):
print(epoch)
train_epoch(data_loader, discriminator, generator, d_optimizer, g_optimizer,
criterion, latent_size, writer, epoch)

if epoch % 10 == 0:
torch.save(discriminator.state_dict(), 'checkpoints/discriminator{0:03d}.pth'.format(epoch))
torch.save(generator.state_dict(), 'checkpoints/generator{0:03d}.pth'.format(epoch))


if __name__ == '__main__':
main()

f:id:LeMU_Research:20201118225211p:plain

f:id:LeMU_Research:20201118225518p:plainf:id:LeMU_Research:20201118225529p:plain