DL-Art-School/codes/data/audio/fast_paired_dataset.py

315 lines
13 KiB
Python
Raw Normal View History

import hashlib
import os
import random
import sys
import time
from itertools import groupby
import torch
import torch.nn.functional as F
import torch.utils.data
import torchaudio
from tqdm import tqdm
2022-01-17 04:20:00 +00:00
from data.audio.paired_voice_audio_dataset import CharacterTokenizer
from data.audio.unsupervised_audio_dataset import load_audio, load_similar_clips
from utils.util import opt_get
def parse_tsv_aligned_codes(line, base_path):
fpt = line.strip().split('\t')
def convert_string_list_to_tensor(strlist):
if strlist.startswith('['):
strlist = strlist[1:]
if strlist.endswith(']'):
strlist = strlist[:-1]
as_ints = [int(s) for s in strlist.split(', ')]
return torch.tensor(as_ints)
return os.path.join(base_path, f'{fpt[1]}'), fpt[0], convert_string_list_to_tensor(fpt[2])
2022-01-17 04:20:00 +00:00
class FastPairedVoiceDataset(torch.utils.data.Dataset):
"""
This dataset is derived from paired_voice_audio, but it only supports loading from TSV files generated from the
ocotillo transcription engine, which includes alignment codes. To support the vastly larger TSV files, this dataset
uses an indexing mechanism which randomly selects offsets within the translation file to seek to. The data returned
is relative to these offsets.
2022-01-17 04:20:00 +00:00
In practice, this means two things:
1) Index {i} of this dataset means nothing: fetching from the same index will almost always return different data.
As a result, this dataset should not be used for validation or test runs. Use PairedVoiceAudio dataset instead.
2022-01-17 04:20:00 +00:00
2) This dataset has a slight bias for items with longer text or longer filenames.
2022-01-17 04:20:00 +00:00
The upshot is that this dataset loads extremely quickly and consumes almost no system memory.
"""
def __init__(self, hparams):
self.paths = hparams['path']
if not isinstance(self.paths, list):
self.paths = [self.paths]
self.paths_size_bytes = [os.path.getsize(p) for p in self.paths]
self.total_size_bytes = sum(self.paths_size_bytes)
2022-04-17 05:36:57 +00:00
self.types = opt_get(hparams, ['types'], [0 for _ in self.paths])
self.load_conditioning = opt_get(hparams, ['load_conditioning'], False)
self.conditioning_candidates = opt_get(hparams, ['num_conditioning_candidates'], 1)
self.conditioning_length = opt_get(hparams, ['conditioning_length'], 44100)
self.produce_ctc_metadata = opt_get(hparams, ['produce_ctc_metadata'], False)
self.debug_failures = opt_get(hparams, ['debug_loading_failures'], False)
self.text_cleaners = hparams.text_cleaners
self.sample_rate = hparams.sample_rate
2022-01-26 00:57:16 +00:00
self.aligned_codes_to_audio_ratio = 443 * self.sample_rate // 22050
self.max_wav_len = opt_get(hparams, ['max_wav_length'], None)
2022-02-16 03:53:07 +00:00
self.load_aligned_codes = opt_get(hparams, ['load_aligned_codes'], False)
if self.max_wav_len is not None:
self.max_aligned_codes = self.max_wav_len // self.aligned_codes_to_audio_ratio
self.max_text_len = opt_get(hparams, ['max_text_length'], None)
assert self.max_wav_len is not None and self.max_text_len is not None
2022-01-17 04:20:00 +00:00
self.use_bpe_tokenizer = opt_get(hparams, ['use_bpe_tokenizer'], False)
if self.use_bpe_tokenizer:
from data.audio.voice_tokenizer import VoiceBpeTokenizer
self.tokenizer = VoiceBpeTokenizer(opt_get(hparams, ['tokenizer_vocab'], '../experiments/bpe_lowercase_asr_256.json'))
else:
self.tokenizer = CharacterTokenizer()
self.skipped_items = 0 # records how many items are skipped when accessing an index.
self.load_times = torch.zeros((256,))
self.load_ind = 0
def get_wav_text_pair(self, audiopath_and_text):
# separate filename and text
audiopath, text = audiopath_and_text[0], audiopath_and_text[1]
text_seq = self.get_text(text)
wav = load_audio(audiopath, self.sample_rate)
return (text_seq, wav, text, audiopath_and_text[0])
def get_text(self, text):
tokens = self.tokenizer.encode(text)
tokens = torch.IntTensor(tokens)
if self.use_bpe_tokenizer:
# Assert if any UNK,start tokens encountered.
assert not torch.any(tokens == 1)
# The stop token should always be sacred.
assert not torch.any(tokens == 0)
return tokens
def load_random_line(self, depth=0):
assert depth < 10
rand_offset = random.randint(0, self.total_size_bytes)
for i in range(len(self.paths)):
if rand_offset < self.paths_size_bytes[i]:
break
else:
rand_offset -= self.paths_size_bytes[i]
path = self.paths[i]
2022-04-17 05:36:57 +00:00
type = self.types[i]
with open(path, 'r', encoding='utf-8') as f:
f.seek(rand_offset)
# Read the rest of the line we seeked to, then the line after that.
try: # This can fail when seeking to a UTF-8 escape byte.
f.readline()
except:
2022-04-17 05:36:57 +00:00
return self.load_random_line(depth=depth + 1), type # On failure, just recurse and try again.
l2 = f.readline()
if l2:
try:
base_path = os.path.dirname(path)
2022-04-17 05:36:57 +00:00
return parse_tsv_aligned_codes(l2, base_path), type
except:
print(f"error parsing random offset: {sys.exc_info()}")
2022-04-17 05:36:57 +00:00
return self.load_random_line(depth=depth+1), type # On failure, just recurse and try again.
def get_ctc_metadata(self, codes):
grouped = groupby(codes.tolist())
2022-02-05 22:59:53 +00:00
rcodes, repeats, seps = [], [], [0]
for val, group in grouped:
if val == 0:
2022-02-05 22:59:53 +00:00
seps[-1] = len(list(group)) # This is a very important distinction! It means the padding belongs to the character proceeding it.
else:
2022-02-05 22:59:53 +00:00
rcodes.append(val)
repeats.append(len(list(group)))
2022-02-05 22:59:53 +00:00
seps.append(0)
2022-02-05 22:59:53 +00:00
rcodes = torch.tensor(rcodes)
# These clip values are sane maximum values which I did not see in the datasets I have access to.
2022-02-05 22:59:53 +00:00
repeats = torch.clip(torch.tensor(repeats), min=1, max=30)
seps = torch.clip(torch.tensor(seps[:-1]), max=120)
# Pad or clip the codes to get them to exactly self.max_text_len
2022-02-05 22:59:53 +00:00
orig_lens = rcodes.shape[0]
if rcodes.shape[0] < self.max_text_len:
gap = self.max_text_len - rcodes.shape[0]
rcodes = F.pad(rcodes, (0, gap))
repeats = F.pad(repeats, (0, gap), value=1) # The minimum value for repeats is 1, hence this is the pad value too.
seps = F.pad(seps, (0, gap))
elif rcodes.shape[0] > self.max_text_len:
rcodes = rcodes[:self.max_text_len]
repeats = rcodes[:self.max_text_len]
seps = seps[:self.max_text_len]
return {
2022-02-05 22:59:53 +00:00
'ctc_raw_codes': rcodes,
'ctc_separators': seps,
'ctc_repeats': repeats,
'ctc_raw_lengths': orig_lens,
}
def __getitem__(self, index):
start = time.time()
self.skipped_items += 1
2022-04-17 05:36:57 +00:00
apt, type = self.load_random_line()
try:
tseq, wav, text, path = self.get_wav_text_pair(apt)
if text is None or len(text.strip()) == 0:
raise ValueError
cond, cond_is_self = load_similar_clips(apt[0], self.conditioning_length, self.sample_rate,
n=self.conditioning_candidates) if self.load_conditioning else (None, False)
except:
if self.skipped_items > 100:
raise # Rethrow if we have nested too far.
if self.debug_failures:
print(f"error loading {apt[0]} {sys.exc_info()}")
return self[(index+1) % len(self)]
raw_codes = apt[2]
aligned_codes = raw_codes
actually_skipped_items = self.skipped_items
self.skipped_items = 0
if wav is None or \
(self.max_wav_len is not None and wav.shape[-1] > self.max_wav_len) or \
(self.max_text_len is not None and tseq.shape[0] > self.max_text_len):
# Basically, this audio file is nonexistent or too long to be supported by the dataset.
# It's hard to handle this situation properly. Best bet is to return the a random valid token and skew the dataset somewhat as a result.
if self.debug_failures:
print(f"error loading {path}: ranges are out of bounds; {wav.shape[-1]}, {tseq.shape[0]}")
rv = random.randint(0,len(self)-1)
return self[rv]
orig_output = wav.shape[-1]
orig_text_len = tseq.shape[0]
orig_aligned_code_length = aligned_codes.shape[0]
if wav.shape[-1] != self.max_wav_len:
wav = F.pad(wav, (0, self.max_wav_len - wav.shape[-1]))
2022-01-17 04:20:00 +00:00
# These codes are aligned to audio inputs, so make sure to pad them as well.
aligned_codes = F.pad(aligned_codes, (0, self.max_aligned_codes-aligned_codes.shape[0]))
if tseq.shape[0] != self.max_text_len:
tseq = F.pad(tseq, (0, self.max_text_len - tseq.shape[0]))
elapsed = time.time() - start
self.load_times[self.load_ind] = elapsed
self.load_ind = (self.load_ind + 1) % len(self.load_times)
res = {
'real_text': text,
'padded_text': tseq,
'text_lengths': torch.tensor(orig_text_len, dtype=torch.long),
'wav': wav,
'wav_lengths': torch.tensor(orig_output, dtype=torch.long),
'filenames': path,
'skipped_items': actually_skipped_items,
2022-04-17 05:36:57 +00:00
'load_time': self.load_times.mean(),
'type': type,
}
if self.load_conditioning:
res['conditioning'] = cond
res['conditioning_contains_self'] = cond_is_self
2022-02-16 03:53:07 +00:00
if self.load_aligned_codes:
res['aligned_codes']: aligned_codes
res['aligned_codes_lengths']: orig_aligned_code_length
if self.produce_ctc_metadata:
res.update(self.get_ctc_metadata(raw_codes))
return res
def __len__(self):
2022-01-17 04:20:00 +00:00
return self.total_size_bytes // 1000 # 1000 cuts down a TSV file to the actual length pretty well.
class FastPairedVoiceDebugger:
def __init__(self):
self.total_items = 0
self.loaded_items = 0
self.self_conditioning_items = 0
self.unique_files = set()
self.load_time = 0
def get_state(self):
return {'total_items': self.total_items,
'loaded_items': self.loaded_items,
2022-01-21 07:02:06 +00:00
'self_conditioning_items': self.self_conditioning_items}
def load_state(self, state):
if isinstance(state, dict):
self.total_items = opt_get(state, ['total_items'], 0)
self.loaded_items = opt_get(state, ['loaded_items'], 0)
self.self_conditioning_items = opt_get(state, ['self_conditioning_items'], 0)
def update(self, batch):
self.total_items += batch['wav'].shape[0]
self.loaded_items += batch['skipped_items'].sum().item()
self.load_time = batch['load_time'].mean().item()
for filename in batch['filenames']:
self.unique_files.add(hashlib.sha256(filename.encode('utf-8')))
if 'conditioning' in batch.keys():
self.self_conditioning_items += batch['conditioning_contains_self'].sum().item()
def get_debugging_map(self):
return {
'total_samples_loaded': self.total_items,
'percent_skipped_samples': (self.loaded_items - self.total_items) / self.loaded_items,
'percent_conditioning_is_self': self.self_conditioning_items / self.loaded_items,
'unique_files_loaded': len(self.unique_files),
'load_time': self.load_time,
}
if __name__ == '__main__':
2022-03-25 06:03:18 +00:00
batch_sz = 16
params = {
2022-01-17 04:20:00 +00:00
'mode': 'fast_paired_voice_audio',
'path': ['y:/libritts/train-other-500/transcribed-oco.tsv',
'y:/libritts/train-clean-100/transcribed-oco.tsv',
'y:/libritts/train-clean-360/transcribed-oco.tsv',
2022-04-17 05:36:57 +00:00
'y:/clips/books1/transcribed-oco.tsv',
'y:/clips/books2/transcribed-oco.tsv',
'y:/bigasr_dataset/hifi_tts/transcribed-oco.tsv',
2022-03-25 06:03:18 +00:00
'y:/clips/podcasts-1/transcribed-oco.tsv',],
2022-04-17 05:36:57 +00:00
'types': [0,1,1,1,2,2,0],
'phase': 'train',
'n_workers': 0,
'batch_size': batch_sz,
2022-03-25 06:03:18 +00:00
'max_wav_length': 220500,
'max_text_length': 500,
'sample_rate': 22050,
2022-01-17 04:20:00 +00:00
'load_conditioning': True,
2022-03-25 06:03:18 +00:00
'num_conditioning_candidates': 2,
'conditioning_length': 102400,
'use_bpe_tokenizer': True,
'load_aligned_codes': False,
}
from data import create_dataset, create_dataloader
def save(b, i, ib, key, c=None):
if c is not None:
torchaudio.save(f'{i}_clip_{ib}_{key}_{c}.wav', b[key][ib][c], 22050)
else:
torchaudio.save(f'{i}_clip_{ib}_{key}.wav', b[key][ib], 22050)
ds, c = create_dataset(params, return_collate=True)
dl = create_dataloader(ds, params, collate_fn=c)
i = 0
m = None
max_pads, max_repeats = 0, 0
for i, b in tqdm(enumerate(dl)):
for ib in range(batch_sz):
#max_pads = max(max_pads, b['ctc_pads'].max())
#max_repeats = max(max_repeats, b['ctc_repeats'].max())
print(f'{i} {ib} {b["real_text"][ib]}')
save(b, i, ib, 'wav')
2022-03-25 06:03:18 +00:00
save(b, i, ib, 'conditioning', 0)
save(b, i, ib, 'conditioning', 1)
pass
if i > 15:
break
print(max_pads, max_repeats)