pytorch lightning 1.9.4

import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader, random_split
from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.loggers import CSVLogger
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from torchvision import transforms
from torchvision.datasets import MNIST
from torchmetrics import Accuracy


class LitMNIST(LightningModule):

def __init__(self):
super().__init__()

input_size = 28 * 28
hidden_size = 256
num_classes = 10
self.batch_size = 256
self.learning_rate = 2e-4
self.num_workers = 0
self.data_dir = 'data'

self.transform = transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,)),
]
)

self.model = nn.Sequential(
nn.Flatten(),
nn.Linear(input_size, hidden_size),
nn.ReLU(),
nn.Dropout(0.1),
nn.Linear(hidden_size, hidden_size),
nn.ReLU(),
nn.Dropout(0.1),
nn.Linear(hidden_size, num_classes),
)

self.val_accuracy = Accuracy(task="multiclass", num_classes=num_classes)
self.test_accuracy = Accuracy(task="multiclass", num_classes=num_classes)

def forward(self, x):
output = self.model(x)
return output

def training_step(self, batch, batch_idx):
x, y = batch
output = self(x)
loss = F.cross_entropy(output, y)
self.log("train_loss", loss, on_step=False, on_epoch=True)
return loss

def validation_step(self, batch, batch_idx):
x, y = batch
output = self(x)
loss = F.cross_entropy(output, y)
preds = torch.argmax(output, dim=1)
self.val_accuracy.update(preds, y)

self.log("val_loss", loss, on_step=False, on_epoch=True)
self.log("val_acc", self.val_accuracy, on_step=False, on_epoch=True)

def test_step(self, batch, batch_idx):
x, y = batch
output = self(x)
loss = F.cross_entropy(output, y)
preds = torch.argmax(output, dim=1)
self.test_accuracy.update(preds, y)

self.log("test_loss", loss, on_step=False, on_epoch=True)
self.log("test_acc", self.test_accuracy, on_step=False, on_epoch=True)

def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
return optimizer

def setup(self, stage=None):
# Assign train/val datasets for use in dataloaders
if stage == "fit" or stage is None:
mnist = MNIST(self.data_dir, train=True, download=True, transform=self.transform)
self.mnist_train, self.mnist_val = random_split(mnist, [55000, 5000])

# Assign test dataset for use in dataloader(s)
if stage == "test" or stage is None:
self.mnist_test = MNIST(self.data_dir, train=False, download=True, transform=self.transform)

def train_dataloader(self):
return DataLoader(self.mnist_train, batch_size=self.batch_size, num_workers=self.num_workers)

def val_dataloader(self):
return DataLoader(self.mnist_val, batch_size=self.batch_size, num_workers=self.num_workers)

def test_dataloader(self):
return DataLoader(self.mnist_test, batch_size=self.batch_size, num_workers=self.num_workers)


def main():
model = LitMNIST()
early_stop_callback = EarlyStopping(monitor="val_loss", patience=10, mode="min")
trainer = Trainer(
accelerator="auto",
devices=1,
max_epochs=100,
logger=CSVLogger(save_dir="logs/"),
callbacks=[early_stop_callback]
)
trainer.fit(model)
trainer.test(model)


if __name__ == '__main__':
main()