import torch
import torch.nn.functional as F
from torch import nn
from torch.utils import tensorboard
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import cv2
class Generator(nn.Module):
def __init__(self, latent_size, image_size):
super().__init__()
hidden_size = 256
self.fc1 = nn.Linear(latent_size, hidden_size)
self.fc2 = nn.Linear(hidden_size, hidden_size)
self.fc3 = nn.Linear(hidden_size, image_size)
def forward(self, x):
x = F.dropout(F.leaky_relu(self.fc1(x)))
x = F.dropout(F.leaky_relu(self.fc2(x)))
return torch.sigmoid(self.fc3(x))
class Discriminator(nn.Module):
def __init__(self, image_size):
super().__init__()
hidden_size = 48
self.fc1 = nn.Linear(image_size, hidden_size)
self.fc2 = nn.Linear(hidden_size, hidden_size)
self.fc3 = nn.Linear(hidden_size, 1)
def forward(self, x):
x = F.dropout(F.leaky_relu(self.fc1(x)))
x = F.dropout(F.leaky_relu(self.fc2(x)))
return self.fc3(x)
def train_epoch(data_loader, discriminator, generator, d_optimizer,
g_optimizer, criterion, latent_size, writer, epoch):
discriminator.train()
generator.train()
d_loss_sum = 0
g_loss_sum = 0
cnt = 0
for data, target in data_loader:
# extract real images(target == 0)
real_images = data[target == 0]
batch_size = real_images.shape[0]
real_images = real_images.reshape(batch_size, -1).cuda()
# label
real_labels = torch.ones(batch_size, 1).cuda()
fake_labels = torch.zeros(batch_size, 1).cuda()
# loss of discriminator for real data
d_real_output = discriminator(real_images)
d_real_loss = criterion(d_real_output, real_labels)
# loss of discriminator for fake data
z_for_discriminator = torch.randn(batch_size, latent_size).cuda()
d_fake_output = discriminator(generator(z_for_discriminator))
d_fake_loss = criterion(d_fake_output, fake_labels)
# optimize discriminator
d_loss = d_real_loss + d_fake_loss
d_optimizer.zero_grad()
d_loss.backward()
d_optimizer.step()
# loss of generator
z_for_generator = torch.randn(batch_size, latent_size).cuda()
g_fake_output = discriminator(generator(z_for_generator))
g_loss = criterion(g_fake_output, real_labels)
# optimize generator
g_optimizer.zero_grad()
g_loss.backward()
g_optimizer.step()
d_loss_sum += d_loss
g_loss_sum += g_loss
cnt += 1
writer.add_scalar("discriminator loss", d_loss_sum.item() / cnt, epoch)
writer.add_scalar("generator loss", g_loss_sum.item() / cnt, epoch)
def main():
torch.manual_seed(0)
width = 28
height = 28
channel_num = 1
image_size = width * height * channel_num
latent_size = 64
transform = transforms.Compose([transforms.ToTensor()])
mnist = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
data_loader = DataLoader(dataset=mnist, batch_size=1000, shuffle=True)
discriminator = Discriminator(image_size).cuda()
generator = Generator(latent_size, image_size).cuda()
criterion = nn.BCEWithLogitsLoss()
d_optimizer = torch.optim.Adam(discriminator.parameters())
g_optimizer = torch.optim.Adam(generator.parameters())
writer = tensorboard.SummaryWriter(log_dir="logs")
for epoch in range(100):
print(epoch)
train_epoch(data_loader, discriminator, generator, d_optimizer,
g_optimizer, criterion, latent_size, writer, epoch)
if epoch % 5 == 0:
generator.eval()
z = torch.randn(1, latent_size).cuda()
fake_image = generator(z)
img = fake_image[0].reshape(height, width, channel_num)
img = (img.cpu().detach().numpy() * 255).astype('uint8')
cv2.imwrite('img/{0:04d}.png'.format(epoch), img)
if __name__ == '__main__':
main()