GANで0を描く

import torch
import torch.nn.functional as F
from torch import nn
from torch.utils import tensorboard
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import cv2


class Generator(nn.Module):
def __init__(self, latent_size, image_size):
super().__init__()
hidden_size = 256
self.fc1 = nn.Linear(latent_size, hidden_size)
self.fc2 = nn.Linear(hidden_size, hidden_size)
self.fc3 = nn.Linear(hidden_size, image_size)

def forward(self, x):
x = F.dropout(F.leaky_relu(self.fc1(x)))
x = F.dropout(F.leaky_relu(self.fc2(x)))
return torch.sigmoid(self.fc3(x))


class Discriminator(nn.Module):
def __init__(self, image_size):
super().__init__()
hidden_size = 48
self.fc1 = nn.Linear(image_size, hidden_size)
self.fc2 = nn.Linear(hidden_size, hidden_size)
self.fc3 = nn.Linear(hidden_size, 1)

def forward(self, x):
x = F.dropout(F.leaky_relu(self.fc1(x)))
x = F.dropout(F.leaky_relu(self.fc2(x)))
return self.fc3(x)


def train_epoch(data_loader, discriminator, generator, d_optimizer,
g_optimizer, criterion, latent_size, writer, epoch):

discriminator.train()
generator.train()

d_loss_sum = 0
g_loss_sum = 0
cnt = 0

for data, target in data_loader:

# extract real images(target == 0)
real_images = data[target == 0]
batch_size = real_images.shape[0]
real_images = real_images.reshape(batch_size, -1).cuda()

# label
real_labels = torch.ones(batch_size, 1).cuda()
fake_labels = torch.zeros(batch_size, 1).cuda()

# loss of discriminator for real data
d_real_output = discriminator(real_images)
d_real_loss = criterion(d_real_output, real_labels)

# loss of discriminator for fake data
z_for_discriminator = torch.randn(batch_size, latent_size).cuda()
d_fake_output = discriminator(generator(z_for_discriminator))
d_fake_loss = criterion(d_fake_output, fake_labels)

# optimize discriminator
d_loss = d_real_loss + d_fake_loss
d_optimizer.zero_grad()
d_loss.backward()
d_optimizer.step()

# loss of generator
z_for_generator = torch.randn(batch_size, latent_size).cuda()
g_fake_output = discriminator(generator(z_for_generator))
g_loss = criterion(g_fake_output, real_labels)

# optimize generator
g_optimizer.zero_grad()
g_loss.backward()
g_optimizer.step()

d_loss_sum += d_loss
g_loss_sum += g_loss

cnt += 1

writer.add_scalar("discriminator loss", d_loss_sum.item() / cnt, epoch)
writer.add_scalar("generator loss", g_loss_sum.item() / cnt, epoch)


def main():

torch.manual_seed(0)

width = 28
height = 28
channel_num = 1
image_size = width * height * channel_num
latent_size = 64

transform = transforms.Compose([transforms.ToTensor()])
mnist = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
data_loader = DataLoader(dataset=mnist, batch_size=1000, shuffle=True)

discriminator = Discriminator(image_size).cuda()
generator = Generator(latent_size, image_size).cuda()

criterion = nn.BCEWithLogitsLoss()
d_optimizer = torch.optim.Adam(discriminator.parameters())
g_optimizer = torch.optim.Adam(generator.parameters())

writer = tensorboard.SummaryWriter(log_dir="logs")

for epoch in range(100):
print(epoch)
train_epoch(data_loader, discriminator, generator, d_optimizer,
g_optimizer, criterion, latent_size, writer, epoch)
if epoch % 5 == 0:
generator.eval()
z = torch.randn(1, latent_size).cuda()
fake_image = generator(z)
img = fake_image[0].reshape(height, width, channel_num)
img = (img.cpu().detach().numpy() * 255).astype('uint8')
cv2.imwrite('img/{0:04d}.png'.format(epoch), img)


if __name__ == '__main__':
main()

f:id:LeMU_Research:20201114021805p:plain      f:id:LeMU_Research:20201114021839p:plain

f:id:LeMU_Research:20201114021956p:plain