CycleGAN

import os
import glob
import random
import cv2
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.utils import tensorboard
import torchvision
from tqdm import tqdm
import itertools


class CycleDataset(Dataset):
def __init__(self, A_dir, B_dir, transform):
self.A_paths = glob.glob(os.path.join(A_dir, '*'))
self.B_paths = glob.glob(os.path.join(B_dir, '*'))
self.A_size = len(self.A_paths)
self.B_size = len(self.B_paths)
self.transform = transform

def __len__(self):
return max(self.A_size, self.B_size)

def __getitem__(self, idx):
A_img = cv2.imread(self.A_paths[idx % self.A_size])[:, :, ::-1].astype('float32') / 255.0
B_img = cv2.imread(self.B_paths[random.randint(0, self.B_size - 1)])[:, :, ::-1].astype('float32') / 255.0
return self.transform(A_img), self.transform(B_img)


class ResBlock(nn.Module):
def __init__(self, ch):
super().__init__()

self.conv1 = nn.Conv2d(ch, ch, kernel_size=3, stride=1, padding=1)
self.bn1 = nn.InstanceNorm2d(ch)
self.conv2 = nn.Conv2d(ch, ch, kernel_size=3, stride=1, padding=1)
self.bn2 = nn.InstanceNorm2d(ch)

nn.init.xavier_uniform_(self.conv1.weight.data, 1.)
nn.init.xavier_uniform_(self.conv2.weight.data, 1.)

def forward(self, x):
out = F.leaky_relu(self.bn1(self.conv1(x)), negative_slope=0.2, inplace=True)
out = self.bn2(self.conv2(out))
return F.leaky_relu(x + out, negative_slope=0.2, inplace=True)


class Generator(nn.Module):
def __init__(self):
super().__init__()

self.conv_first = nn.Conv2d(3, 64, kernel_size=7, padding=3)
self.bn_first = nn.InstanceNorm2d(64)
self.downsample1 = nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1)
self.bn_down1 = nn.InstanceNorm2d(128)
self.downsample2 = nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1)
self.bn_down2 = nn.InstanceNorm2d(256)

res_blocks = []
for _ in range(9):
res_blocks += [ResBlock(256)]
self.res_blocks = nn.Sequential(*res_blocks)

self.upsample1 = nn.ConvTranspose2d(256, 128, 4, 2, 1)
self.bn_up1 = nn.InstanceNorm2d(128)
self.upsample2 = nn.ConvTranspose2d(128, 64, 4, 2, 1)
self.bn_up2 = nn.InstanceNorm2d(64)
self.conv_last = nn.Conv2d(64, 3, kernel_size=7, padding=3)

nn.init.xavier_uniform_(self.conv_first.weight.data, 1.)
nn.init.xavier_uniform_(self.downsample1.weight.data, 1.)
nn.init.xavier_uniform_(self.downsample2.weight.data, 1.)
nn.init.xavier_uniform_(self.upsample1.weight.data, 1.)
nn.init.xavier_uniform_(self.upsample2.weight.data, 1.)
nn.init.xavier_uniform_(self.conv_last.weight.data, 1.)

def forward(self, x):
x = F.leaky_relu(self.bn_first(self.conv_first(x)), negative_slope=0.2, inplace=True)
x = F.leaky_relu(self.bn_down1(self.downsample1(x)), negative_slope=0.2, inplace=True)
x = F.leaky_relu(self.bn_down2(self.downsample2(x)), negative_slope=0.2, inplace=True)
x = self.res_blocks(x)
x = F.leaky_relu(self.bn_up1(self.upsample1(x)), negative_slope=0.2, inplace=True)
x = F.leaky_relu(self.bn_up2(self.upsample2(x)), negative_slope=0.2, inplace=True)
return torch.tanh(self.conv_last(x))


class Discriminator(nn.Module):
def __init__(self):
super().__init__()

self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1)
self.conv2 = nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1)
self.bn2 = nn.InstanceNorm2d(128)
self.conv3 = nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1)
self.bn3 = nn.InstanceNorm2d(256)
self.conv4 = nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1)
self.bn4 = nn.InstanceNorm2d(512)
self.conv5 = nn.Conv2d(512, 1, kernel_size=3, stride=1, padding=1)

nn.init.xavier_uniform_(self.conv1.weight.data, 1.)
nn.init.xavier_uniform_(self.conv2.weight.data, 1.)
nn.init.xavier_uniform_(self.conv3.weight.data, 1.)
nn.init.xavier_uniform_(self.conv4.weight.data, 1.)
nn.init.xavier_uniform_(self.conv5.weight.data, 1.)

def forward(self, x):
x = F.leaky_relu(self.conv1(x), negative_slope=0.2, inplace=True)
x = F.leaky_relu(self.bn2(self.conv2(x)), negative_slope=0.2, inplace=True)
x = F.leaky_relu(self.bn3(self.conv3(x)), negative_slope=0.2, inplace=True)
x = F.leaky_relu(self.bn4(self.conv4(x)), negative_slope=0.2, inplace=True)
return torch.sigmoid(self.conv5(x))


def train_epoch(train_loader, G_AB, G_BA, D_A, D_B, optimizer_G, optimizer_D,
criterion_MSE, criterion_L1, writer, epoch):

loss_G_AB_sum = 0
loss_G_BA_sum = 0
loss_cycle_A_sum = 0
loss_cycle_B_sum = 0
loss_idt_A_sum = 0
loss_idt_B_sum = 0
loss_D_A_real_sum = 0
loss_D_A_fake_sum = 0
loss_D_B_real_sum = 0
loss_D_B_fake_sum = 0

for A_real_img, B_real_img in tqdm(train_loader):

A_real_img = A_real_img.cuda()
B_real_img = B_real_img.cuda()

real_labels = torch.ones(A_real_img.shape[0], 1, 16, 16).cuda()
fake_labels = torch.zeros(A_real_img.shape[0], 1, 16, 16).cuda()

for param in D_A.parameters():
param.requires_grad = False
for param in D_B.parameters():
param.requires_grad = False

B_fake_img = G_AB(A_real_img)
A_cycle_img = G_BA(B_fake_img)
A_fake_img = G_BA(B_real_img)
B_cycle_img = G_AB(A_fake_img)

A_idt_img = G_BA(A_real_img)
B_idt_img = G_AB(B_real_img)

loss_G_AB = criterion_MSE(D_B(B_fake_img), real_labels)
loss_G_BA = criterion_MSE(D_A(A_fake_img), real_labels)
loss_cycle_A = criterion_L1(A_cycle_img, A_real_img) * 10.0
loss_cycle_B = criterion_L1(B_cycle_img, B_real_img) * 10.0
loss_idt_A = criterion_L1(A_idt_img, A_real_img) * 5.0
loss_idt_B = criterion_L1(B_idt_img, B_real_img) * 5.0

loss_G_AB_sum += loss_G_AB.item()
loss_G_BA_sum += loss_G_BA.item()
loss_cycle_A_sum += loss_cycle_A.item()
loss_cycle_B_sum += loss_cycle_B.item()
loss_idt_A_sum += loss_idt_A.item()
loss_idt_B_sum += loss_idt_B.item()

optimizer_G.zero_grad()
loss_G = loss_G_AB + loss_G_BA + loss_cycle_A + loss_cycle_B + loss_idt_A + loss_idt_B
loss_G.backward()
optimizer_G.step()

for param in D_A.parameters():
param.requires_grad = True
for param in D_B.parameters():
param.requires_grad = True

pred_A_real = D_A(A_real_img)
pred_A_fake = D_A(G_BA(B_real_img))
loss_D_A_real = criterion_MSE(pred_A_real, real_labels) * 0.5
loss_D_A_fake = criterion_MSE(pred_A_fake, fake_labels) * 0.5

pred_B_real = D_B(B_real_img)
pred_B_fake = D_B(G_AB(A_real_img))
loss_D_B_real = criterion_MSE(pred_B_real, real_labels) * 0.5
loss_D_B_fake = criterion_MSE(pred_B_fake, fake_labels) * 0.5

loss_D_A_real_sum += loss_D_A_real.item()
loss_D_A_fake_sum += loss_D_A_fake.item()
loss_D_B_real_sum += loss_D_B_real.item()
loss_D_B_fake_sum += loss_D_B_fake.item()

optimizer_D.zero_grad()
loss_D = loss_D_A_real + loss_D_A_fake + loss_D_B_real + loss_D_B_fake
loss_D.backward()
optimizer_D.step()

show_imgs = torch.zeros(8, 3, 128, 128)
show_imgs[0] = A_real_img[0]
show_imgs[1] = B_fake_img[0]
show_imgs[2] = A_cycle_img[0]
show_imgs[3] = A_idt_img[0]
show_imgs[4] = B_real_img[0]
show_imgs[5] = A_fake_img[0]
show_imgs[6] = B_cycle_img[0]
show_imgs[7] = B_idt_img[0]
torchvision.utils.save_image((show_imgs + 1) / 2, "img/{0:03d}.png".format(epoch), nrow=4)

writer.add_scalar("loss_G_AB", loss_G_AB_sum / len(train_loader), epoch)
writer.add_scalar("loss_G_BA", loss_G_BA_sum / len(train_loader), epoch)
writer.add_scalar("loss_cycle_A", loss_cycle_A_sum / len(train_loader), epoch)
writer.add_scalar("loss_cycle_B", loss_cycle_B_sum / len(train_loader), epoch)
writer.add_scalar("loss_idt_A", loss_idt_A_sum / len(train_loader), epoch)
writer.add_scalar("loss_idt_B", loss_idt_B_sum / len(train_loader), epoch)
writer.add_scalar("loss_D_A_real", loss_D_A_real_sum / len(train_loader), epoch)
writer.add_scalar("loss_D_A_fake", loss_D_A_fake_sum / len(train_loader), epoch)
writer.add_scalar("loss_D_B_real", loss_D_B_real_sum / len(train_loader), epoch)
writer.add_scalar("loss_D_B_fake", loss_D_B_fake_sum / len(train_loader), epoch)


def main():

batch_size = 16
epoch_num = 50

transform = torchvision.transforms.Compose([
torchvision.transforms.ToTensor(),
torchvision.transforms.Resize(128),
torchvision.transforms.Normalize(mean=(0.5,), std=(0.5,))
])

train_dataset = CycleDataset('monet2photo/trainA', 'monet2photo/trainB', transform=transform)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=8)

G_AB = Generator().cuda()
G_BA = Generator().cuda()
D_A = Discriminator().cuda()
D_B = Discriminator().cuda()

G_AB = torch.nn.DataParallel(G_AB)
G_BA = torch.nn.DataParallel(G_BA)
D_A = torch.nn.DataParallel(D_A)
D_B = torch.nn.DataParallel(D_B)

optimizer_G = torch.optim.Adam(itertools.chain(G_AB.parameters(), G_BA.parameters()), lr=0.0002, betas=(0.5, 0.999))
optimizer_D = torch.optim.Adam(itertools.chain(D_A.parameters(), D_B.parameters()), lr=0.0002, betas=(0.5, 0.999))

criterion_MSE = torch.nn.MSELoss()
criterion_L1 = torch.nn.L1Loss()

writer = tensorboard.SummaryWriter(log_dir="logs")
if not os.path.exists('img'):
os.mkdir('img')
if not os.path.exists('checkpoints'):
os.mkdir('checkpoints')

for epoch in range(epoch_num):
train_epoch(train_loader, G_AB, G_BA, D_A, D_B, optimizer_G, optimizer_D,
criterion_MSE, criterion_L1, writer, epoch)
if epoch % 5 == 0:
torch.save(G_AB.module.state_dict(), 'checkpoints/G_AB{0:03d}.pth'.format(epoch))
torch.save(G_BA.module.state_dict(), 'checkpoints/G_BA{0:03d}.pth'.format(epoch))
torch.save(D_A.module.state_dict(), 'checkpoints/D_A{0:03d}.pth'.format(epoch))
torch.save(D_B.module.state_dict(), 'checkpoints/D_B{0:03d}.pth'.format(epoch))


if __name__ == '__main__':
main()