forked from mrq/DL-Art-School
92 lines
3.4 KiB
92 lines
3.4 KiB
import os
import pathlib
import random
from munch import munchify
from import Dataset
import torch
from tqdm import tqdm
from import save_mel_buffer_to_file
from models.tacotron2 import hparams
from models.tacotron2.layers import TacotronSTFT
from models.tacotron2.taco_utils import load_wav_to_torch
from utils.util import opt_get
# A dataset that consumes the result from the script `produce_libri_stretched_dataset`, which itself is a combined
# set of clips from the librivox corpus of equal length with the sentence alignment labeled.
class StopPredictionDataset(Dataset):
def __init__(self, opt):
path = opt['path']
label_compaction = opt_get(opt, ['label_compaction'], 1)
hp = munchify(hparams.create_hparams())
cache_path = os.path.join(path, 'cache.pth')
if os.path.exists(cache_path):
self.files = torch.load(cache_path)
print("Building cache..")
self.files = list(pathlib.Path(path).glob('*.wav'))
|, cache_path)
self.sampling_rate = 22050 # Fixed since the underlying data is also fixed at this SR.
self.mel_length = 2000
self.stft = TacotronSTFT(
hp.filter_length, hp.hop_length, hp.win_length,
hp.n_mel_channels, hp.sampling_rate, hp.mel_fmin,
self.label_compaction = label_compaction
def __getitem__(self, index):
audio, _ = load_wav_to_torch(self.files[index])
starts, ends = torch.load(str(self.files[index]).replace('.wav', '_se.pth'))
if audio.std() > 1:
print(f"Something is very wrong with the given audio. std_dev={audio.std()}. file={self.files[index]}")
return None
audio.clip_(-1, 1)
mels = self.stft.mel_spectrogram(audio.unsqueeze(0))[:, :, :self.mel_length].squeeze(0)
# Form labels.
labels_start = torch.zeros((2000 // self.label_compaction,), dtype=torch.long)
for s in starts:
# Mel compaction operates at a ratio of 1/256, the dataset also allows further compaction.
s = s // (256 * self.label_compaction)
if s >= 2000//self.label_compaction:
labels_start[s] = 1
labels_end = torch.zeros((2000 // self.label_compaction,), dtype=torch.long)
for e in ends:
e = e // (256 * self.label_compaction)
if e >= 2000//self.label_compaction:
labels_end[e] = 1
return {
'mels': mels,
'labels_start': labels_start,
'labels_end': labels_end,
def __len__(self):
return len(self.files)
if __name__ == '__main__':
opt = {
'path': 'D:\\data\\audio\\libritts\\stop_dataset',
'label_compaction': 4,
ds = StopPredictionDataset(opt)
j = 0
for i in tqdm(range(100)):
b = ds[random.randint(0, len(ds))]
start_indices = torch.nonzero(b['labels_start']).squeeze(1)
end_indices = torch.nonzero(b['labels_end']).squeeze(1)
assert len(end_indices) <= len(start_indices) # There should always be more START tokens then END tokens.
for i in range(len(end_indices)):
s = start_indices[i].item()*4
e = end_indices[i].item()*4
m = b['mels'][:, s:e]
save_mel_buffer_to_file(m, f'{j}.npy')
j += 1 |