DL-Art-School/codes/data/audio/stop_prediction_dataset_2.py
2021-08-18 18:29:38 -06:00

92 lines
3.4 KiB
Python

import os
import pathlib
import random
from munch import munchify
from torch.utils.data import Dataset
import torch
from tqdm import tqdm
from data.audio.nv_tacotron_dataset import save_mel_buffer_to_file
from models.tacotron2 import hparams
from models.tacotron2.layers import TacotronSTFT
from models.tacotron2.taco_utils import load_wav_to_torch
from utils.util import opt_get
# A dataset that consumes the result from the script `produce_libri_stretched_dataset`, which itself is a combined
# set of clips from the librivox corpus of equal length with the sentence alignment labeled.
class StopPredictionDataset(Dataset):
def __init__(self, opt):
path = opt['path']
label_compaction = opt_get(opt, ['label_compaction'], 1)
hp = munchify(hparams.create_hparams())
cache_path = os.path.join(path, 'cache.pth')
if os.path.exists(cache_path):
self.files = torch.load(cache_path)
else:
print("Building cache..")
self.files = list(pathlib.Path(path).glob('*.wav'))
torch.save(self.files, cache_path)
self.sampling_rate = 22050 # Fixed since the underlying data is also fixed at this SR.
self.mel_length = 2000
self.stft = TacotronSTFT(
hp.filter_length, hp.hop_length, hp.win_length,
hp.n_mel_channels, hp.sampling_rate, hp.mel_fmin,
hp.mel_fmax)
self.label_compaction = label_compaction
def __getitem__(self, index):
audio, _ = load_wav_to_torch(self.files[index])
starts, ends = torch.load(str(self.files[index]).replace('.wav', '_se.pth'))
if audio.std() > 1:
print(f"Something is very wrong with the given audio. std_dev={audio.std()}. file={self.files[index]}")
return None
audio.clip_(-1, 1)
mels = self.stft.mel_spectrogram(audio.unsqueeze(0))[:, :, :self.mel_length].squeeze(0)
# Form labels.
labels_start = torch.zeros((2000 // self.label_compaction,), dtype=torch.long)
for s in starts:
# Mel compaction operates at a ratio of 1/256, the dataset also allows further compaction.
s = s // (256 * self.label_compaction)
if s >= 2000//self.label_compaction:
continue
labels_start[s] = 1
labels_end = torch.zeros((2000 // self.label_compaction,), dtype=torch.long)
for e in ends:
e = e // (256 * self.label_compaction)
if e >= 2000//self.label_compaction:
continue
labels_end[e] = 1
return {
'mels': mels,
'labels_start': labels_start,
'labels_end': labels_end,
}
def __len__(self):
return len(self.files)
if __name__ == '__main__':
opt = {
'path': 'D:\\data\\audio\\libritts\\stop_dataset',
'label_compaction': 4,
}
ds = StopPredictionDataset(opt)
j = 0
for i in tqdm(range(100)):
b = ds[random.randint(0, len(ds))]
start_indices = torch.nonzero(b['labels_start']).squeeze(1)
end_indices = torch.nonzero(b['labels_end']).squeeze(1)
assert len(end_indices) <= len(start_indices) # There should always be more START tokens then END tokens.
for i in range(len(end_indices)):
s = start_indices[i].item()*4
e = end_indices[i].item()*4
m = b['mels'][:, s:e]
save_mel_buffer_to_file(m, f'{j}.npy')
j += 1