Use torch resampler

This commit is contained in:
James Betker 2022-01-05 15:47:22 -07:00
parent 38aba6f88d
commit 0fe34f57d1
2 changed files with 13 additions and 6 deletions

View File

@ -222,8 +222,8 @@ if __name__ == '__main__':
batch_sz = 8
params = {
'mode': 'paired_voice_audio',
'path': ['Z:\\clips\\podcasts-0-transcribed.tsv'],
'fetcher_mode': ['tsv'],
'path': ['Y:\\bigasr_dataset\\hifi_tts_mp3\\test.txt'],
'fetcher_mode': ['libritts'],
'phase': 'train',
'n_workers': 0,
'batch_size': batch_sz,
@ -238,11 +238,20 @@ if __name__ == '__main__':
}
from data import create_dataset, create_dataloader
def save(b, i, ib, key, c=None):
if c is not None:
torchaudio.save(f'{i}_clip_{ib}_{key}_{c}.wav', b[key][ib][c], 22050)
else:
torchaudio.save(f'{i}_clip_{ib}_{key}.wav', b[key][ib], 22050)
ds, c = create_dataset(params, return_collate=True)
dl = create_dataloader(ds, params, collate_fn=c)
i = 0
m = None
for i, b in tqdm(enumerate(dl)):
for ib in range(batch_sz):
print(f"text_seq: {b['text_lengths'].max()}, speech_seq: {b['wav_lengths'].max()//1024}")
print(f'{i} {ib} {b["real_text"][ib]}')
save(b, i, ib, 'wav')
if i > 5:
break

View File

@ -38,9 +38,7 @@ def load_audio(audiopath, sampling_rate):
audio = audio[:, 0]
if lsr != sampling_rate:
#if lsr < sampling_rate:
# warn(f'{audiopath} has a sample rate of {sampling_rate} which is lower than the requested sample rate of {sampling_rate}. This is not a good idea.')
audio = torch.nn.functional.interpolate(audio.unsqueeze(0).unsqueeze(1), scale_factor=sampling_rate/lsr, mode='nearest', recompute_scale_factor=False).squeeze()
audio = torchaudio.functional.resample(audio, lsr, sampling_rate)
# Check some assumptions about audio range. This should be automatically fixed in load_wav_to_torch, but might not be in some edge cases, where we should squawk.
# '2' is arbitrarily chosen since it seems like audio will often "overdrive" the [-1,1] bounds.