Add conditioning clips features to grand_conjoined

This commit is contained in:
James Betker 2021-12-29 14:44:32 -07:00
parent b12f47b36d
commit 51ce1b5007
3 changed files with 94 additions and 69 deletions

View File

@ -57,10 +57,15 @@ class GrandConjoinedDataset(torch.utils.data.Dataset):
self.max_solo_text_length = opt['max_solo_text_length']
self.collate = opt_get(opt, ['needs_collate'], False)
self.sample_rate = sample_rate
self.num_conditioning_candidates = opt_get(opt, ['num_conditioning_candidates'], 0)
self.conditioning_length = opt_get(opt, ['conditioning_length'], 44000)
load_conditioning = self.num_conditioning_candidates > 0
# Set some sane arguments for all three datasets.
paired_dataset_args['needs_collate'] = self.collate
paired_dataset_args['load_conditioning'] = False
paired_dataset_args['load_conditioning'] = load_conditioning
paired_dataset_args['num_conditioning_candidates'] = self.num_conditioning_candidates
paired_dataset_args['conditioning_length'] = self.conditioning_length
paired_dataset_args['sample_rate'] = sample_rate
paired_dataset_args['max_wav_length'] = self.max_paired_audio_length
paired_dataset_args['max_text_length'] = self.max_paired_text_length
@ -70,6 +75,8 @@ class GrandConjoinedDataset(torch.utils.data.Dataset):
unsupervised_audio_args['sampling_rate'] = sample_rate
unsupervised_audio_args['do_augmentation'] = False
unsupervised_audio_args['resample_clip'] = False
unsupervised_audio_args['extra_samples'] = self.num_conditioning_candidates
unsupervised_audio_args['extra_sample_length'] = self.conditioning_length
if self.collate:
unsupervised_audio_args['pad_to_samples'] = self.max_solo_audio_length
self.speech = UnsupervisedAudioDataset(unsupervised_audio_args)
@ -96,7 +103,7 @@ class GrandConjoinedDataset(torch.utils.data.Dataset):
fetched = self.speech_and_text[i % len(self.speech_and_text)]
if self.collate:
tseq, wav, path, text, cond = fetched
return {
res = {
'real_text': text,
'padded_text': tseq,
'text_lengths': torch.tensor(tseq.shape[0], dtype=torch.long),
@ -104,13 +111,26 @@ class GrandConjoinedDataset(torch.utils.data.Dataset):
'wav_lengths': torch.tensor(wav.shape[-1], dtype=torch.long),
'filenames': path
}
if self.num_conditioning_candidates > 0:
res['conditioning'] = cond
return res
else:
return fetched
def optionally_add_conditioning_candidates(self, res, paired, solo=None):
if self.num_conditioning_candidates > 0:
if solo is None:
res['paired_audio_conditioning'] = paired['conditioning']
res['speech_audio_conditioning'] = paired['conditioning']
else:
res['paired_audio_conditioning'] = paired['conditioning']
res['speech_audio_conditioning'] = solo['alt_clips']
return res
def __getitem__(self, i):
snt = self.fetch_snt_at(i)
if self.only_paired:
return {
return self.optionally_add_conditioning_candidates({
'paired_audio': snt['wav'],
'paired_audio_lengths': snt['wav_lengths'],
'paired_text': snt['real_text'],
@ -121,14 +141,14 @@ class GrandConjoinedDataset(torch.utils.data.Dataset):
'speech_file': snt['filenames'],
'text_text': snt['real_text'],
'text_tokens': snt['padded_text'],
}
}, snt)
else:
txt, txt_tok = self.fetch_text_at(i % len(self.text))
sp = self.speech[i % len(self.speech)]
# Set upper bound on solo speech lengths. This is handled automatically when collation is turned off, but needs to be done otherwise.
sp['clip'] = sp['clip'][:, :self.max_solo_audio_length]
sp['clip_lengths'] = clamp(sp['clip_lengths'], 0, self.max_solo_audio_length)
return {
return self.optionally_add_conditioning_candidates({
'paired_audio': snt['wav'],
'paired_audio_lengths': snt['wav_lengths'],
'paired_text': snt['real_text'],
@ -139,7 +159,7 @@ class GrandConjoinedDataset(torch.utils.data.Dataset):
'speech_file': sp['path'],
'text_text': txt,
'text_tokens': txt_tok,
}
}, snt, sp)
def __len__(self):
if self.only_paired:
@ -161,9 +181,11 @@ if __name__ == '__main__':
'max_solo_text_length': 330,
'max_solo_audio_length': 300000,
'needs_collate': True,
'num_conditioning_candidates': 2,
'conditioning_length': 44000,
'paired_dataset_args': {
'path': ['Z:\\bigasr_dataset\\libritts\\test-clean_list.txt'],
'fetcher_mode': ['libritts'],
'path': ['Y:\\clips\\podcasts-0-transcribed.tsv'],
'fetcher_mode': ['tsv'],
'use_bpe_tokenizer': False,
},
'unsupervised_audio_args': {
@ -198,7 +220,10 @@ if __name__ == '__main__':
ds, c = create_dataset(train_params, return_collate=True)
dl = create_dataloader(ds, train_params, collate_fn=c)
def save(b, i, ib, key):
def save(b, i, ib, key, c=None):
if c is not None:
torchaudio.save(f'{i}_clip_{ib}_{key}_{c}.wav', b[key][ib][c], 22050)
else:
torchaudio.save(f'{i}_clip_{ib}_{key}.wav', b[key][ib], 22050)
def decode(b, ib, key):
@ -208,10 +233,16 @@ if __name__ == '__main__':
m = None
for i, b in tqdm(enumerate(dl)):
for ib in range(batch_sz):
#save(b, i, ib, 'paired_audio')
print(f'Paired text: {b["paired_text"][ib]}')
save(b, i, ib, 'paired_audio')
save(b, i, ib, 'paired_audio_conditioning', 0)
save(b, i, ib, 'paired_audio_conditioning', 1)
print(f'Paired file: {b["paired_file"][ib]} text: {b["paired_text"][ib]}')
print(f'Paired text decoded: {decode(b, ib, "paired_text_tokens")}')
#save(b, i, ib, 'speech_audio')
print(f'Text: {b["text_text"][ib]}')
print(f'Text decoded: {decode(b, ib, "text_tokens")}')
#save(b, i, ib, 'speech_audio_conditioning', 0)
#save(b, i, ib, 'speech_audio_conditioning', 1)
#print(f'Text: {b["text_text"][ib]}')
#print(f'Text decoded: {decode(b, ib, "text_tokens")}')
if i > 5:
break

View File

@ -1,6 +1,7 @@
import os
import os
import random
import sys
import torch
import torch.nn.functional as F
@ -11,7 +12,7 @@ from tokenizers import Tokenizer
from tqdm import tqdm
from transformers import GPT2TokenizerFast
from data.audio.unsupervised_audio_dataset import load_audio
from data.audio.unsupervised_audio_dataset import load_audio, load_similar_clips
from data.util import find_files_of_type, is_audio_file
from models.tacotron2.taco_utils import load_filepaths_and_text
from models.tacotron2.text import text_to_sequence, sequence_to_text
@ -68,7 +69,7 @@ class TextWavLoader(torch.utils.data.Dataset):
assert len(self.path) == len(fetcher_mode)
self.load_conditioning = opt_get(hparams, ['load_conditioning'], False)
self.conditioning_candidates = opt_get(hparams, ['num_conditioning_candidates'], 3)
self.conditioning_candidates = opt_get(hparams, ['num_conditioning_candidates'], 1)
self.conditioning_length = opt_get(hparams, ['conditioning_length'], 44100)
self.debug_failures = opt_get(hparams, ['debug_loading_failures'], False)
self.audiopaths_and_text = []
@ -119,33 +120,15 @@ class TextWavLoader(torch.utils.data.Dataset):
assert not torch.any(tokens == 0)
return tokens
def load_conditioning_candidates(self, path):
candidates = find_files_of_type('img', os.path.dirname(path), qualifier=is_audio_file)[0]
assert len(candidates) < 50000 # Sanity check to ensure we aren't loading "related files" that aren't actually related.
if len(candidates) == 0:
print(f"No conditioning candidates found for {path} (not even the clip itself??)")
raise NotImplementedError()
# Sample with replacement. This can get repeats, but more conveniently handles situations where there are not enough candidates.
related_clips = []
for k in range(self.conditioning_candidates):
rel_clip = load_audio(random.choice(candidates), self.sample_rate)
gap = rel_clip.shape[-1] - self.conditioning_length
if gap < 0:
rel_clip = F.pad(rel_clip, pad=(0, abs(gap)))
elif gap > 0:
rand_start = random.randint(0, gap)
rel_clip = rel_clip[:, rand_start:rand_start+self.conditioning_length]
related_clips.append(rel_clip)
return torch.stack(related_clips, dim=0)
def __getitem__(self, index):
try:
tseq, wav, text, path = self.get_wav_text_pair(self.audiopaths_and_text[index])
cond = self.load_conditioning_candidates(self.audiopaths_and_text[index][0]) if self.load_conditioning else None
cond = load_similar_clips(self.audiopaths_and_text[index][0], self.conditioning_length, self.sample_rate,
n=self.conditioning_candidates) if self.load_conditioning else None
except:
if self.debug_failures:
print(f"error loading {self.audiopaths_and_text[index][0]}")
return self[index+1]
print(f"error loading {self.audiopaths_and_text[index][0]} {sys.exc_info()}")
return self[(index+1) % len(self)]
if wav is None or \
(self.max_wav_len is not None and wav.shape[-1] > self.max_wav_len) or \
(self.max_text_len is not None and tseq.shape[0] > self.max_text_len):
@ -247,7 +230,9 @@ if __name__ == '__main__':
'max_wav_length': 255995,
'max_text_length': 200,
'sample_rate': 22050,
'load_conditioning': False,
'load_conditioning': True,
'num_conditioning_candidates': 2,
'conditioning_length': 44000,
'use_bpe_tokenizer': False,
}
from data import create_dataset, create_dataloader
@ -259,3 +244,4 @@ if __name__ == '__main__':
for i, b in tqdm(enumerate(dl)):
for ib in range(batch_sz):
print(f"text_seq: {b['text_lengths'].max()}, speech_seq: {b['wav_lengths'].max()//1024}")

View File

@ -51,6 +51,41 @@ def load_audio(audiopath, sampling_rate):
return audio.unsqueeze(0)
def load_similar_clips(path, sample_length, sample_rate, n=3, include_self=True):
sim_path = os.path.join(os.path.dirname(path), 'similarities.pth')
candidates = []
if os.path.exists(sim_path):
similarities = torch.load(sim_path)
fname = os.path.basename(path)
if fname in similarities.keys():
candidates = similarities[fname]
else:
print(f'Similarities list found for {path} but {fname} was not in that list.')
if len(candidates) == 0:
print(f"Falling back to non-similarity list for {path}")
candidates = find_files_of_type('img', os.path.dirname(path), qualifier=is_audio_file)[0]
assert len(candidates) < 50000 # Sanity check to ensure we aren't loading "related files" that aren't actually related.
if not include_self:
candidates.remove(path)
if len(candidates) == 0:
print(f"No conditioning candidates found for {path}")
raise NotImplementedError()
# Sample with replacement. This can get repeats, but more conveniently handles situations where there are not enough candidates.
related_clips = []
for k in range(n):
rel_clip = load_audio(os.path.join(os.path.dirname(path), random.choice(candidates)), sample_rate)
gap = rel_clip.shape[-1] - sample_length
if gap < 0:
rel_clip = F.pad(rel_clip, pad=(0, abs(gap)))
elif gap > 0:
rand_start = random.randint(0, gap)
rel_clip = rel_clip[:, rand_start:rand_start+sample_length]
related_clips.append(rel_clip)
return torch.stack(related_clips, dim=0)
class UnsupervisedAudioDataset(torch.utils.data.Dataset):
def __init__(self, opt):
@ -77,8 +112,7 @@ class UnsupervisedAudioDataset(torch.utils.data.Dataset):
# "Extra samples" are other audio clips pulled from wav files in the same directory as the 'clip' wav file.
self.extra_samples = opt_get(opt, ['extra_samples'], 0)
self.extra_sample_len = opt_get(opt, ['extra_sample_length'], 2)
self.extra_sample_len *= self.sampling_rate
self.extra_sample_len = opt_get(opt, ['extra_sample_length'], 44000)
self.debug_loading_failures = opt_get(opt, ['debug_loading_failures'], True)
@ -92,38 +126,13 @@ class UnsupervisedAudioDataset(torch.utils.data.Dataset):
if self.extra_samples <= 0:
return None, 0
audiopath = self.audiopaths[index]
related_files = find_files_of_type('img', os.path.dirname(audiopath), qualifier=is_audio_file)[0]
assert audiopath in related_files
assert len(related_files) < 50000 # Sanity check to ensure we aren't loading "related files" that aren't actually related.
if len(related_files) == 0:
print(f"No related files for {audiopath}")
related_files.remove(audiopath)
related_clips = []
random.shuffle(related_clips)
i = 0
for related_file in related_files:
rel_clip = load_audio(related_file, self.sampling_rate)
gap = rel_clip.shape[-1] - self.extra_sample_len
if gap < 0:
rel_clip = F.pad(rel_clip, pad=(0, abs(gap)))
elif gap > 0:
rand_start = random.randint(0, gap)
rel_clip = rel_clip[:, rand_start:rand_start+self.extra_sample_len]
related_clips.append(rel_clip)
i += 1
if i >= self.extra_samples:
break
actual_extra_samples = i
while i < self.extra_samples:
related_clips.append(torch.zeros(1, self.extra_sample_len))
i += 1
return torch.stack(related_clips, dim=0), actual_extra_samples
return load_similar_clips(audiopath, self.extra_sample_len, self.sampling_rate, n=self.extra_samples)
def __getitem__(self, index):
try:
# Split audio_norm into two tensors of equal size.
audio_norm, filename = self.get_audio_for_index(index)
alt_files, actual_samples = self.get_related_audio_for_index(index)
alt_files = self.get_related_audio_for_index(index)
except:
if self.debug_loading_failures:
print(f"Error loading audio for file {self.audiopaths[index]} {sys.exc_info()}")
@ -155,7 +164,6 @@ class UnsupervisedAudioDataset(torch.utils.data.Dataset):
output['resampled_clip'] = clips[1]
if self.extra_samples > 0:
output['alt_clips'] = alt_files
output['num_alt_clips'] = actual_samples
return output
def __len__(self):