Add conditioning clips features to grand_conjoined

This commit is contained in:
James Betker 2021-12-29 14:44:32 -07:00
parent b12f47b36d
commit 51ce1b5007
3 changed files with 94 additions and 69 deletions

View File

@ -57,10 +57,15 @@ class GrandConjoinedDataset(torch.utils.data.Dataset):
self.max_solo_text_length = opt['max_solo_text_length'] self.max_solo_text_length = opt['max_solo_text_length']
self.collate = opt_get(opt, ['needs_collate'], False) self.collate = opt_get(opt, ['needs_collate'], False)
self.sample_rate = sample_rate self.sample_rate = sample_rate
self.num_conditioning_candidates = opt_get(opt, ['num_conditioning_candidates'], 0)
self.conditioning_length = opt_get(opt, ['conditioning_length'], 44000)
load_conditioning = self.num_conditioning_candidates > 0
# Set some sane arguments for all three datasets. # Set some sane arguments for all three datasets.
paired_dataset_args['needs_collate'] = self.collate paired_dataset_args['needs_collate'] = self.collate
paired_dataset_args['load_conditioning'] = False paired_dataset_args['load_conditioning'] = load_conditioning
paired_dataset_args['num_conditioning_candidates'] = self.num_conditioning_candidates
paired_dataset_args['conditioning_length'] = self.conditioning_length
paired_dataset_args['sample_rate'] = sample_rate paired_dataset_args['sample_rate'] = sample_rate
paired_dataset_args['max_wav_length'] = self.max_paired_audio_length paired_dataset_args['max_wav_length'] = self.max_paired_audio_length
paired_dataset_args['max_text_length'] = self.max_paired_text_length paired_dataset_args['max_text_length'] = self.max_paired_text_length
@ -70,6 +75,8 @@ class GrandConjoinedDataset(torch.utils.data.Dataset):
unsupervised_audio_args['sampling_rate'] = sample_rate unsupervised_audio_args['sampling_rate'] = sample_rate
unsupervised_audio_args['do_augmentation'] = False unsupervised_audio_args['do_augmentation'] = False
unsupervised_audio_args['resample_clip'] = False unsupervised_audio_args['resample_clip'] = False
unsupervised_audio_args['extra_samples'] = self.num_conditioning_candidates
unsupervised_audio_args['extra_sample_length'] = self.conditioning_length
if self.collate: if self.collate:
unsupervised_audio_args['pad_to_samples'] = self.max_solo_audio_length unsupervised_audio_args['pad_to_samples'] = self.max_solo_audio_length
self.speech = UnsupervisedAudioDataset(unsupervised_audio_args) self.speech = UnsupervisedAudioDataset(unsupervised_audio_args)
@ -96,7 +103,7 @@ class GrandConjoinedDataset(torch.utils.data.Dataset):
fetched = self.speech_and_text[i % len(self.speech_and_text)] fetched = self.speech_and_text[i % len(self.speech_and_text)]
if self.collate: if self.collate:
tseq, wav, path, text, cond = fetched tseq, wav, path, text, cond = fetched
return { res = {
'real_text': text, 'real_text': text,
'padded_text': tseq, 'padded_text': tseq,
'text_lengths': torch.tensor(tseq.shape[0], dtype=torch.long), 'text_lengths': torch.tensor(tseq.shape[0], dtype=torch.long),
@ -104,13 +111,26 @@ class GrandConjoinedDataset(torch.utils.data.Dataset):
'wav_lengths': torch.tensor(wav.shape[-1], dtype=torch.long), 'wav_lengths': torch.tensor(wav.shape[-1], dtype=torch.long),
'filenames': path 'filenames': path
} }
if self.num_conditioning_candidates > 0:
res['conditioning'] = cond
return res
else: else:
return fetched return fetched
def optionally_add_conditioning_candidates(self, res, paired, solo=None):
if self.num_conditioning_candidates > 0:
if solo is None:
res['paired_audio_conditioning'] = paired['conditioning']
res['speech_audio_conditioning'] = paired['conditioning']
else:
res['paired_audio_conditioning'] = paired['conditioning']
res['speech_audio_conditioning'] = solo['alt_clips']
return res
def __getitem__(self, i): def __getitem__(self, i):
snt = self.fetch_snt_at(i) snt = self.fetch_snt_at(i)
if self.only_paired: if self.only_paired:
return { return self.optionally_add_conditioning_candidates({
'paired_audio': snt['wav'], 'paired_audio': snt['wav'],
'paired_audio_lengths': snt['wav_lengths'], 'paired_audio_lengths': snt['wav_lengths'],
'paired_text': snt['real_text'], 'paired_text': snt['real_text'],
@ -121,14 +141,14 @@ class GrandConjoinedDataset(torch.utils.data.Dataset):
'speech_file': snt['filenames'], 'speech_file': snt['filenames'],
'text_text': snt['real_text'], 'text_text': snt['real_text'],
'text_tokens': snt['padded_text'], 'text_tokens': snt['padded_text'],
} }, snt)
else: else:
txt, txt_tok = self.fetch_text_at(i % len(self.text)) txt, txt_tok = self.fetch_text_at(i % len(self.text))
sp = self.speech[i % len(self.speech)] sp = self.speech[i % len(self.speech)]
# Set upper bound on solo speech lengths. This is handled automatically when collation is turned off, but needs to be done otherwise. # Set upper bound on solo speech lengths. This is handled automatically when collation is turned off, but needs to be done otherwise.
sp['clip'] = sp['clip'][:, :self.max_solo_audio_length] sp['clip'] = sp['clip'][:, :self.max_solo_audio_length]
sp['clip_lengths'] = clamp(sp['clip_lengths'], 0, self.max_solo_audio_length) sp['clip_lengths'] = clamp(sp['clip_lengths'], 0, self.max_solo_audio_length)
return { return self.optionally_add_conditioning_candidates({
'paired_audio': snt['wav'], 'paired_audio': snt['wav'],
'paired_audio_lengths': snt['wav_lengths'], 'paired_audio_lengths': snt['wav_lengths'],
'paired_text': snt['real_text'], 'paired_text': snt['real_text'],
@ -139,7 +159,7 @@ class GrandConjoinedDataset(torch.utils.data.Dataset):
'speech_file': sp['path'], 'speech_file': sp['path'],
'text_text': txt, 'text_text': txt,
'text_tokens': txt_tok, 'text_tokens': txt_tok,
} }, snt, sp)
def __len__(self): def __len__(self):
if self.only_paired: if self.only_paired:
@ -161,9 +181,11 @@ if __name__ == '__main__':
'max_solo_text_length': 330, 'max_solo_text_length': 330,
'max_solo_audio_length': 300000, 'max_solo_audio_length': 300000,
'needs_collate': True, 'needs_collate': True,
'num_conditioning_candidates': 2,
'conditioning_length': 44000,
'paired_dataset_args': { 'paired_dataset_args': {
'path': ['Z:\\bigasr_dataset\\libritts\\test-clean_list.txt'], 'path': ['Y:\\clips\\podcasts-0-transcribed.tsv'],
'fetcher_mode': ['libritts'], 'fetcher_mode': ['tsv'],
'use_bpe_tokenizer': False, 'use_bpe_tokenizer': False,
}, },
'unsupervised_audio_args': { 'unsupervised_audio_args': {
@ -198,7 +220,10 @@ if __name__ == '__main__':
ds, c = create_dataset(train_params, return_collate=True) ds, c = create_dataset(train_params, return_collate=True)
dl = create_dataloader(ds, train_params, collate_fn=c) dl = create_dataloader(ds, train_params, collate_fn=c)
def save(b, i, ib, key): def save(b, i, ib, key, c=None):
if c is not None:
torchaudio.save(f'{i}_clip_{ib}_{key}_{c}.wav', b[key][ib][c], 22050)
else:
torchaudio.save(f'{i}_clip_{ib}_{key}.wav', b[key][ib], 22050) torchaudio.save(f'{i}_clip_{ib}_{key}.wav', b[key][ib], 22050)
def decode(b, ib, key): def decode(b, ib, key):
@ -208,10 +233,16 @@ if __name__ == '__main__':
m = None m = None
for i, b in tqdm(enumerate(dl)): for i, b in tqdm(enumerate(dl)):
for ib in range(batch_sz): for ib in range(batch_sz):
#save(b, i, ib, 'paired_audio') save(b, i, ib, 'paired_audio')
print(f'Paired text: {b["paired_text"][ib]}') save(b, i, ib, 'paired_audio_conditioning', 0)
save(b, i, ib, 'paired_audio_conditioning', 1)
print(f'Paired file: {b["paired_file"][ib]} text: {b["paired_text"][ib]}')
print(f'Paired text decoded: {decode(b, ib, "paired_text_tokens")}') print(f'Paired text decoded: {decode(b, ib, "paired_text_tokens")}')
#save(b, i, ib, 'speech_audio') #save(b, i, ib, 'speech_audio')
print(f'Text: {b["text_text"][ib]}') #save(b, i, ib, 'speech_audio_conditioning', 0)
print(f'Text decoded: {decode(b, ib, "text_tokens")}') #save(b, i, ib, 'speech_audio_conditioning', 1)
#print(f'Text: {b["text_text"][ib]}')
#print(f'Text decoded: {decode(b, ib, "text_tokens")}')
if i > 5:
break

View File

@ -1,6 +1,7 @@
import os import os
import os import os
import random import random
import sys
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
@ -11,7 +12,7 @@ from tokenizers import Tokenizer
from tqdm import tqdm from tqdm import tqdm
from transformers import GPT2TokenizerFast from transformers import GPT2TokenizerFast
from data.audio.unsupervised_audio_dataset import load_audio from data.audio.unsupervised_audio_dataset import load_audio, load_similar_clips
from data.util import find_files_of_type, is_audio_file from data.util import find_files_of_type, is_audio_file
from models.tacotron2.taco_utils import load_filepaths_and_text from models.tacotron2.taco_utils import load_filepaths_and_text
from models.tacotron2.text import text_to_sequence, sequence_to_text from models.tacotron2.text import text_to_sequence, sequence_to_text
@ -68,7 +69,7 @@ class TextWavLoader(torch.utils.data.Dataset):
assert len(self.path) == len(fetcher_mode) assert len(self.path) == len(fetcher_mode)
self.load_conditioning = opt_get(hparams, ['load_conditioning'], False) self.load_conditioning = opt_get(hparams, ['load_conditioning'], False)
self.conditioning_candidates = opt_get(hparams, ['num_conditioning_candidates'], 3) self.conditioning_candidates = opt_get(hparams, ['num_conditioning_candidates'], 1)
self.conditioning_length = opt_get(hparams, ['conditioning_length'], 44100) self.conditioning_length = opt_get(hparams, ['conditioning_length'], 44100)
self.debug_failures = opt_get(hparams, ['debug_loading_failures'], False) self.debug_failures = opt_get(hparams, ['debug_loading_failures'], False)
self.audiopaths_and_text = [] self.audiopaths_and_text = []
@ -119,33 +120,15 @@ class TextWavLoader(torch.utils.data.Dataset):
assert not torch.any(tokens == 0) assert not torch.any(tokens == 0)
return tokens return tokens
def load_conditioning_candidates(self, path):
candidates = find_files_of_type('img', os.path.dirname(path), qualifier=is_audio_file)[0]
assert len(candidates) < 50000 # Sanity check to ensure we aren't loading "related files" that aren't actually related.
if len(candidates) == 0:
print(f"No conditioning candidates found for {path} (not even the clip itself??)")
raise NotImplementedError()
# Sample with replacement. This can get repeats, but more conveniently handles situations where there are not enough candidates.
related_clips = []
for k in range(self.conditioning_candidates):
rel_clip = load_audio(random.choice(candidates), self.sample_rate)
gap = rel_clip.shape[-1] - self.conditioning_length
if gap < 0:
rel_clip = F.pad(rel_clip, pad=(0, abs(gap)))
elif gap > 0:
rand_start = random.randint(0, gap)
rel_clip = rel_clip[:, rand_start:rand_start+self.conditioning_length]
related_clips.append(rel_clip)
return torch.stack(related_clips, dim=0)
def __getitem__(self, index): def __getitem__(self, index):
try: try:
tseq, wav, text, path = self.get_wav_text_pair(self.audiopaths_and_text[index]) tseq, wav, text, path = self.get_wav_text_pair(self.audiopaths_and_text[index])
cond = self.load_conditioning_candidates(self.audiopaths_and_text[index][0]) if self.load_conditioning else None cond = load_similar_clips(self.audiopaths_and_text[index][0], self.conditioning_length, self.sample_rate,
n=self.conditioning_candidates) if self.load_conditioning else None
except: except:
if self.debug_failures: if self.debug_failures:
print(f"error loading {self.audiopaths_and_text[index][0]}") print(f"error loading {self.audiopaths_and_text[index][0]} {sys.exc_info()}")
return self[index+1] return self[(index+1) % len(self)]
if wav is None or \ if wav is None or \
(self.max_wav_len is not None and wav.shape[-1] > self.max_wav_len) or \ (self.max_wav_len is not None and wav.shape[-1] > self.max_wav_len) or \
(self.max_text_len is not None and tseq.shape[0] > self.max_text_len): (self.max_text_len is not None and tseq.shape[0] > self.max_text_len):
@ -247,7 +230,9 @@ if __name__ == '__main__':
'max_wav_length': 255995, 'max_wav_length': 255995,
'max_text_length': 200, 'max_text_length': 200,
'sample_rate': 22050, 'sample_rate': 22050,
'load_conditioning': False, 'load_conditioning': True,
'num_conditioning_candidates': 2,
'conditioning_length': 44000,
'use_bpe_tokenizer': False, 'use_bpe_tokenizer': False,
} }
from data import create_dataset, create_dataloader from data import create_dataset, create_dataloader
@ -259,3 +244,4 @@ if __name__ == '__main__':
for i, b in tqdm(enumerate(dl)): for i, b in tqdm(enumerate(dl)):
for ib in range(batch_sz): for ib in range(batch_sz):
print(f"text_seq: {b['text_lengths'].max()}, speech_seq: {b['wav_lengths'].max()//1024}") print(f"text_seq: {b['text_lengths'].max()}, speech_seq: {b['wav_lengths'].max()//1024}")

View File

@ -51,6 +51,41 @@ def load_audio(audiopath, sampling_rate):
return audio.unsqueeze(0) return audio.unsqueeze(0)
def load_similar_clips(path, sample_length, sample_rate, n=3, include_self=True):
sim_path = os.path.join(os.path.dirname(path), 'similarities.pth')
candidates = []
if os.path.exists(sim_path):
similarities = torch.load(sim_path)
fname = os.path.basename(path)
if fname in similarities.keys():
candidates = similarities[fname]
else:
print(f'Similarities list found for {path} but {fname} was not in that list.')
if len(candidates) == 0:
print(f"Falling back to non-similarity list for {path}")
candidates = find_files_of_type('img', os.path.dirname(path), qualifier=is_audio_file)[0]
assert len(candidates) < 50000 # Sanity check to ensure we aren't loading "related files" that aren't actually related.
if not include_self:
candidates.remove(path)
if len(candidates) == 0:
print(f"No conditioning candidates found for {path}")
raise NotImplementedError()
# Sample with replacement. This can get repeats, but more conveniently handles situations where there are not enough candidates.
related_clips = []
for k in range(n):
rel_clip = load_audio(os.path.join(os.path.dirname(path), random.choice(candidates)), sample_rate)
gap = rel_clip.shape[-1] - sample_length
if gap < 0:
rel_clip = F.pad(rel_clip, pad=(0, abs(gap)))
elif gap > 0:
rand_start = random.randint(0, gap)
rel_clip = rel_clip[:, rand_start:rand_start+sample_length]
related_clips.append(rel_clip)
return torch.stack(related_clips, dim=0)
class UnsupervisedAudioDataset(torch.utils.data.Dataset): class UnsupervisedAudioDataset(torch.utils.data.Dataset):
def __init__(self, opt): def __init__(self, opt):
@ -77,8 +112,7 @@ class UnsupervisedAudioDataset(torch.utils.data.Dataset):
# "Extra samples" are other audio clips pulled from wav files in the same directory as the 'clip' wav file. # "Extra samples" are other audio clips pulled from wav files in the same directory as the 'clip' wav file.
self.extra_samples = opt_get(opt, ['extra_samples'], 0) self.extra_samples = opt_get(opt, ['extra_samples'], 0)
self.extra_sample_len = opt_get(opt, ['extra_sample_length'], 2) self.extra_sample_len = opt_get(opt, ['extra_sample_length'], 44000)
self.extra_sample_len *= self.sampling_rate
self.debug_loading_failures = opt_get(opt, ['debug_loading_failures'], True) self.debug_loading_failures = opt_get(opt, ['debug_loading_failures'], True)
@ -92,38 +126,13 @@ class UnsupervisedAudioDataset(torch.utils.data.Dataset):
if self.extra_samples <= 0: if self.extra_samples <= 0:
return None, 0 return None, 0
audiopath = self.audiopaths[index] audiopath = self.audiopaths[index]
related_files = find_files_of_type('img', os.path.dirname(audiopath), qualifier=is_audio_file)[0] return load_similar_clips(audiopath, self.extra_sample_len, self.sampling_rate, n=self.extra_samples)
assert audiopath in related_files
assert len(related_files) < 50000 # Sanity check to ensure we aren't loading "related files" that aren't actually related.
if len(related_files) == 0:
print(f"No related files for {audiopath}")
related_files.remove(audiopath)
related_clips = []
random.shuffle(related_clips)
i = 0
for related_file in related_files:
rel_clip = load_audio(related_file, self.sampling_rate)
gap = rel_clip.shape[-1] - self.extra_sample_len
if gap < 0:
rel_clip = F.pad(rel_clip, pad=(0, abs(gap)))
elif gap > 0:
rand_start = random.randint(0, gap)
rel_clip = rel_clip[:, rand_start:rand_start+self.extra_sample_len]
related_clips.append(rel_clip)
i += 1
if i >= self.extra_samples:
break
actual_extra_samples = i
while i < self.extra_samples:
related_clips.append(torch.zeros(1, self.extra_sample_len))
i += 1
return torch.stack(related_clips, dim=0), actual_extra_samples
def __getitem__(self, index): def __getitem__(self, index):
try: try:
# Split audio_norm into two tensors of equal size. # Split audio_norm into two tensors of equal size.
audio_norm, filename = self.get_audio_for_index(index) audio_norm, filename = self.get_audio_for_index(index)
alt_files, actual_samples = self.get_related_audio_for_index(index) alt_files = self.get_related_audio_for_index(index)
except: except:
if self.debug_loading_failures: if self.debug_loading_failures:
print(f"Error loading audio for file {self.audiopaths[index]} {sys.exc_info()}") print(f"Error loading audio for file {self.audiopaths[index]} {sys.exc_info()}")
@ -155,7 +164,6 @@ class UnsupervisedAudioDataset(torch.utils.data.Dataset):
output['resampled_clip'] = clips[1] output['resampled_clip'] = clips[1]
if self.extra_samples > 0: if self.extra_samples > 0:
output['alt_clips'] = alt_files output['alt_clips'] = alt_files
output['num_alt_clips'] = actual_samples
return output return output
def __len__(self): def __len__(self):