support tts typing

This commit is contained in:
James Betker 2022-04-16 23:36:57 -06:00
parent 48cb6a5abd
commit 8fe0dff33c
4 changed files with 84 additions and 32 deletions

View File

@ -48,6 +48,7 @@ class FastPairedVoiceDataset(torch.utils.data.Dataset):
self.paths = [self.paths] self.paths = [self.paths]
self.paths_size_bytes = [os.path.getsize(p) for p in self.paths] self.paths_size_bytes = [os.path.getsize(p) for p in self.paths]
self.total_size_bytes = sum(self.paths_size_bytes) self.total_size_bytes = sum(self.paths_size_bytes)
self.types = opt_get(hparams, ['types'], [0 for _ in self.paths])
self.load_conditioning = opt_get(hparams, ['load_conditioning'], False) self.load_conditioning = opt_get(hparams, ['load_conditioning'], False)
self.conditioning_candidates = opt_get(hparams, ['num_conditioning_candidates'], 1) self.conditioning_candidates = opt_get(hparams, ['num_conditioning_candidates'], 1)
@ -101,22 +102,23 @@ class FastPairedVoiceDataset(torch.utils.data.Dataset):
else: else:
rand_offset -= self.paths_size_bytes[i] rand_offset -= self.paths_size_bytes[i]
path = self.paths[i] path = self.paths[i]
type = self.types[i]
with open(path, 'r', encoding='utf-8') as f: with open(path, 'r', encoding='utf-8') as f:
f.seek(rand_offset) f.seek(rand_offset)
# Read the rest of the line we seeked to, then the line after that. # Read the rest of the line we seeked to, then the line after that.
try: # This can fail when seeking to a UTF-8 escape byte. try: # This can fail when seeking to a UTF-8 escape byte.
f.readline() f.readline()
except: except:
return self.load_random_line(depth=depth + 1) # On failure, just recurse and try again. return self.load_random_line(depth=depth + 1), type # On failure, just recurse and try again.
l2 = f.readline() l2 = f.readline()
if l2: if l2:
try: try:
base_path = os.path.dirname(path) base_path = os.path.dirname(path)
return parse_tsv_aligned_codes(l2, base_path) return parse_tsv_aligned_codes(l2, base_path), type
except: except:
print(f"error parsing random offset: {sys.exc_info()}") print(f"error parsing random offset: {sys.exc_info()}")
return self.load_random_line(depth=depth+1) # On failure, just recurse and try again. return self.load_random_line(depth=depth+1), type # On failure, just recurse and try again.
def get_ctc_metadata(self, codes): def get_ctc_metadata(self, codes):
grouped = groupby(codes.tolist()) grouped = groupby(codes.tolist())
@ -155,7 +157,7 @@ class FastPairedVoiceDataset(torch.utils.data.Dataset):
def __getitem__(self, index): def __getitem__(self, index):
start = time.time() start = time.time()
self.skipped_items += 1 self.skipped_items += 1
apt = self.load_random_line() apt, type = self.load_random_line()
try: try:
tseq, wav, text, path = self.get_wav_text_pair(apt) tseq, wav, text, path = self.get_wav_text_pair(apt)
if text is None or len(text.strip()) == 0: if text is None or len(text.strip()) == 0:
@ -204,7 +206,8 @@ class FastPairedVoiceDataset(torch.utils.data.Dataset):
'wav_lengths': torch.tensor(orig_output, dtype=torch.long), 'wav_lengths': torch.tensor(orig_output, dtype=torch.long),
'filenames': path, 'filenames': path,
'skipped_items': actually_skipped_items, 'skipped_items': actually_skipped_items,
'load_time': self.load_times.mean() 'load_time': self.load_times.mean(),
'type': type,
} }
if self.load_conditioning: if self.load_conditioning:
res['conditioning'] = cond res['conditioning'] = cond
@ -266,10 +269,11 @@ if __name__ == '__main__':
'path': ['y:/libritts/train-other-500/transcribed-oco.tsv', 'path': ['y:/libritts/train-other-500/transcribed-oco.tsv',
'y:/libritts/train-clean-100/transcribed-oco.tsv', 'y:/libritts/train-clean-100/transcribed-oco.tsv',
'y:/libritts/train-clean-360/transcribed-oco.tsv', 'y:/libritts/train-clean-360/transcribed-oco.tsv',
'y:/clips/books1/transcribed-w2v.tsv', 'y:/clips/books1/transcribed-oco.tsv',
'y:/clips/books2/transcribed-w2v.tsv', 'y:/clips/books2/transcribed-oco.tsv',
'y:/bigasr_dataset/hifi_tts/transcribed-w2v.tsv', 'y:/bigasr_dataset/hifi_tts/transcribed-oco.tsv',
'y:/clips/podcasts-1/transcribed-oco.tsv',], 'y:/clips/podcasts-1/transcribed-oco.tsv',],
'types': [0,1,1,1,2,2,0],
'phase': 'train', 'phase': 'train',
'n_workers': 0, 'n_workers': 0,
'batch_size': batch_sz, 'batch_size': batch_sz,

View File

@ -9,11 +9,28 @@ import torchaudio
from tqdm import tqdm from tqdm import tqdm
from data.audio.unsupervised_audio_dataset import load_audio, load_similar_clips from data.audio.unsupervised_audio_dataset import load_audio, load_similar_clips
from models.audio.tts.tacotron2 import load_filepaths_and_text from models.audio.tts.tacotron2 import load_filepaths_and_text, load_filepaths_and_text_type
from models.audio.tts.tacotron2 import text_to_sequence, sequence_to_text from models.audio.tts.tacotron2 import text_to_sequence, sequence_to_text
from utils.util import opt_get from utils.util import opt_get
def load_tsv_type(filename, type):
with open(filename, encoding='utf-8') as f:
filepaths_and_text = []
base = os.path.dirname(filename)
bad_lines = 0
for line in f:
components = line.strip().split('\t')
if len(components) < 2:
bad_lines += 1
if bad_lines > 1000:
print(f'{filename} contains 1000+ bad entries. Failing. Sample last entry: {line}')
raise ValueError
continue
filepaths_and_text.append([os.path.join(base, f'{components[1]}'), components[0]] + [type])
return filepaths_and_text
def load_tsv(filename): def load_tsv(filename):
with open(filename, encoding='utf-8') as f: with open(filename, encoding='utf-8') as f:
filepaths_and_text = [] filepaths_and_text = []
@ -40,6 +57,23 @@ def convert_string_list_to_tensor(strlist):
return torch.tensor(as_ints) return torch.tensor(as_ints)
def load_tsv_aligned_codes_type(filename, type):
with open(filename, encoding='utf-8') as f:
filepaths_and_text = []
base = os.path.dirname(filename)
bad_lines = 0
for line in f:
components = line.strip().split('\t')
if len(components) < 3:
bad_lines += 1
if bad_lines > 1000:
print(f'{filename} contains 1000+ bad entries. Failing. Sample last entry: {line}')
raise ValueError
continue
filepaths_and_text.append([os.path.join(base, f'{components[1]}'), components[0], convert_string_list_to_tensor(components[2])] + [type])
return filepaths_and_text
def load_tsv_aligned_codes(filename): def load_tsv_aligned_codes(filename):
with open(filename, encoding='utf-8') as f: with open(filename, encoding='utf-8') as f:
filepaths_and_text = [] filepaths_and_text = []
@ -57,15 +91,15 @@ def load_tsv_aligned_codes(filename):
return filepaths_and_text return filepaths_and_text
def load_mozilla_cv(filename): def load_mozilla_cv(filename, type):
with open(filename, encoding='utf-8') as f: with open(filename, encoding='utf-8') as f:
components = [line.strip().split('\t') for line in f][1:] # First line is the header components = [line.strip().split('\t') for line in f][1:] # First line is the header
base = os.path.dirname(filename) base = os.path.dirname(filename)
filepaths_and_text = [[os.path.join(base, f'clips/{component[1]}'), component[2]] for component in components] filepaths_and_text = [[os.path.join(base, f'clips/{component[1]}'), component[2], type] for component in components]
return filepaths_and_text return filepaths_and_text
def load_voxpopuli(filename): def load_voxpopuli(filename, type):
with open(filename, encoding='utf-8') as f: with open(filename, encoding='utf-8') as f:
lines = [line.strip().split('\t') for line in f][1:] # First line is the header lines = [line.strip().split('\t') for line in f][1:] # First line is the header
base = os.path.dirname(filename) base = os.path.dirname(filename)
@ -75,7 +109,7 @@ def load_voxpopuli(filename):
continue continue
file, raw_text, norm_text, speaker_id, split, gender = line file, raw_text, norm_text, speaker_id, split, gender = line
year = file[:4] year = file[:4]
filepaths_and_text.append([os.path.join(base, year, f'{file}.ogg.wav'), raw_text]) filepaths_and_text.append([os.path.join(base, year, f'{file}.ogg.wav'), raw_text, type])
return filepaths_and_text return filepaths_and_text
@ -92,6 +126,7 @@ class TextWavLoader(torch.utils.data.Dataset):
self.path = hparams['path'] self.path = hparams['path']
if not isinstance(self.path, list): if not isinstance(self.path, list):
self.path = [self.path] self.path = [self.path]
self.types = opt_get(hparams, ['types'], [0 for _ in self.path])
fetcher_mode = opt_get(hparams, ['fetcher_mode'], 'lj') fetcher_mode = opt_get(hparams, ['fetcher_mode'], 'lj')
if not isinstance(fetcher_mode, list): if not isinstance(fetcher_mode, list):
@ -105,11 +140,11 @@ class TextWavLoader(torch.utils.data.Dataset):
self.load_aligned_codes = opt_get(hparams, ['load_aligned_codes'], False) self.load_aligned_codes = opt_get(hparams, ['load_aligned_codes'], False)
self.aligned_codes_to_audio_ratio = opt_get(hparams, ['aligned_codes_ratio'], 443) self.aligned_codes_to_audio_ratio = opt_get(hparams, ['aligned_codes_ratio'], 443)
self.audiopaths_and_text = [] self.audiopaths_and_text = []
for p, fm in zip(self.path, fetcher_mode): for p, fm, type in zip(self.path, fetcher_mode, self.types):
if fm == 'lj' or fm == 'libritts': if fm == 'lj' or fm == 'libritts':
fetcher_fn = load_filepaths_and_text fetcher_fn = load_filepaths_and_text_type
elif fm == 'tsv': elif fm == 'tsv':
fetcher_fn = load_tsv_aligned_codes if self.load_aligned_codes else load_tsv fetcher_fn = load_tsv_aligned_codes_type if self.load_aligned_codes else load_tsv_type
elif fm == 'mozilla_cv': elif fm == 'mozilla_cv':
assert not self.load_conditioning # Conditioning inputs are incompatible with mozilla_cv assert not self.load_conditioning # Conditioning inputs are incompatible with mozilla_cv
fetcher_fn = load_mozilla_cv fetcher_fn = load_mozilla_cv
@ -118,7 +153,7 @@ class TextWavLoader(torch.utils.data.Dataset):
fetcher_fn = load_voxpopuli fetcher_fn = load_voxpopuli
else: else:
raise NotImplementedError() raise NotImplementedError()
self.audiopaths_and_text.extend(fetcher_fn(p)) self.audiopaths_and_text.extend(fetcher_fn(p, type))
self.text_cleaners = hparams.text_cleaners self.text_cleaners = hparams.text_cleaners
self.sample_rate = hparams.sample_rate self.sample_rate = hparams.sample_rate
random.seed(hparams.seed) random.seed(hparams.seed)
@ -138,10 +173,10 @@ class TextWavLoader(torch.utils.data.Dataset):
def get_wav_text_pair(self, audiopath_and_text): def get_wav_text_pair(self, audiopath_and_text):
# separate filename and text # separate filename and text
audiopath, text = audiopath_and_text[0], audiopath_and_text[1] audiopath, text, type = audiopath_and_text[0], audiopath_and_text[1], audiopath_and_text[2]
text_seq = self.get_text(text) text_seq = self.get_text(text)
wav = load_audio(audiopath, self.sample_rate) wav = load_audio(audiopath, self.sample_rate)
return (text_seq, wav, text, audiopath_and_text[0]) return (text_seq, wav, text, audiopath_and_text[0], type)
def get_text(self, text): def get_text(self, text):
tokens = self.tokenizer.encode(text) tokens = self.tokenizer.encode(text)
@ -156,7 +191,7 @@ class TextWavLoader(torch.utils.data.Dataset):
def __getitem__(self, index): def __getitem__(self, index):
self.skipped_items += 1 self.skipped_items += 1
try: try:
tseq, wav, text, path = self.get_wav_text_pair(self.audiopaths_and_text[index]) tseq, wav, text, path, type = self.get_wav_text_pair(self.audiopaths_and_text[index])
if text is None or len(text.strip()) == 0: if text is None or len(text.strip()) == 0:
raise ValueError raise ValueError
if wav is None or wav.shape[-1] < (.6 * self.sample_rate): if wav is None or wav.shape[-1] < (.6 * self.sample_rate):
@ -202,6 +237,7 @@ class TextWavLoader(torch.utils.data.Dataset):
'wav_lengths': torch.tensor(orig_output, dtype=torch.long), 'wav_lengths': torch.tensor(orig_output, dtype=torch.long),
'filenames': path, 'filenames': path,
'skipped_items': actually_skipped_items, 'skipped_items': actually_skipped_items,
'type': type,
} }
if self.load_conditioning: if self.load_conditioning:
res['conditioning'] = cond res['conditioning'] = cond
@ -249,7 +285,7 @@ if __name__ == '__main__':
batch_sz = 8 batch_sz = 8
params = { params = {
'mode': 'paired_voice_audio', 'mode': 'paired_voice_audio',
'path': ['Y:\\clips\\books1\\transcribed-w2v.tsv'], 'path': ['Y:\\clips\\books1\\transcribed-oco.tsv'],
'fetcher_mode': ['tsv'], 'fetcher_mode': ['tsv'],
'phase': 'train', 'phase': 'train',
'n_workers': 0, 'n_workers': 0,

View File

@ -26,6 +26,14 @@ def load_wav_to_torch(full_path):
return (torch.FloatTensor(data.astype(np.float32)) / norm_fix, sampling_rate) return (torch.FloatTensor(data.astype(np.float32)) / norm_fix, sampling_rate)
def load_filepaths_and_text_type(filename, type, split="|"):
with open(filename, encoding='utf-8') as f:
filepaths_and_text = [list(line.strip().split(split)) + [type] for line in f]
base = os.path.dirname(filename)
for j in range(len(filepaths_and_text)):
filepaths_and_text[j][0] = os.path.join(base, filepaths_and_text[j][0])
return filepaths_and_text
def load_filepaths_and_text(filename, split="|"): def load_filepaths_and_text(filename, split="|"):
with open(filename, encoding='utf-8') as f: with open(filename, encoding='utf-8') as f:
filepaths_and_text = [line.strip().split(split) for line in f] filepaths_and_text = [line.strip().split(split) for line in f]

View File

@ -238,10 +238,10 @@ class MelEncoder(nn.Module):
class UnifiedVoice(nn.Module): class UnifiedVoice(nn.Module):
def __init__(self, layers=8, model_dim=512, heads=8, max_text_tokens=120, max_mel_tokens=250, max_conditioning_inputs=1, def __init__(self, layers=8, model_dim=512, heads=8, max_text_tokens=120, max_mel_tokens=250, max_conditioning_inputs=1,
mel_length_compression=1024, number_text_tokens=256, mel_length_compression=1024, number_text_tokens=256, number_mel_codes=8194, start_mel_token=8192,
start_text_token=255, stop_text_token=0, number_mel_codes=8194, start_mel_token=8192,
stop_mel_token=8193, train_solo_embeddings=False, use_mel_codes_as_input=True, stop_mel_token=8193, train_solo_embeddings=False, use_mel_codes_as_input=True,
checkpointing=True, average_conditioning_embeddings=False, freeze_everything_but_position_embeddings=False): checkpointing=True, average_conditioning_embeddings=False, freeze_everything_but_position_embeddings=False,
types=1):
""" """
Args: Args:
layers: Number of layers in transformer stack. layers: Number of layers in transformer stack.
@ -252,7 +252,6 @@ class UnifiedVoice(nn.Module):
max_conditioning_inputs: Maximum number of conditioning inputs provided to the model. If (1), conditioning input can be of format (b,80,s), otherwise (b,n,80,s). max_conditioning_inputs: Maximum number of conditioning inputs provided to the model. If (1), conditioning input can be of format (b,80,s), otherwise (b,n,80,s).
mel_length_compression: The factor between <number_input_samples> and <mel_tokens>. Used to compute MEL code padding given wav input length. mel_length_compression: The factor between <number_input_samples> and <mel_tokens>. Used to compute MEL code padding given wav input length.
number_text_tokens: number_text_tokens:
start_text_token:
stop_text_token: stop_text_token:
number_mel_codes: number_mel_codes:
start_mel_token: start_mel_token:
@ -265,8 +264,8 @@ class UnifiedVoice(nn.Module):
super().__init__() super().__init__()
self.number_text_tokens = number_text_tokens self.number_text_tokens = number_text_tokens
self.start_text_token = start_text_token self.start_text_token = number_text_tokens * types
self.stop_text_token = stop_text_token self.stop_text_token = 0
self.number_mel_codes = number_mel_codes self.number_mel_codes = number_mel_codes
self.start_mel_token = start_mel_token self.start_mel_token = start_mel_token
self.stop_mel_token = stop_mel_token self.stop_mel_token = stop_mel_token
@ -279,7 +278,7 @@ class UnifiedVoice(nn.Module):
self.mel_length_compression = mel_length_compression self.mel_length_compression = mel_length_compression
self.conditioning_encoder = ConditioningEncoder(80, model_dim, num_attn_heads=heads) self.conditioning_encoder = ConditioningEncoder(80, model_dim, num_attn_heads=heads)
self.average_conditioning_embeddings = average_conditioning_embeddings self.average_conditioning_embeddings = average_conditioning_embeddings
self.text_embedding = nn.Embedding(self.number_text_tokens, model_dim) self.text_embedding = nn.Embedding(self.number_text_tokens*types+1, model_dim)
if use_mel_codes_as_input: if use_mel_codes_as_input:
self.mel_embedding = nn.Embedding(self.number_mel_codes, model_dim) self.mel_embedding = nn.Embedding(self.number_mel_codes, model_dim)
else: else:
@ -294,7 +293,7 @@ class UnifiedVoice(nn.Module):
self.text_solo_embedding = 0 self.text_solo_embedding = 0
self.final_norm = nn.LayerNorm(model_dim) self.final_norm = nn.LayerNorm(model_dim)
self.text_head = nn.Linear(model_dim, self.number_text_tokens) self.text_head = nn.Linear(model_dim, self.number_text_tokens*types+1)
self.mel_head = nn.Linear(model_dim, self.number_mel_codes) self.mel_head = nn.Linear(model_dim, self.number_mel_codes)
# Initialize the embeddings per the GPT-2 scheme # Initialize the embeddings per the GPT-2 scheme
@ -366,7 +365,7 @@ class UnifiedVoice(nn.Module):
else: else:
return first_logits return first_logits
def forward(self, speech_conditioning_input, text_inputs, text_lengths, mel_codes, wav_lengths, text_first=True, raw_mels=None, return_attentions=False, def forward(self, speech_conditioning_input, text_inputs, text_lengths, mel_codes, wav_lengths, types=None, text_first=True, raw_mels=None, return_attentions=False,
return_latent=False, clip_inputs=True): return_latent=False, clip_inputs=True):
""" """
Forward pass that uses both text and voice in either text conditioning mode or voice conditioning mode Forward pass that uses both text and voice in either text conditioning mode or voice conditioning mode
@ -383,6 +382,10 @@ class UnifiedVoice(nn.Module):
If return_latent is specified, loss & logits are not computed or returned. Only the predicted latents are returned. If return_latent is specified, loss & logits are not computed or returned. Only the predicted latents are returned.
If clip_inputs is True, the inputs will be clipped to the smallest input size across each input modality. If clip_inputs is True, the inputs will be clipped to the smallest input size across each input modality.
""" """
# Types are expressed by expanding the text embedding space.
if types is not None:
text_inputs = text_inputs * (1+types).unsqueeze(-1)
if clip_inputs: if clip_inputs:
# This model will receive micro-batches with a ton of padding for both the text and MELs. Ameliorate this by # This model will receive micro-batches with a ton of padding for both the text and MELs. Ameliorate this by
# chopping the inputs by the maximum actual length. # chopping the inputs by the maximum actual length.
@ -579,10 +582,11 @@ def register_unified_voice2(opt_net, opt):
if __name__ == '__main__': if __name__ == '__main__':
gpt = UnifiedVoice(model_dim=256, heads=4, train_solo_embeddings=True, use_mel_codes_as_input=True, max_conditioning_inputs=4, freeze_everything_but_position_embeddings=True) gpt = UnifiedVoice(model_dim=256, heads=4, train_solo_embeddings=True, use_mel_codes_as_input=True, max_conditioning_inputs=4, freeze_everything_but_position_embeddings=True, types=2)
l = gpt(torch.randn(2, 3, 80, 800), l = gpt(torch.randn(2, 3, 80, 800),
torch.randint(high=256, size=(2,120)), torch.randint(high=256, size=(2,120)),
torch.tensor([32, 120]), torch.tensor([32, 120]),
torch.randint(high=8192, size=(2,250)), torch.randint(high=8192, size=(2,250)),
torch.tensor([250*256,195*256])) torch.tensor([250*256,195*256]),
types=torch.tensor([0, 1]))
#gpt.text_forward(torch.randn(2,80,800), torch.randint(high=50, size=(2,80)), torch.tensor([32, 80])) #gpt.text_forward(torch.randn(2,80,800), torch.randint(high=50, size=(2,80)), torch.tensor([32, 80]))