92 lines
3.4 KiB
Python
92 lines
3.4 KiB
Python
|
import os
|
||
|
import pathlib
|
||
|
import random
|
||
|
|
||
|
from munch import munchify
|
||
|
from torch.utils.data import Dataset
|
||
|
import torch
|
||
|
from tqdm import tqdm
|
||
|
|
||
|
from data.audio.nv_tacotron_dataset import save_mel_buffer_to_file
|
||
|
from models.tacotron2 import hparams
|
||
|
from models.tacotron2.layers import TacotronSTFT
|
||
|
from models.tacotron2.taco_utils import load_wav_to_torch
|
||
|
from utils.util import opt_get
|
||
|
|
||
|
|
||
|
# A dataset that consumes the result from the script `produce_libri_stretched_dataset`, which itself is a combined
|
||
|
# set of clips from the librivox corpus of equal length with the sentence alignment labeled.
|
||
|
class StopPredictionDataset(Dataset):
|
||
|
def __init__(self, opt):
|
||
|
path = opt['path']
|
||
|
label_compaction = opt_get(opt, ['label_compaction'], 1)
|
||
|
hp = munchify(hparams.create_hparams())
|
||
|
cache_path = os.path.join(path, 'cache.pth')
|
||
|
if os.path.exists(cache_path):
|
||
|
self.files = torch.load(cache_path)
|
||
|
else:
|
||
|
print("Building cache..")
|
||
|
self.files = list(pathlib.Path(path).glob('*.wav'))
|
||
|
torch.save(self.files, cache_path)
|
||
|
self.sampling_rate = 22050 # Fixed since the underlying data is also fixed at this SR.
|
||
|
self.mel_length = 2000
|
||
|
self.stft = TacotronSTFT(
|
||
|
hp.filter_length, hp.hop_length, hp.win_length,
|
||
|
hp.n_mel_channels, hp.sampling_rate, hp.mel_fmin,
|
||
|
hp.mel_fmax)
|
||
|
self.label_compaction = label_compaction
|
||
|
|
||
|
def __getitem__(self, index):
|
||
|
audio, _ = load_wav_to_torch(self.files[index])
|
||
|
starts, ends = torch.load(str(self.files[index]).replace('.wav', '_se.pth'))
|
||
|
|
||
|
if audio.std() > 1:
|
||
|
print(f"Something is very wrong with the given audio. std_dev={audio.std()}. file={self.files[index]}")
|
||
|
return None
|
||
|
audio.clip_(-1, 1)
|
||
|
mels = self.stft.mel_spectrogram(audio.unsqueeze(0))[:, :, :self.mel_length].squeeze(0)
|
||
|
|
||
|
# Form labels.
|
||
|
labels_start = torch.zeros((2000 // self.label_compaction,), dtype=torch.long)
|
||
|
for s in starts:
|
||
|
# Mel compaction operates at a ratio of 1/256, the dataset also allows further compaction.
|
||
|
s = s // (256 * self.label_compaction)
|
||
|
if s >= 2000//self.label_compaction:
|
||
|
continue
|
||
|
labels_start[s] = 1
|
||
|
labels_end = torch.zeros((2000 // self.label_compaction,), dtype=torch.long)
|
||
|
for e in ends:
|
||
|
e = e // (256 * self.label_compaction)
|
||
|
if e >= 2000//self.label_compaction:
|
||
|
continue
|
||
|
labels_end[e] = 1
|
||
|
|
||
|
return {
|
||
|
'mels': mels,
|
||
|
'labels_start': labels_start,
|
||
|
'labels_end': labels_end,
|
||
|
}
|
||
|
|
||
|
|
||
|
def __len__(self):
|
||
|
return len(self.files)
|
||
|
|
||
|
|
||
|
if __name__ == '__main__':
|
||
|
opt = {
|
||
|
'path': 'D:\\data\\audio\\libritts\\stop_dataset',
|
||
|
'label_compaction': 4,
|
||
|
}
|
||
|
ds = StopPredictionDataset(opt)
|
||
|
j = 0
|
||
|
for i in tqdm(range(100)):
|
||
|
b = ds[random.randint(0, len(ds))]
|
||
|
start_indices = torch.nonzero(b['labels_start']).squeeze(1)
|
||
|
end_indices = torch.nonzero(b['labels_end']).squeeze(1)
|
||
|
assert len(end_indices) <= len(start_indices) # There should always be more START tokens then END tokens.
|
||
|
for i in range(len(end_indices)):
|
||
|
s = start_indices[i].item()*4
|
||
|
e = end_indices[i].item()*4
|
||
|
m = b['mels'][:, s:e]
|
||
|
save_mel_buffer_to_file(m, f'{j}.npy')
|
||
|
j += 1
|