PyTorch Lightning

import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torchvision import transforms
from pytorch_lightning import LightningModule, LightningDataModule, Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from torchmetrics import Accuracy
from sklearn.model_selection import train_test_split


class MNISTModel(LightningModule):
def __init__(self):
super().__init__()
self.l1 = torch.nn.Linear(28 * 28, 1024)
self.l2 = torch.nn.Linear(1024, 10)
self.accuracy = Accuracy()

def forward(self, x):
x = torch.relu(self.l1(x.view(x.size(0), -1)))
x = self.l2(x)
return x

def training_step(self, batch):
loss, accuracy = self.step(batch)
self.log_dict({'train_loss': loss,
'train_accuracy': accuracy,
'step': torch.tensor(self.current_epoch, dtype=torch.float32)},
on_step=False, on_epoch=True)
return loss

def validation_step(self, batch, batch_idx):
loss, accuracy = self.step(batch)
self.log_dict({'val_loss': loss,
'val_accuracy': accuracy,
'step': torch.tensor(self.current_epoch, dtype=torch.float32)},
on_step=False, on_epoch=True)
return loss

def step(self, batch):
x, y = batch
output = self.forward(x)
loss = F.cross_entropy(output, y)
accuracy = self.accuracy(output, y)
return loss, accuracy

def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=0.02)


class MNISTDataModule(LightningDataModule):
def __init__(self, batch_size, num_workers):
super().__init__()
self.batch_size = batch_size
self.num_workers = num_workers
mnist_train_val = MNIST('data', train=True, download=True, transform=transforms.ToTensor())
self.mnist_test = MNIST('data', train=False, download=True, transform=transforms.ToTensor())
self.mnist_train, self.mnist_val = train_test_split(mnist_train_val, test_size=0.1, random_state=0)

# def setup(self, stage):
# mnist_train_val = MNIST('data', train=True, download=True, transform=transforms.ToTensor())
# self.mnist_test = MNIST('data', train=False, download=True, transform=transforms.ToTensor())
# self.mnist_train, self.mnist_val = train_test_split(mnist_train_val, test_size=0.1, random_state=0)

def train_dataloader(self):
return DataLoader(self.mnist_train, shuffle=True, batch_size=self.batch_size, num_workers=self.num_workers)

def val_dataloader(self):
return DataLoader(self.mnist_val, shuffle=False, batch_size=self.batch_size, num_workers=self.num_workers)

def test_dataloader(self):
return DataLoader(self.mnist_test, shuffle=False, batch_size=self.batch_size, num_workers=self.num_workers)


def train():

data_module = MNISTDataModule(batch_size=32, num_workers=0)
model = MNISTModel()

model_checkpoint = ModelCheckpoint(monitor='val_loss', filename='{epoch:02d}', mode='min')
trainer = Trainer(gpus=[1], max_epochs=50, callbacks=[model_checkpoint])
trainer.fit(model, data_module)


def test():

data_module = MNISTDataModule(batch_size=1, num_workers=0)
data_module.setup(stage='test')
data_loader = data_module.test_dataloader()

model = MNISTModel.load_from_checkpoint('lightning_logs/version_17/checkpoints/epoch=06.ckpt')
model.freeze()

for data, target in data_loader:
out = model(data)


if __name__ == '__main__':
train()
# test()