forked from mrq/DL-Art-School
grand: support validation mode
This commit is contained in:
parent
e55d949855
commit
5bc9772cb0
|
@ -44,10 +44,12 @@ class GrandConjoinedDataset(torch.utils.data.Dataset):
|
|||
Performs tokenization at this level, ignoring any tokenization performed by upstream datasets.
|
||||
"""
|
||||
def __init__(self, opt):
|
||||
sample_rate = 22050 # Fixed.
|
||||
paired_dataset_args = opt['paired_dataset_args']
|
||||
unsupervised_audio_args = opt['unsupervised_audio_args']
|
||||
text_corpus_args = opt['text_corpus_args']
|
||||
sample_rate = 22050
|
||||
self.only_paired = opt_get(opt, ['only_paired'], False)
|
||||
if not self.only_paired:
|
||||
unsupervised_audio_args = opt['unsupervised_audio_args']
|
||||
text_corpus_args = opt['text_corpus_args']
|
||||
|
||||
self.max_paired_audio_length = opt['max_paired_audio_length']
|
||||
self.max_paired_text_length = opt['max_paired_text_length']
|
||||
|
@ -61,14 +63,15 @@ class GrandConjoinedDataset(torch.utils.data.Dataset):
|
|||
paired_dataset_args['sample_rate'] = sample_rate
|
||||
paired_dataset_args['max_wav_length'] = self.max_paired_audio_length
|
||||
paired_dataset_args['max_text_length'] = self.max_paired_text_length
|
||||
unsupervised_audio_args['sampling_rate'] = sample_rate
|
||||
unsupervised_audio_args['do_augmentation'] = False
|
||||
unsupervised_audio_args['resample_clip'] = False
|
||||
unsupervised_audio_args['pad_to_samples'] = self.max_solo_audio_length
|
||||
|
||||
self.speech_and_text = build_paired_voice_dataset(paired_dataset_args)
|
||||
self.speech = UnsupervisedAudioDataset(unsupervised_audio_args)
|
||||
self.text = HfDataset(**text_corpus_args)
|
||||
|
||||
if not self.only_paired:
|
||||
unsupervised_audio_args['sampling_rate'] = sample_rate
|
||||
unsupervised_audio_args['do_augmentation'] = False
|
||||
unsupervised_audio_args['resample_clip'] = False
|
||||
unsupervised_audio_args['pad_to_samples'] = self.max_solo_audio_length
|
||||
self.speech = UnsupervisedAudioDataset(unsupervised_audio_args)
|
||||
self.text = HfDataset(**text_corpus_args)
|
||||
|
||||
def fetch_text_at(self, i):
|
||||
try:
|
||||
|
@ -76,7 +79,7 @@ class GrandConjoinedDataset(torch.utils.data.Dataset):
|
|||
tok = self.speech_and_text.get_text(txt)
|
||||
padding_required = self.max_solo_text_length - tok.shape[0]
|
||||
if padding_required < 0:
|
||||
# Just truncate since there is no conditioning requried.
|
||||
# Just truncate since there is no conditioning required.
|
||||
tok = tok[:self.max_solo_text_length]
|
||||
elif padding_required > 0:
|
||||
tok = F.pad(tok, (0, padding_required))
|
||||
|
@ -88,29 +91,45 @@ class GrandConjoinedDataset(torch.utils.data.Dataset):
|
|||
|
||||
def __getitem__(self, i):
|
||||
snt = self.speech_and_text[i % len(self.speech_and_text)]
|
||||
sp = self.speech[i % len(self.speech)]
|
||||
txt, txt_tok = self.fetch_text_at(i % len(self.text))
|
||||
|
||||
return {
|
||||
'paired_audio': snt['wav'],
|
||||
'paired_audio_lengths': snt['wav_lengths'],
|
||||
'paired_text': snt['real_text'],
|
||||
'paired_text_tokens': snt['padded_text'],
|
||||
'paired_file': snt['filenames'],
|
||||
'speech_audio': sp['clip'],
|
||||
'speech_lengths': clamp(sp['clip_lengths'], 0, self.max_solo_audio_length),
|
||||
'speech_file': sp['path'],
|
||||
'text_text': txt,
|
||||
'text_tokens': txt_tok,
|
||||
}
|
||||
if self.only_paired:
|
||||
return {
|
||||
'paired_audio': snt['wav'],
|
||||
'paired_audio_lengths': snt['wav_lengths'],
|
||||
'paired_text': snt['real_text'],
|
||||
'paired_text_tokens': snt['padded_text'],
|
||||
'paired_file': snt['filenames'],
|
||||
'speech_audio': snt['wav'],
|
||||
'speech_lengths': snt['wav_lengths'],
|
||||
'speech_file': snt['filenames'],
|
||||
'text_text': snt['real_text'],
|
||||
'text_tokens': snt['padded_text'],
|
||||
}
|
||||
else:
|
||||
sp = self.speech[i % len(self.speech)]
|
||||
txt, txt_tok = self.fetch_text_at(i % len(self.text))
|
||||
return {
|
||||
'paired_audio': snt['wav'],
|
||||
'paired_audio_lengths': snt['wav_lengths'],
|
||||
'paired_text': snt['real_text'],
|
||||
'paired_text_tokens': snt['padded_text'],
|
||||
'paired_file': snt['filenames'],
|
||||
'speech_audio': sp['clip'],
|
||||
'speech_lengths': clamp(sp['clip_lengths'], 0, self.max_solo_audio_length),
|
||||
'speech_file': sp['path'],
|
||||
'text_text': txt,
|
||||
'text_tokens': txt_tok,
|
||||
}
|
||||
|
||||
def __len__(self):
|
||||
return max(len(self.speech), len(self.speech_and_text), len(self.text))
|
||||
if self.only_paired:
|
||||
return len(self.speech_and_text)
|
||||
else:
|
||||
return max(len(self.speech), len(self.speech_and_text), len(self.text))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
batch_sz = 8
|
||||
params = {
|
||||
train_params = {
|
||||
'mode': 'grand_conjoined_voice',
|
||||
'phase': 'train',
|
||||
'n_workers': 0,
|
||||
|
@ -133,10 +152,26 @@ if __name__ == '__main__':
|
|||
'cache_path': 'Z:\\huggingface_datasets\\cache',
|
||||
},
|
||||
}
|
||||
val_params = {
|
||||
'mode': 'grand_conjoined_voice',
|
||||
'phase': 'val',
|
||||
'n_workers': 0,
|
||||
'batch_size': batch_sz,
|
||||
|
||||
'max_paired_audio_length': 255995,
|
||||
'max_paired_text_length': 80,
|
||||
'max_solo_text_length': 330,
|
||||
'max_solo_audio_length': 300000,
|
||||
'only_paired': True,
|
||||
'paired_dataset_args': {
|
||||
'path': ['Z:\\bigasr_dataset\\libritts\\test-clean_list.txt'],
|
||||
'fetcher_mode': ['libritts'],
|
||||
},
|
||||
}
|
||||
from data import create_dataset, create_dataloader
|
||||
|
||||
ds = create_dataset(params)
|
||||
dl = create_dataloader(ds, params)
|
||||
ds = create_dataset(val_params)
|
||||
dl = create_dataloader(ds, val_params)
|
||||
|
||||
def save(b, i, ib, key):
|
||||
torchaudio.save(f'{i}_clip_{ib}_{key}.wav', b[key][ib], 22050)
|
||||
|
|
Loading…
Reference in New Issue
Block a user