support tts typing

This commit is contained in:
James Betker 2022-04-16 23:36:57 -06:00
parent 48cb6a5abd
commit 8fe0dff33c
4 changed files with 84 additions and 32 deletions

View File

@ -48,6 +48,7 @@ class FastPairedVoiceDataset(torch.utils.data.Dataset):
self.paths = [self.paths]
self.paths_size_bytes = [os.path.getsize(p) for p in self.paths]
self.total_size_bytes = sum(self.paths_size_bytes)
self.types = opt_get(hparams, ['types'], [0 for _ in self.paths])
self.load_conditioning = opt_get(hparams, ['load_conditioning'], False)
self.conditioning_candidates = opt_get(hparams, ['num_conditioning_candidates'], 1)
@ -101,22 +102,23 @@ class FastPairedVoiceDataset(torch.utils.data.Dataset):
else:
rand_offset -= self.paths_size_bytes[i]
path = self.paths[i]
type = self.types[i]
with open(path, 'r', encoding='utf-8') as f:
f.seek(rand_offset)
# Read the rest of the line we seeked to, then the line after that.
try: # This can fail when seeking to a UTF-8 escape byte.
f.readline()
except:
return self.load_random_line(depth=depth + 1) # On failure, just recurse and try again.
return self.load_random_line(depth=depth + 1), type # On failure, just recurse and try again.
l2 = f.readline()
if l2:
try:
base_path = os.path.dirname(path)
return parse_tsv_aligned_codes(l2, base_path)
return parse_tsv_aligned_codes(l2, base_path), type
except:
print(f"error parsing random offset: {sys.exc_info()}")
return self.load_random_line(depth=depth+1) # On failure, just recurse and try again.
return self.load_random_line(depth=depth+1), type # On failure, just recurse and try again.
def get_ctc_metadata(self, codes):
grouped = groupby(codes.tolist())
@ -155,7 +157,7 @@ class FastPairedVoiceDataset(torch.utils.data.Dataset):
def __getitem__(self, index):
start = time.time()
self.skipped_items += 1
apt = self.load_random_line()
apt, type = self.load_random_line()
try:
tseq, wav, text, path = self.get_wav_text_pair(apt)
if text is None or len(text.strip()) == 0:
@ -204,7 +206,8 @@ class FastPairedVoiceDataset(torch.utils.data.Dataset):
'wav_lengths': torch.tensor(orig_output, dtype=torch.long),
'filenames': path,
'skipped_items': actually_skipped_items,
'load_time': self.load_times.mean()
'load_time': self.load_times.mean(),
'type': type,
}
if self.load_conditioning:
res['conditioning'] = cond
@ -266,10 +269,11 @@ if __name__ == '__main__':
'path': ['y:/libritts/train-other-500/transcribed-oco.tsv',
'y:/libritts/train-clean-100/transcribed-oco.tsv',
'y:/libritts/train-clean-360/transcribed-oco.tsv',
'y:/clips/books1/transcribed-w2v.tsv',
'y:/clips/books2/transcribed-w2v.tsv',
'y:/bigasr_dataset/hifi_tts/transcribed-w2v.tsv',
'y:/clips/books1/transcribed-oco.tsv',
'y:/clips/books2/transcribed-oco.tsv',
'y:/bigasr_dataset/hifi_tts/transcribed-oco.tsv',
'y:/clips/podcasts-1/transcribed-oco.tsv',],
'types': [0,1,1,1,2,2,0],
'phase': 'train',
'n_workers': 0,
'batch_size': batch_sz,

View File

@ -9,11 +9,28 @@ import torchaudio
from tqdm import tqdm
from data.audio.unsupervised_audio_dataset import load_audio, load_similar_clips
from models.audio.tts.tacotron2 import load_filepaths_and_text
from models.audio.tts.tacotron2 import load_filepaths_and_text, load_filepaths_and_text_type
from models.audio.tts.tacotron2 import text_to_sequence, sequence_to_text
from utils.util import opt_get
def load_tsv_type(filename, type):
with open(filename, encoding='utf-8') as f:
filepaths_and_text = []
base = os.path.dirname(filename)
bad_lines = 0
for line in f:
components = line.strip().split('\t')
if len(components) < 2:
bad_lines += 1
if bad_lines > 1000:
print(f'{filename} contains 1000+ bad entries. Failing. Sample last entry: {line}')
raise ValueError
continue
filepaths_and_text.append([os.path.join(base, f'{components[1]}'), components[0]] + [type])
return filepaths_and_text
def load_tsv(filename):
with open(filename, encoding='utf-8') as f:
filepaths_and_text = []
@ -40,6 +57,23 @@ def convert_string_list_to_tensor(strlist):
return torch.tensor(as_ints)
def load_tsv_aligned_codes_type(filename, type):
with open(filename, encoding='utf-8') as f:
filepaths_and_text = []
base = os.path.dirname(filename)
bad_lines = 0
for line in f:
components = line.strip().split('\t')
if len(components) < 3:
bad_lines += 1
if bad_lines > 1000:
print(f'{filename} contains 1000+ bad entries. Failing. Sample last entry: {line}')
raise ValueError
continue
filepaths_and_text.append([os.path.join(base, f'{components[1]}'), components[0], convert_string_list_to_tensor(components[2])] + [type])
return filepaths_and_text
def load_tsv_aligned_codes(filename):
with open(filename, encoding='utf-8') as f:
filepaths_and_text = []
@ -57,15 +91,15 @@ def load_tsv_aligned_codes(filename):
return filepaths_and_text
def load_mozilla_cv(filename):
def load_mozilla_cv(filename, type):
with open(filename, encoding='utf-8') as f:
components = [line.strip().split('\t') for line in f][1:] # First line is the header
base = os.path.dirname(filename)
filepaths_and_text = [[os.path.join(base, f'clips/{component[1]}'), component[2]] for component in components]
filepaths_and_text = [[os.path.join(base, f'clips/{component[1]}'), component[2], type] for component in components]
return filepaths_and_text
def load_voxpopuli(filename):
def load_voxpopuli(filename, type):
with open(filename, encoding='utf-8') as f:
lines = [line.strip().split('\t') for line in f][1:] # First line is the header
base = os.path.dirname(filename)
@ -75,7 +109,7 @@ def load_voxpopuli(filename):
continue
file, raw_text, norm_text, speaker_id, split, gender = line
year = file[:4]
filepaths_and_text.append([os.path.join(base, year, f'{file}.ogg.wav'), raw_text])
filepaths_and_text.append([os.path.join(base, year, f'{file}.ogg.wav'), raw_text, type])
return filepaths_and_text
@ -92,6 +126,7 @@ class TextWavLoader(torch.utils.data.Dataset):
self.path = hparams['path']
if not isinstance(self.path, list):
self.path = [self.path]
self.types = opt_get(hparams, ['types'], [0 for _ in self.path])
fetcher_mode = opt_get(hparams, ['fetcher_mode'], 'lj')
if not isinstance(fetcher_mode, list):
@ -105,11 +140,11 @@ class TextWavLoader(torch.utils.data.Dataset):
self.load_aligned_codes = opt_get(hparams, ['load_aligned_codes'], False)
self.aligned_codes_to_audio_ratio = opt_get(hparams, ['aligned_codes_ratio'], 443)
self.audiopaths_and_text = []
for p, fm in zip(self.path, fetcher_mode):
for p, fm, type in zip(self.path, fetcher_mode, self.types):
if fm == 'lj' or fm == 'libritts':
fetcher_fn = load_filepaths_and_text
fetcher_fn = load_filepaths_and_text_type
elif fm == 'tsv':
fetcher_fn = load_tsv_aligned_codes if self.load_aligned_codes else load_tsv
fetcher_fn = load_tsv_aligned_codes_type if self.load_aligned_codes else load_tsv_type
elif fm == 'mozilla_cv':
assert not self.load_conditioning # Conditioning inputs are incompatible with mozilla_cv
fetcher_fn = load_mozilla_cv
@ -118,7 +153,7 @@ class TextWavLoader(torch.utils.data.Dataset):
fetcher_fn = load_voxpopuli
else:
raise NotImplementedError()
self.audiopaths_and_text.extend(fetcher_fn(p))
self.audiopaths_and_text.extend(fetcher_fn(p, type))
self.text_cleaners = hparams.text_cleaners
self.sample_rate = hparams.sample_rate
random.seed(hparams.seed)
@ -138,10 +173,10 @@ class TextWavLoader(torch.utils.data.Dataset):
def get_wav_text_pair(self, audiopath_and_text):
# separate filename and text
audiopath, text = audiopath_and_text[0], audiopath_and_text[1]
audiopath, text, type = audiopath_and_text[0], audiopath_and_text[1], audiopath_and_text[2]
text_seq = self.get_text(text)
wav = load_audio(audiopath, self.sample_rate)
return (text_seq, wav, text, audiopath_and_text[0])
return (text_seq, wav, text, audiopath_and_text[0], type)
def get_text(self, text):
tokens = self.tokenizer.encode(text)
@ -156,7 +191,7 @@ class TextWavLoader(torch.utils.data.Dataset):
def __getitem__(self, index):
self.skipped_items += 1
try:
tseq, wav, text, path = self.get_wav_text_pair(self.audiopaths_and_text[index])
tseq, wav, text, path, type = self.get_wav_text_pair(self.audiopaths_and_text[index])
if text is None or len(text.strip()) == 0:
raise ValueError
if wav is None or wav.shape[-1] < (.6 * self.sample_rate):
@ -202,6 +237,7 @@ class TextWavLoader(torch.utils.data.Dataset):
'wav_lengths': torch.tensor(orig_output, dtype=torch.long),
'filenames': path,
'skipped_items': actually_skipped_items,
'type': type,
}
if self.load_conditioning:
res['conditioning'] = cond
@ -249,7 +285,7 @@ if __name__ == '__main__':
batch_sz = 8
params = {
'mode': 'paired_voice_audio',
'path': ['Y:\\clips\\books1\\transcribed-w2v.tsv'],
'path': ['Y:\\clips\\books1\\transcribed-oco.tsv'],
'fetcher_mode': ['tsv'],
'phase': 'train',
'n_workers': 0,

View File

@ -26,6 +26,14 @@ def load_wav_to_torch(full_path):
return (torch.FloatTensor(data.astype(np.float32)) / norm_fix, sampling_rate)
def load_filepaths_and_text_type(filename, type, split="|"):
with open(filename, encoding='utf-8') as f:
filepaths_and_text = [list(line.strip().split(split)) + [type] for line in f]
base = os.path.dirname(filename)
for j in range(len(filepaths_and_text)):
filepaths_and_text[j][0] = os.path.join(base, filepaths_and_text[j][0])
return filepaths_and_text
def load_filepaths_and_text(filename, split="|"):
with open(filename, encoding='utf-8') as f:
filepaths_and_text = [line.strip().split(split) for line in f]

View File

@ -238,10 +238,10 @@ class MelEncoder(nn.Module):
class UnifiedVoice(nn.Module):
def __init__(self, layers=8, model_dim=512, heads=8, max_text_tokens=120, max_mel_tokens=250, max_conditioning_inputs=1,
mel_length_compression=1024, number_text_tokens=256,
start_text_token=255, stop_text_token=0, number_mel_codes=8194, start_mel_token=8192,
mel_length_compression=1024, number_text_tokens=256, number_mel_codes=8194, start_mel_token=8192,
stop_mel_token=8193, train_solo_embeddings=False, use_mel_codes_as_input=True,
checkpointing=True, average_conditioning_embeddings=False, freeze_everything_but_position_embeddings=False):
checkpointing=True, average_conditioning_embeddings=False, freeze_everything_but_position_embeddings=False,
types=1):
"""
Args:
layers: Number of layers in transformer stack.
@ -252,7 +252,6 @@ class UnifiedVoice(nn.Module):
max_conditioning_inputs: Maximum number of conditioning inputs provided to the model. If (1), conditioning input can be of format (b,80,s), otherwise (b,n,80,s).
mel_length_compression: The factor between <number_input_samples> and <mel_tokens>. Used to compute MEL code padding given wav input length.
number_text_tokens:
start_text_token:
stop_text_token:
number_mel_codes:
start_mel_token:
@ -265,8 +264,8 @@ class UnifiedVoice(nn.Module):
super().__init__()
self.number_text_tokens = number_text_tokens
self.start_text_token = start_text_token
self.stop_text_token = stop_text_token
self.start_text_token = number_text_tokens * types
self.stop_text_token = 0
self.number_mel_codes = number_mel_codes
self.start_mel_token = start_mel_token
self.stop_mel_token = stop_mel_token
@ -279,7 +278,7 @@ class UnifiedVoice(nn.Module):
self.mel_length_compression = mel_length_compression
self.conditioning_encoder = ConditioningEncoder(80, model_dim, num_attn_heads=heads)
self.average_conditioning_embeddings = average_conditioning_embeddings
self.text_embedding = nn.Embedding(self.number_text_tokens, model_dim)
self.text_embedding = nn.Embedding(self.number_text_tokens*types+1, model_dim)
if use_mel_codes_as_input:
self.mel_embedding = nn.Embedding(self.number_mel_codes, model_dim)
else:
@ -294,7 +293,7 @@ class UnifiedVoice(nn.Module):
self.text_solo_embedding = 0
self.final_norm = nn.LayerNorm(model_dim)
self.text_head = nn.Linear(model_dim, self.number_text_tokens)
self.text_head = nn.Linear(model_dim, self.number_text_tokens*types+1)
self.mel_head = nn.Linear(model_dim, self.number_mel_codes)
# Initialize the embeddings per the GPT-2 scheme
@ -366,7 +365,7 @@ class UnifiedVoice(nn.Module):
else:
return first_logits
def forward(self, speech_conditioning_input, text_inputs, text_lengths, mel_codes, wav_lengths, text_first=True, raw_mels=None, return_attentions=False,
def forward(self, speech_conditioning_input, text_inputs, text_lengths, mel_codes, wav_lengths, types=None, text_first=True, raw_mels=None, return_attentions=False,
return_latent=False, clip_inputs=True):
"""
Forward pass that uses both text and voice in either text conditioning mode or voice conditioning mode
@ -383,6 +382,10 @@ class UnifiedVoice(nn.Module):
If return_latent is specified, loss & logits are not computed or returned. Only the predicted latents are returned.
If clip_inputs is True, the inputs will be clipped to the smallest input size across each input modality.
"""
# Types are expressed by expanding the text embedding space.
if types is not None:
text_inputs = text_inputs * (1+types).unsqueeze(-1)
if clip_inputs:
# This model will receive micro-batches with a ton of padding for both the text and MELs. Ameliorate this by
# chopping the inputs by the maximum actual length.
@ -579,10 +582,11 @@ def register_unified_voice2(opt_net, opt):
if __name__ == '__main__':
gpt = UnifiedVoice(model_dim=256, heads=4, train_solo_embeddings=True, use_mel_codes_as_input=True, max_conditioning_inputs=4, freeze_everything_but_position_embeddings=True)
gpt = UnifiedVoice(model_dim=256, heads=4, train_solo_embeddings=True, use_mel_codes_as_input=True, max_conditioning_inputs=4, freeze_everything_but_position_embeddings=True, types=2)
l = gpt(torch.randn(2, 3, 80, 800),
torch.randint(high=256, size=(2,120)),
torch.tensor([32, 120]),
torch.randint(high=8192, size=(2,250)),
torch.tensor([250*256,195*256]))
torch.tensor([250*256,195*256]),
types=torch.tensor([0, 1]))
#gpt.text_forward(torch.randn(2,80,800), torch.randint(high=50, size=(2,80)), torch.tensor([32, 80]))