ESC50音声分類をシンプルなCNNでやってみた

※改良記事書きました。

メルスペクトラムとデータ水増しでESC50の精度を上げる - LeMU_Researchの日記

 

・npy化

import os
import pandas as pd
import numpy as np
from scipy.io import wavfile
import pyworld as pw
import cv2

data_dir = 'data'
df = pd.read_csv(os.path.join(data_dir, 'meta/esc50.csv'))

for idx in range(len(df)):
fname = df['filename'][idx]
sample_rate, signal_int = wavfile.read(os.path.join(data_dir, 'audio/' + fname))

# float化、正規化(スペクトルのオーバーフロー防止)
signal_float = signal_int.astype(np.float)
signal_float /= max(signal_float)

# スペクトル包絡
f0, t = pw.dio(signal_float, sample_rate) # 基本周波数の抽出
f0 = pw.stonemask(signal_float, f0, t, sample_rate) # refinement
sp = pw.cheaptrick(signal_float, f0, t, sample_rate) # スペクトル包絡の抽出

# リサイズ、正規化、3次元化
sp = cv2.resize(sp, (512, 512), interpolation=cv2.INTER_LINEAR)
sp /= np.max(sp)
sp = sp[np.newaxis, ...].astype('float32')

np.save('data/npy/' + fname[:-4], sp)

 

・学習

import os
import pandas as pd
import numpy as np
import torch
from torch import optim, nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm


class Down(nn.Module):

def __init__(self, in_channels, out_channels):
super().__init__()
self.maxpool_conv = nn.Sequential(
nn.MaxPool2d(2),
nn.Conv2d(in_channels, out_channels, kernel_size=3, dilation=1, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)

def forward(self, x):
return self.maxpool_conv(x)


class EscNet(nn.Module):

def __init__(self):
super().__init__()
self.features = nn.Sequential(
nn.Conv2d(1, 32, kernel_size=3, dilation=1, padding=1),
nn.BatchNorm2d(32),
nn.ReLU(inplace=True),
Down(32, 64),
Down(64, 128),
Down(128, 256),
Down(256, 512),
Down(512, 1024)
)
self.fc1 = nn.Linear(1024, 256)
self.fc2 = nn.Linear(256, 50)

def forward(self, x):
x = self.features(x)
x = F.avg_pool2d(x, kernel_size=x.size()[2:])
x = x.view(x.shape[0], -1)
x = F.relu(F.dropout2d(self.fc1(x)))
return self.fc2(x)


class EscDataset(Dataset):

def __init__(self, data_dir, train_flag):

self.data_dir = data_dir
self.df = pd.read_csv(os.path.join(self.data_dir, 'meta/esc50.csv'))
if train_flag:
self.df = self.df[self.df['fold'] != 5]
else:
self.df = self.df[self.df['fold'] == 5]
self.df = self.df.reset_index()

def __len__(self):
return len(self.df)

def __getitem__(self, idx):
fname = self.df['filename'][idx]
sp = np.load(os.path.join(self.data_dir, 'npy/' + fname[:-4] + '.npy'))
return sp, self.df['target'][idx]


def train_epoch(data_loader, model, criterion, optimizer):
model.train()
loss_sum = 0
for data, target in tqdm(data_loader):
data = data.cuda()
target = target.cuda()

output = model(data)
loss = criterion(output, target)

optimizer.zero_grad()
loss.backward()
optimizer.step()

loss_sum += loss.item()

loss_epoch = loss_sum / len(data_loader)
print('train_loss = ', loss_epoch)


def val_epoch(data_loader, model, criterion):
model.eval()
loss_sum = 0
correct = 0
data_num = 0
for data, target in tqdm(data_loader):
data = data.cuda()
target = target.cuda()

output = model(data)
loss = criterion(output, target)
loss_sum += loss.item()

_, preds = torch.max(output, axis=1)
correct += (preds == target).sum().item()
data_num += target.size(0)

loss_epoch = loss_sum / len(data_loader)
print('val_loss = ', loss_epoch)

accuracy = float(correct) / data_num
print('accuracy = ', accuracy)


def main():
data_dir = 'data'
train_dataset = EscDataset(data_dir, train_flag=True)
train_loader = DataLoader(dataset=train_dataset, batch_size=32, shuffle=True, num_workers=0)
val_dataset = EscDataset(data_dir, train_flag=False)
val_loader = DataLoader(dataset=val_dataset, batch_size=16, shuffle=False, num_workers=0)

model = EscNet().cuda()
optimizer = optim.Adam(model.parameters(), lr=0.0001)
criterion = nn.CrossEntropyLoss()

for epoch in range(10):
train_epoch(train_loader, model, criterion, optimizer)
val_epoch(val_loader, model, criterion)

state = {'state_dict': model.state_dict()}
filename = os.path.join(data_dir, 'checkpoints/{0:04d}.pth.tar'.format(epoch))
torch.save(state, filename)


if __name__ == '__main__':
main()

 

test accuracyは8%弱までしか上がりませんでした。