forked from mrq/DL-Art-School
Add conditioning clips features to grand_conjoined
This commit is contained in:
parent
b12f47b36d
commit
51ce1b5007
|
@ -57,10 +57,15 @@ class GrandConjoinedDataset(torch.utils.data.Dataset):
|
|||
self.max_solo_text_length = opt['max_solo_text_length']
|
||||
self.collate = opt_get(opt, ['needs_collate'], False)
|
||||
self.sample_rate = sample_rate
|
||||
self.num_conditioning_candidates = opt_get(opt, ['num_conditioning_candidates'], 0)
|
||||
self.conditioning_length = opt_get(opt, ['conditioning_length'], 44000)
|
||||
load_conditioning = self.num_conditioning_candidates > 0
|
||||
|
||||
# Set some sane arguments for all three datasets.
|
||||
paired_dataset_args['needs_collate'] = self.collate
|
||||
paired_dataset_args['load_conditioning'] = False
|
||||
paired_dataset_args['load_conditioning'] = load_conditioning
|
||||
paired_dataset_args['num_conditioning_candidates'] = self.num_conditioning_candidates
|
||||
paired_dataset_args['conditioning_length'] = self.conditioning_length
|
||||
paired_dataset_args['sample_rate'] = sample_rate
|
||||
paired_dataset_args['max_wav_length'] = self.max_paired_audio_length
|
||||
paired_dataset_args['max_text_length'] = self.max_paired_text_length
|
||||
|
@ -70,6 +75,8 @@ class GrandConjoinedDataset(torch.utils.data.Dataset):
|
|||
unsupervised_audio_args['sampling_rate'] = sample_rate
|
||||
unsupervised_audio_args['do_augmentation'] = False
|
||||
unsupervised_audio_args['resample_clip'] = False
|
||||
unsupervised_audio_args['extra_samples'] = self.num_conditioning_candidates
|
||||
unsupervised_audio_args['extra_sample_length'] = self.conditioning_length
|
||||
if self.collate:
|
||||
unsupervised_audio_args['pad_to_samples'] = self.max_solo_audio_length
|
||||
self.speech = UnsupervisedAudioDataset(unsupervised_audio_args)
|
||||
|
@ -96,7 +103,7 @@ class GrandConjoinedDataset(torch.utils.data.Dataset):
|
|||
fetched = self.speech_and_text[i % len(self.speech_and_text)]
|
||||
if self.collate:
|
||||
tseq, wav, path, text, cond = fetched
|
||||
return {
|
||||
res = {
|
||||
'real_text': text,
|
||||
'padded_text': tseq,
|
||||
'text_lengths': torch.tensor(tseq.shape[0], dtype=torch.long),
|
||||
|
@ -104,13 +111,26 @@ class GrandConjoinedDataset(torch.utils.data.Dataset):
|
|||
'wav_lengths': torch.tensor(wav.shape[-1], dtype=torch.long),
|
||||
'filenames': path
|
||||
}
|
||||
if self.num_conditioning_candidates > 0:
|
||||
res['conditioning'] = cond
|
||||
return res
|
||||
else:
|
||||
return fetched
|
||||
|
||||
def optionally_add_conditioning_candidates(self, res, paired, solo=None):
|
||||
if self.num_conditioning_candidates > 0:
|
||||
if solo is None:
|
||||
res['paired_audio_conditioning'] = paired['conditioning']
|
||||
res['speech_audio_conditioning'] = paired['conditioning']
|
||||
else:
|
||||
res['paired_audio_conditioning'] = paired['conditioning']
|
||||
res['speech_audio_conditioning'] = solo['alt_clips']
|
||||
return res
|
||||
|
||||
def __getitem__(self, i):
|
||||
snt = self.fetch_snt_at(i)
|
||||
if self.only_paired:
|
||||
return {
|
||||
return self.optionally_add_conditioning_candidates({
|
||||
'paired_audio': snt['wav'],
|
||||
'paired_audio_lengths': snt['wav_lengths'],
|
||||
'paired_text': snt['real_text'],
|
||||
|
@ -121,14 +141,14 @@ class GrandConjoinedDataset(torch.utils.data.Dataset):
|
|||
'speech_file': snt['filenames'],
|
||||
'text_text': snt['real_text'],
|
||||
'text_tokens': snt['padded_text'],
|
||||
}
|
||||
}, snt)
|
||||
else:
|
||||
txt, txt_tok = self.fetch_text_at(i % len(self.text))
|
||||
sp = self.speech[i % len(self.speech)]
|
||||
# Set upper bound on solo speech lengths. This is handled automatically when collation is turned off, but needs to be done otherwise.
|
||||
sp['clip'] = sp['clip'][:, :self.max_solo_audio_length]
|
||||
sp['clip_lengths'] = clamp(sp['clip_lengths'], 0, self.max_solo_audio_length)
|
||||
return {
|
||||
return self.optionally_add_conditioning_candidates({
|
||||
'paired_audio': snt['wav'],
|
||||
'paired_audio_lengths': snt['wav_lengths'],
|
||||
'paired_text': snt['real_text'],
|
||||
|
@ -139,7 +159,7 @@ class GrandConjoinedDataset(torch.utils.data.Dataset):
|
|||
'speech_file': sp['path'],
|
||||
'text_text': txt,
|
||||
'text_tokens': txt_tok,
|
||||
}
|
||||
}, snt, sp)
|
||||
|
||||
def __len__(self):
|
||||
if self.only_paired:
|
||||
|
@ -161,9 +181,11 @@ if __name__ == '__main__':
|
|||
'max_solo_text_length': 330,
|
||||
'max_solo_audio_length': 300000,
|
||||
'needs_collate': True,
|
||||
'num_conditioning_candidates': 2,
|
||||
'conditioning_length': 44000,
|
||||
'paired_dataset_args': {
|
||||
'path': ['Z:\\bigasr_dataset\\libritts\\test-clean_list.txt'],
|
||||
'fetcher_mode': ['libritts'],
|
||||
'path': ['Y:\\clips\\podcasts-0-transcribed.tsv'],
|
||||
'fetcher_mode': ['tsv'],
|
||||
'use_bpe_tokenizer': False,
|
||||
},
|
||||
'unsupervised_audio_args': {
|
||||
|
@ -198,8 +220,11 @@ if __name__ == '__main__':
|
|||
ds, c = create_dataset(train_params, return_collate=True)
|
||||
dl = create_dataloader(ds, train_params, collate_fn=c)
|
||||
|
||||
def save(b, i, ib, key):
|
||||
torchaudio.save(f'{i}_clip_{ib}_{key}.wav', b[key][ib], 22050)
|
||||
def save(b, i, ib, key, c=None):
|
||||
if c is not None:
|
||||
torchaudio.save(f'{i}_clip_{ib}_{key}_{c}.wav', b[key][ib][c], 22050)
|
||||
else:
|
||||
torchaudio.save(f'{i}_clip_{ib}_{key}.wav', b[key][ib], 22050)
|
||||
|
||||
def decode(b, ib, key):
|
||||
return ds.speech_and_text.tokenizer.decode(b[key][ib].cpu().numpy())
|
||||
|
@ -208,10 +233,16 @@ if __name__ == '__main__':
|
|||
m = None
|
||||
for i, b in tqdm(enumerate(dl)):
|
||||
for ib in range(batch_sz):
|
||||
#save(b, i, ib, 'paired_audio')
|
||||
print(f'Paired text: {b["paired_text"][ib]}')
|
||||
save(b, i, ib, 'paired_audio')
|
||||
save(b, i, ib, 'paired_audio_conditioning', 0)
|
||||
save(b, i, ib, 'paired_audio_conditioning', 1)
|
||||
print(f'Paired file: {b["paired_file"][ib]} text: {b["paired_text"][ib]}')
|
||||
print(f'Paired text decoded: {decode(b, ib, "paired_text_tokens")}')
|
||||
#save(b, i, ib, 'speech_audio')
|
||||
print(f'Text: {b["text_text"][ib]}')
|
||||
print(f'Text decoded: {decode(b, ib, "text_tokens")}')
|
||||
#save(b, i, ib, 'speech_audio_conditioning', 0)
|
||||
#save(b, i, ib, 'speech_audio_conditioning', 1)
|
||||
#print(f'Text: {b["text_text"][ib]}')
|
||||
#print(f'Text decoded: {decode(b, ib, "text_tokens")}')
|
||||
if i > 5:
|
||||
break
|
||||
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
import os
|
||||
import os
|
||||
import random
|
||||
import sys
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
@ -11,7 +12,7 @@ from tokenizers import Tokenizer
|
|||
from tqdm import tqdm
|
||||
from transformers import GPT2TokenizerFast
|
||||
|
||||
from data.audio.unsupervised_audio_dataset import load_audio
|
||||
from data.audio.unsupervised_audio_dataset import load_audio, load_similar_clips
|
||||
from data.util import find_files_of_type, is_audio_file
|
||||
from models.tacotron2.taco_utils import load_filepaths_and_text
|
||||
from models.tacotron2.text import text_to_sequence, sequence_to_text
|
||||
|
@ -68,7 +69,7 @@ class TextWavLoader(torch.utils.data.Dataset):
|
|||
assert len(self.path) == len(fetcher_mode)
|
||||
|
||||
self.load_conditioning = opt_get(hparams, ['load_conditioning'], False)
|
||||
self.conditioning_candidates = opt_get(hparams, ['num_conditioning_candidates'], 3)
|
||||
self.conditioning_candidates = opt_get(hparams, ['num_conditioning_candidates'], 1)
|
||||
self.conditioning_length = opt_get(hparams, ['conditioning_length'], 44100)
|
||||
self.debug_failures = opt_get(hparams, ['debug_loading_failures'], False)
|
||||
self.audiopaths_and_text = []
|
||||
|
@ -119,33 +120,15 @@ class TextWavLoader(torch.utils.data.Dataset):
|
|||
assert not torch.any(tokens == 0)
|
||||
return tokens
|
||||
|
||||
def load_conditioning_candidates(self, path):
|
||||
candidates = find_files_of_type('img', os.path.dirname(path), qualifier=is_audio_file)[0]
|
||||
assert len(candidates) < 50000 # Sanity check to ensure we aren't loading "related files" that aren't actually related.
|
||||
if len(candidates) == 0:
|
||||
print(f"No conditioning candidates found for {path} (not even the clip itself??)")
|
||||
raise NotImplementedError()
|
||||
# Sample with replacement. This can get repeats, but more conveniently handles situations where there are not enough candidates.
|
||||
related_clips = []
|
||||
for k in range(self.conditioning_candidates):
|
||||
rel_clip = load_audio(random.choice(candidates), self.sample_rate)
|
||||
gap = rel_clip.shape[-1] - self.conditioning_length
|
||||
if gap < 0:
|
||||
rel_clip = F.pad(rel_clip, pad=(0, abs(gap)))
|
||||
elif gap > 0:
|
||||
rand_start = random.randint(0, gap)
|
||||
rel_clip = rel_clip[:, rand_start:rand_start+self.conditioning_length]
|
||||
related_clips.append(rel_clip)
|
||||
return torch.stack(related_clips, dim=0)
|
||||
|
||||
def __getitem__(self, index):
|
||||
try:
|
||||
tseq, wav, text, path = self.get_wav_text_pair(self.audiopaths_and_text[index])
|
||||
cond = self.load_conditioning_candidates(self.audiopaths_and_text[index][0]) if self.load_conditioning else None
|
||||
cond = load_similar_clips(self.audiopaths_and_text[index][0], self.conditioning_length, self.sample_rate,
|
||||
n=self.conditioning_candidates) if self.load_conditioning else None
|
||||
except:
|
||||
if self.debug_failures:
|
||||
print(f"error loading {self.audiopaths_and_text[index][0]}")
|
||||
return self[index+1]
|
||||
print(f"error loading {self.audiopaths_and_text[index][0]} {sys.exc_info()}")
|
||||
return self[(index+1) % len(self)]
|
||||
if wav is None or \
|
||||
(self.max_wav_len is not None and wav.shape[-1] > self.max_wav_len) or \
|
||||
(self.max_text_len is not None and tseq.shape[0] > self.max_text_len):
|
||||
|
@ -247,7 +230,9 @@ if __name__ == '__main__':
|
|||
'max_wav_length': 255995,
|
||||
'max_text_length': 200,
|
||||
'sample_rate': 22050,
|
||||
'load_conditioning': False,
|
||||
'load_conditioning': True,
|
||||
'num_conditioning_candidates': 2,
|
||||
'conditioning_length': 44000,
|
||||
'use_bpe_tokenizer': False,
|
||||
}
|
||||
from data import create_dataset, create_dataloader
|
||||
|
@ -259,3 +244,4 @@ if __name__ == '__main__':
|
|||
for i, b in tqdm(enumerate(dl)):
|
||||
for ib in range(batch_sz):
|
||||
print(f"text_seq: {b['text_lengths'].max()}, speech_seq: {b['wav_lengths'].max()//1024}")
|
||||
|
||||
|
|
|
@ -51,6 +51,41 @@ def load_audio(audiopath, sampling_rate):
|
|||
return audio.unsqueeze(0)
|
||||
|
||||
|
||||
def load_similar_clips(path, sample_length, sample_rate, n=3, include_self=True):
|
||||
sim_path = os.path.join(os.path.dirname(path), 'similarities.pth')
|
||||
candidates = []
|
||||
if os.path.exists(sim_path):
|
||||
similarities = torch.load(sim_path)
|
||||
fname = os.path.basename(path)
|
||||
if fname in similarities.keys():
|
||||
candidates = similarities[fname]
|
||||
else:
|
||||
print(f'Similarities list found for {path} but {fname} was not in that list.')
|
||||
if len(candidates) == 0:
|
||||
print(f"Falling back to non-similarity list for {path}")
|
||||
candidates = find_files_of_type('img', os.path.dirname(path), qualifier=is_audio_file)[0]
|
||||
|
||||
assert len(candidates) < 50000 # Sanity check to ensure we aren't loading "related files" that aren't actually related.
|
||||
if not include_self:
|
||||
candidates.remove(path)
|
||||
if len(candidates) == 0:
|
||||
print(f"No conditioning candidates found for {path}")
|
||||
raise NotImplementedError()
|
||||
|
||||
# Sample with replacement. This can get repeats, but more conveniently handles situations where there are not enough candidates.
|
||||
related_clips = []
|
||||
for k in range(n):
|
||||
rel_clip = load_audio(os.path.join(os.path.dirname(path), random.choice(candidates)), sample_rate)
|
||||
gap = rel_clip.shape[-1] - sample_length
|
||||
if gap < 0:
|
||||
rel_clip = F.pad(rel_clip, pad=(0, abs(gap)))
|
||||
elif gap > 0:
|
||||
rand_start = random.randint(0, gap)
|
||||
rel_clip = rel_clip[:, rand_start:rand_start+sample_length]
|
||||
related_clips.append(rel_clip)
|
||||
return torch.stack(related_clips, dim=0)
|
||||
|
||||
|
||||
class UnsupervisedAudioDataset(torch.utils.data.Dataset):
|
||||
|
||||
def __init__(self, opt):
|
||||
|
@ -77,8 +112,7 @@ class UnsupervisedAudioDataset(torch.utils.data.Dataset):
|
|||
|
||||
# "Extra samples" are other audio clips pulled from wav files in the same directory as the 'clip' wav file.
|
||||
self.extra_samples = opt_get(opt, ['extra_samples'], 0)
|
||||
self.extra_sample_len = opt_get(opt, ['extra_sample_length'], 2)
|
||||
self.extra_sample_len *= self.sampling_rate
|
||||
self.extra_sample_len = opt_get(opt, ['extra_sample_length'], 44000)
|
||||
|
||||
self.debug_loading_failures = opt_get(opt, ['debug_loading_failures'], True)
|
||||
|
||||
|
@ -92,38 +126,13 @@ class UnsupervisedAudioDataset(torch.utils.data.Dataset):
|
|||
if self.extra_samples <= 0:
|
||||
return None, 0
|
||||
audiopath = self.audiopaths[index]
|
||||
related_files = find_files_of_type('img', os.path.dirname(audiopath), qualifier=is_audio_file)[0]
|
||||
assert audiopath in related_files
|
||||
assert len(related_files) < 50000 # Sanity check to ensure we aren't loading "related files" that aren't actually related.
|
||||
if len(related_files) == 0:
|
||||
print(f"No related files for {audiopath}")
|
||||
related_files.remove(audiopath)
|
||||
related_clips = []
|
||||
random.shuffle(related_clips)
|
||||
i = 0
|
||||
for related_file in related_files:
|
||||
rel_clip = load_audio(related_file, self.sampling_rate)
|
||||
gap = rel_clip.shape[-1] - self.extra_sample_len
|
||||
if gap < 0:
|
||||
rel_clip = F.pad(rel_clip, pad=(0, abs(gap)))
|
||||
elif gap > 0:
|
||||
rand_start = random.randint(0, gap)
|
||||
rel_clip = rel_clip[:, rand_start:rand_start+self.extra_sample_len]
|
||||
related_clips.append(rel_clip)
|
||||
i += 1
|
||||
if i >= self.extra_samples:
|
||||
break
|
||||
actual_extra_samples = i
|
||||
while i < self.extra_samples:
|
||||
related_clips.append(torch.zeros(1, self.extra_sample_len))
|
||||
i += 1
|
||||
return torch.stack(related_clips, dim=0), actual_extra_samples
|
||||
return load_similar_clips(audiopath, self.extra_sample_len, self.sampling_rate, n=self.extra_samples)
|
||||
|
||||
def __getitem__(self, index):
|
||||
try:
|
||||
# Split audio_norm into two tensors of equal size.
|
||||
audio_norm, filename = self.get_audio_for_index(index)
|
||||
alt_files, actual_samples = self.get_related_audio_for_index(index)
|
||||
alt_files = self.get_related_audio_for_index(index)
|
||||
except:
|
||||
if self.debug_loading_failures:
|
||||
print(f"Error loading audio for file {self.audiopaths[index]} {sys.exc_info()}")
|
||||
|
@ -155,7 +164,6 @@ class UnsupervisedAudioDataset(torch.utils.data.Dataset):
|
|||
output['resampled_clip'] = clips[1]
|
||||
if self.extra_samples > 0:
|
||||
output['alt_clips'] = alt_files
|
||||
output['num_alt_clips'] = actual_samples
|
||||
return output
|
||||
|
||||
def __len__(self):
|
||||
|
|
Loading…
Reference in New Issue
Block a user