De-specify fast-paired-dataset

This commit is contained in:
James Betker 2022-01-16 21:20:00 -07:00
parent 2b36ca5f8e
commit 1d30d79e34
2 changed files with 32 additions and 99 deletions

View File

@ -75,6 +75,12 @@ def create_dataset(dataset_opt, return_collate=False):
default_params = create_hparams() default_params = create_hparams()
default_params.update(dataset_opt) default_params.update(dataset_opt)
dataset_opt = munchify(default_params) dataset_opt = munchify(default_params)
elif mode == 'fast_paired_voice_audio':
from data.audio.fast_paired_dataset import FastPairedVoiceDataset as D
from models.tacotron2.hparams import create_hparams
default_params = create_hparams()
default_params.update(dataset_opt)
dataset_opt = munchify(default_params)
elif mode == 'gpt_tts': elif mode == 'gpt_tts':
from data.audio.gpt_tts_dataset import GptTtsDataset as D from data.audio.gpt_tts_dataset import GptTtsDataset as D
from data.audio.gpt_tts_dataset import GptTtsCollater as C from data.audio.gpt_tts_dataset import GptTtsCollater as C
@ -100,7 +106,7 @@ def create_dataset(dataset_opt, return_collate=False):
def get_dataset_debugger(dataset_opt): def get_dataset_debugger(dataset_opt):
mode = dataset_opt['mode'] mode = dataset_opt['mode']
if mode == 'paired_voice_audio': if mode == 'paired_voice_audio' or mode == 'fast_paired_voice_audio':
from data.audio.paired_voice_audio_dataset import PairedVoiceDebugger from data.audio.paired_voice_audio_dataset import PairedVoiceDebugger
return PairedVoiceDebugger() return PairedVoiceDebugger()
return None return None

View File

@ -9,23 +9,13 @@ import torch.utils.data
import torchaudio import torchaudio
from tqdm import tqdm from tqdm import tqdm
from data.audio.paired_voice_audio_dataset import CharacterTokenizer
from data.audio.unsupervised_audio_dataset import load_audio, load_similar_clips from data.audio.unsupervised_audio_dataset import load_audio, load_similar_clips
from models.tacotron2.taco_utils import load_filepaths_and_text from models.tacotron2.taco_utils import load_filepaths_and_text
from models.tacotron2.text import text_to_sequence, sequence_to_text from models.tacotron2.text import text_to_sequence, sequence_to_text
from utils.util import opt_get from utils.util import opt_get
def parse_libri(line, base_path, split="|"):
fpt = line.strip().split(split)
fpt[0] = os.path.join(base_path, fpt[0])
return fpt
def parse_tsv(line, base_path):
fpt = line.strip().split('\t')
return os.path.join(base_path, f'{fpt[1]}'), fpt[0]
def parse_tsv_aligned_codes(line, base_path): def parse_tsv_aligned_codes(line, base_path):
fpt = line.strip().split('\t') fpt = line.strip().split('\t')
def convert_string_list_to_tensor(strlist): def convert_string_list_to_tensor(strlist):
@ -38,27 +28,19 @@ def parse_tsv_aligned_codes(line, base_path):
return os.path.join(base_path, f'{fpt[1]}'), fpt[0], convert_string_list_to_tensor(fpt[2]) return os.path.join(base_path, f'{fpt[1]}'), fpt[0], convert_string_list_to_tensor(fpt[2])
def parse_mozilla_cv(line, base_path): class FastPairedVoiceDataset(torch.utils.data.Dataset):
components = line.strip().split('\t') """
return os.path.join(base_path, f'clips/{components[1]}'), components[2] This dataset is derived from paired_voice_audio, but it only supports loading from TSV files generated from the
ocotillo transcription engine, which includes alignment codes. To support the vastly larger TSV files, this dataset
uses an indexing mechanism which randomly selects offsets within the translation file to seek to. The data returned
is relative to these offsets.
In practice, this means two things:
1) Index {i} of this dataset means nothing: fetching from the same index will almost always return different data.
2) This dataset has a slight bias for items with longer text or longer filenames.
def parse_voxpopuli(line, base_path): The upshot is that this dataset loads extremely quickly and consumes almost no system memory.
line = line.strip().split('\t') """
file, raw_text, norm_text, speaker_id, split, gender = line
year = file[:4]
return os.path.join(base_path, year, f'{file}.ogg.wav'), raw_text
class CharacterTokenizer:
def encode(self, txt):
return text_to_sequence(txt, ['english_cleaners'])
def decode(self, seq):
return sequence_to_text(seq)
class TextWavLoader(torch.utils.data.Dataset):
def __init__(self, hparams): def __init__(self, hparams):
self.paths = hparams['path'] self.paths = hparams['path']
if not isinstance(self.paths, list): if not isinstance(self.paths, list):
@ -66,16 +48,10 @@ class TextWavLoader(torch.utils.data.Dataset):
self.paths_size_bytes = [os.path.getsize(p) for p in self.paths] self.paths_size_bytes = [os.path.getsize(p) for p in self.paths]
self.total_size_bytes = sum(self.paths_size_bytes) self.total_size_bytes = sum(self.paths_size_bytes)
self.fetcher_mode = opt_get(hparams, ['fetcher_mode'], 'lj')
if not isinstance(self.fetcher_mode, list):
self.fetcher_mode = [self.fetcher_mode]
assert len(self.paths) == len(self.fetcher_mode)
self.load_conditioning = opt_get(hparams, ['load_conditioning'], False) self.load_conditioning = opt_get(hparams, ['load_conditioning'], False)
self.conditioning_candidates = opt_get(hparams, ['num_conditioning_candidates'], 1) self.conditioning_candidates = opt_get(hparams, ['num_conditioning_candidates'], 1)
self.conditioning_length = opt_get(hparams, ['conditioning_length'], 44100) self.conditioning_length = opt_get(hparams, ['conditioning_length'], 44100)
self.debug_failures = opt_get(hparams, ['debug_loading_failures'], False) self.debug_failures = opt_get(hparams, ['debug_loading_failures'], False)
self.load_aligned_codes = opt_get(hparams, ['load_aligned_codes'], False)
self.aligned_codes_to_audio_ratio = opt_get(hparams, ['aligned_codes_ratio'], 443) self.aligned_codes_to_audio_ratio = opt_get(hparams, ['aligned_codes_ratio'], 443)
self.text_cleaners = hparams.text_cleaners self.text_cleaners = hparams.text_cleaners
self.sample_rate = hparams.sample_rate self.sample_rate = hparams.sample_rate
@ -84,7 +60,7 @@ class TextWavLoader(torch.utils.data.Dataset):
self.max_aligned_codes = self.max_wav_len // self.aligned_codes_to_audio_ratio self.max_aligned_codes = self.max_wav_len // self.aligned_codes_to_audio_ratio
self.max_text_len = opt_get(hparams, ['max_text_length'], None) self.max_text_len = opt_get(hparams, ['max_text_length'], None)
assert self.max_wav_len is not None and self.max_text_len is not None assert self.max_wav_len is not None and self.max_text_len is not None
self.use_bpe_tokenizer = opt_get(hparams, ['use_bpe_tokenizer'], True) self.use_bpe_tokenizer = opt_get(hparams, ['use_bpe_tokenizer'], False)
if self.use_bpe_tokenizer: if self.use_bpe_tokenizer:
from data.audio.voice_tokenizer import VoiceBpeTokenizer from data.audio.voice_tokenizer import VoiceBpeTokenizer
self.tokenizer = VoiceBpeTokenizer(opt_get(hparams, ['tokenizer_vocab'], '../experiments/bpe_lowercase_asr_256.json')) self.tokenizer = VoiceBpeTokenizer(opt_get(hparams, ['tokenizer_vocab'], '../experiments/bpe_lowercase_asr_256.json'))
@ -119,7 +95,6 @@ class TextWavLoader(torch.utils.data.Dataset):
else: else:
rand_offset -= self.paths_size_bytes[i] rand_offset -= self.paths_size_bytes[i]
path = self.paths[i] path = self.paths[i]
fm = self.fetcher_mode[i]
with open(path, 'r', encoding='utf-8') as f: with open(path, 'r', encoding='utf-8') as f:
f.seek(rand_offset) f.seek(rand_offset)
# Read the rest of the line we seeked to, then the line after that. # Read the rest of the line we seeked to, then the line after that.
@ -132,18 +107,7 @@ class TextWavLoader(torch.utils.data.Dataset):
if l2: if l2:
try: try:
base_path = os.path.dirname(path) base_path = os.path.dirname(path)
if fm == 'lj' or fm == 'libritts': return parse_tsv_aligned_codes(l2, base_path)
return parse_libri(l2, base_path)
elif fm == 'tsv':
return parse_tsv_aligned_codes(l2, base_path) if self.load_aligned_codes else parse_tsv(l2, base_path)
elif fm == 'mozilla_cv':
assert not self.load_conditioning # Conditioning inputs are incompatible with mozilla_cv
return parse_mozilla_cv(l2, base_path)
elif fm == 'voxpopuli':
assert not self.load_conditioning # Conditioning inputs are incompatible with voxpopuli
return parse_voxpopuli(l2, base_path)
else:
raise NotImplementedError()
except: except:
print(f"error parsing random offset: {sys.exc_info()}") print(f"error parsing random offset: {sys.exc_info()}")
return self.load_random_line(depth=depth+1) # On failure, just recurse and try again. return self.load_random_line(depth=depth+1) # On failure, just recurse and try again.
@ -164,8 +128,6 @@ class TextWavLoader(torch.utils.data.Dataset):
if self.debug_failures: if self.debug_failures:
print(f"error loading {apt[0]} {sys.exc_info()}") print(f"error loading {apt[0]} {sys.exc_info()}")
return self[(index+1) % len(self)] return self[(index+1) % len(self)]
if self.load_aligned_codes:
aligned_codes = apt[2] aligned_codes = apt[2]
actually_skipped_items = self.skipped_items actually_skipped_items = self.skipped_items
@ -183,7 +145,6 @@ class TextWavLoader(torch.utils.data.Dataset):
orig_text_len = tseq.shape[0] orig_text_len = tseq.shape[0]
if wav.shape[-1] != self.max_wav_len: if wav.shape[-1] != self.max_wav_len:
wav = F.pad(wav, (0, self.max_wav_len - wav.shape[-1])) wav = F.pad(wav, (0, self.max_wav_len - wav.shape[-1]))
if self.load_aligned_codes:
# These codes are aligned to audio inputs, so make sure to pad them as well. # These codes are aligned to audio inputs, so make sure to pad them as well.
aligned_codes = F.pad(aligned_codes, (0, self.max_aligned_codes-aligned_codes.shape[0])) aligned_codes = F.pad(aligned_codes, (0, self.max_aligned_codes-aligned_codes.shape[0]))
if tseq.shape[0] != self.max_text_len: if tseq.shape[0] != self.max_text_len:
@ -191,6 +152,7 @@ class TextWavLoader(torch.utils.data.Dataset):
res = { res = {
'real_text': text, 'real_text': text,
'padded_text': tseq, 'padded_text': tseq,
'aligned_codes': aligned_codes,
'text_lengths': torch.tensor(orig_text_len, dtype=torch.long), 'text_lengths': torch.tensor(orig_text_len, dtype=torch.long),
'wav': wav, 'wav': wav,
'wav_lengths': torch.tensor(orig_output, dtype=torch.long), 'wav_lengths': torch.tensor(orig_output, dtype=torch.long),
@ -200,63 +162,28 @@ class TextWavLoader(torch.utils.data.Dataset):
if self.load_conditioning: if self.load_conditioning:
res['conditioning'] = cond res['conditioning'] = cond
res['conditioning_contains_self'] = cond_is_self res['conditioning_contains_self'] = cond_is_self
if self.load_aligned_codes:
res['aligned_codes'] = aligned_codes
return res return res
def __len__(self): def __len__(self):
return self.total_size_bytes // 1000 # 1000 cuts down a TSV file to the actual length pretty well, but doesn't work with the other formats. return self.total_size_bytes // 1000 # 1000 cuts down a TSV file to the actual length pretty well.
class PairedVoiceDebugger:
def __init__(self):
self.total_items = 0
self.loaded_items = 0
self.self_conditioning_items = 0
def get_state(self):
return {'total_items': self.total_items,
'loaded_items': self.loaded_items,
'self_conditioning_items': self.self_conditioning_items}
def load_state(self, state):
if isinstance(state, dict):
self.total_items = opt_get(state, ['total_items'], 0)
self.loaded_items = opt_get(state, ['loaded_items'], 0)
self.self_conditioning_items = opt_get(state, ['self_conditioning_items'], 0)
def update(self, batch):
self.total_items += batch['wav'].shape[0]
self.loaded_items += batch['skipped_items'].sum().item()
if 'conditioning' in batch.keys():
self.self_conditioning_items += batch['conditioning_contains_self'].sum().item()
def get_debugging_map(self):
return {
'total_samples_loaded': self.total_items,
'percent_skipped_samples': (self.loaded_items - self.total_items) / self.loaded_items,
'percent_conditioning_is_self': self.self_conditioning_items / self.loaded_items,
}
if __name__ == '__main__': if __name__ == '__main__':
batch_sz = 16 batch_sz = 16
params = { params = {
'mode': 'paired_voice_audio', 'mode': 'fast_paired_voice_audio',
#'path': ['Y:\\clips\\books1\\transcribed-w2v.tsv'], 'path': ['Y:\\clips\\books1\\transcribed-w2v.tsv'],
'path': ['Y:\\bigasr_dataset\\mozcv\\en\\train.tsv'],
'fetcher_mode': ['mozilla_cv'],
'phase': 'train', 'phase': 'train',
'n_workers': 0, 'n_workers': 0,
'batch_size': batch_sz, 'batch_size': batch_sz,
'max_wav_length': 255995, 'max_wav_length': 255995,
'max_text_length': 200, 'max_text_length': 200,
'sample_rate': 22050, 'sample_rate': 22050,
'load_conditioning': False, 'load_conditioning': True,
'num_conditioning_candidates': 2, 'num_conditioning_candidates': 1,
'conditioning_length': 44000, 'conditioning_length': 44000,
'use_bpe_tokenizer': True, 'use_bpe_tokenizer': False,
'load_aligned_codes': False, 'load_aligned_codes': True,
} }
from data import create_dataset, create_dataloader from data import create_dataset, create_dataloader