diff --git a/codes/data/audio/fast_paired_dataset.py b/codes/data/audio/fast_paired_dataset.py index e968bc45..9aaaa6fc 100644 --- a/codes/data/audio/fast_paired_dataset.py +++ b/codes/data/audio/fast_paired_dataset.py @@ -48,6 +48,7 @@ class FastPairedVoiceDataset(torch.utils.data.Dataset): self.paths = [self.paths] self.paths_size_bytes = [os.path.getsize(p) for p in self.paths] self.total_size_bytes = sum(self.paths_size_bytes) + self.types = opt_get(hparams, ['types'], [0 for _ in self.paths]) self.load_conditioning = opt_get(hparams, ['load_conditioning'], False) self.conditioning_candidates = opt_get(hparams, ['num_conditioning_candidates'], 1) @@ -101,22 +102,23 @@ class FastPairedVoiceDataset(torch.utils.data.Dataset): else: rand_offset -= self.paths_size_bytes[i] path = self.paths[i] + type = self.types[i] with open(path, 'r', encoding='utf-8') as f: f.seek(rand_offset) # Read the rest of the line we seeked to, then the line after that. try: # This can fail when seeking to a UTF-8 escape byte. f.readline() except: - return self.load_random_line(depth=depth + 1) # On failure, just recurse and try again. + return self.load_random_line(depth=depth + 1), type # On failure, just recurse and try again. l2 = f.readline() if l2: try: base_path = os.path.dirname(path) - return parse_tsv_aligned_codes(l2, base_path) + return parse_tsv_aligned_codes(l2, base_path), type except: print(f"error parsing random offset: {sys.exc_info()}") - return self.load_random_line(depth=depth+1) # On failure, just recurse and try again. + return self.load_random_line(depth=depth+1), type # On failure, just recurse and try again. def get_ctc_metadata(self, codes): grouped = groupby(codes.tolist()) @@ -155,7 +157,7 @@ class FastPairedVoiceDataset(torch.utils.data.Dataset): def __getitem__(self, index): start = time.time() self.skipped_items += 1 - apt = self.load_random_line() + apt, type = self.load_random_line() try: tseq, wav, text, path = self.get_wav_text_pair(apt) if text is None or len(text.strip()) == 0: @@ -204,7 +206,8 @@ class FastPairedVoiceDataset(torch.utils.data.Dataset): 'wav_lengths': torch.tensor(orig_output, dtype=torch.long), 'filenames': path, 'skipped_items': actually_skipped_items, - 'load_time': self.load_times.mean() + 'load_time': self.load_times.mean(), + 'type': type, } if self.load_conditioning: res['conditioning'] = cond @@ -266,10 +269,11 @@ if __name__ == '__main__': 'path': ['y:/libritts/train-other-500/transcribed-oco.tsv', 'y:/libritts/train-clean-100/transcribed-oco.tsv', 'y:/libritts/train-clean-360/transcribed-oco.tsv', - 'y:/clips/books1/transcribed-w2v.tsv', - 'y:/clips/books2/transcribed-w2v.tsv', - 'y:/bigasr_dataset/hifi_tts/transcribed-w2v.tsv', + 'y:/clips/books1/transcribed-oco.tsv', + 'y:/clips/books2/transcribed-oco.tsv', + 'y:/bigasr_dataset/hifi_tts/transcribed-oco.tsv', 'y:/clips/podcasts-1/transcribed-oco.tsv',], + 'types': [0,1,1,1,2,2,0], 'phase': 'train', 'n_workers': 0, 'batch_size': batch_sz, diff --git a/codes/data/audio/paired_voice_audio_dataset.py b/codes/data/audio/paired_voice_audio_dataset.py index e4dfd922..60b6a4a8 100644 --- a/codes/data/audio/paired_voice_audio_dataset.py +++ b/codes/data/audio/paired_voice_audio_dataset.py @@ -9,11 +9,28 @@ import torchaudio from tqdm import tqdm from data.audio.unsupervised_audio_dataset import load_audio, load_similar_clips -from models.audio.tts.tacotron2 import load_filepaths_and_text +from models.audio.tts.tacotron2 import load_filepaths_and_text, load_filepaths_and_text_type from models.audio.tts.tacotron2 import text_to_sequence, sequence_to_text from utils.util import opt_get +def load_tsv_type(filename, type): + with open(filename, encoding='utf-8') as f: + filepaths_and_text = [] + base = os.path.dirname(filename) + bad_lines = 0 + for line in f: + components = line.strip().split('\t') + if len(components) < 2: + bad_lines += 1 + if bad_lines > 1000: + print(f'{filename} contains 1000+ bad entries. Failing. Sample last entry: {line}') + raise ValueError + continue + filepaths_and_text.append([os.path.join(base, f'{components[1]}'), components[0]] + [type]) + return filepaths_and_text + + def load_tsv(filename): with open(filename, encoding='utf-8') as f: filepaths_and_text = [] @@ -40,6 +57,23 @@ def convert_string_list_to_tensor(strlist): return torch.tensor(as_ints) +def load_tsv_aligned_codes_type(filename, type): + with open(filename, encoding='utf-8') as f: + filepaths_and_text = [] + base = os.path.dirname(filename) + bad_lines = 0 + for line in f: + components = line.strip().split('\t') + if len(components) < 3: + bad_lines += 1 + if bad_lines > 1000: + print(f'{filename} contains 1000+ bad entries. Failing. Sample last entry: {line}') + raise ValueError + continue + filepaths_and_text.append([os.path.join(base, f'{components[1]}'), components[0], convert_string_list_to_tensor(components[2])] + [type]) + return filepaths_and_text + + def load_tsv_aligned_codes(filename): with open(filename, encoding='utf-8') as f: filepaths_and_text = [] @@ -57,15 +91,15 @@ def load_tsv_aligned_codes(filename): return filepaths_and_text -def load_mozilla_cv(filename): +def load_mozilla_cv(filename, type): with open(filename, encoding='utf-8') as f: components = [line.strip().split('\t') for line in f][1:] # First line is the header base = os.path.dirname(filename) - filepaths_and_text = [[os.path.join(base, f'clips/{component[1]}'), component[2]] for component in components] + filepaths_and_text = [[os.path.join(base, f'clips/{component[1]}'), component[2], type] for component in components] return filepaths_and_text -def load_voxpopuli(filename): +def load_voxpopuli(filename, type): with open(filename, encoding='utf-8') as f: lines = [line.strip().split('\t') for line in f][1:] # First line is the header base = os.path.dirname(filename) @@ -75,7 +109,7 @@ def load_voxpopuli(filename): continue file, raw_text, norm_text, speaker_id, split, gender = line year = file[:4] - filepaths_and_text.append([os.path.join(base, year, f'{file}.ogg.wav'), raw_text]) + filepaths_and_text.append([os.path.join(base, year, f'{file}.ogg.wav'), raw_text, type]) return filepaths_and_text @@ -92,6 +126,7 @@ class TextWavLoader(torch.utils.data.Dataset): self.path = hparams['path'] if not isinstance(self.path, list): self.path = [self.path] + self.types = opt_get(hparams, ['types'], [0 for _ in self.path]) fetcher_mode = opt_get(hparams, ['fetcher_mode'], 'lj') if not isinstance(fetcher_mode, list): @@ -105,11 +140,11 @@ class TextWavLoader(torch.utils.data.Dataset): self.load_aligned_codes = opt_get(hparams, ['load_aligned_codes'], False) self.aligned_codes_to_audio_ratio = opt_get(hparams, ['aligned_codes_ratio'], 443) self.audiopaths_and_text = [] - for p, fm in zip(self.path, fetcher_mode): + for p, fm, type in zip(self.path, fetcher_mode, self.types): if fm == 'lj' or fm == 'libritts': - fetcher_fn = load_filepaths_and_text + fetcher_fn = load_filepaths_and_text_type elif fm == 'tsv': - fetcher_fn = load_tsv_aligned_codes if self.load_aligned_codes else load_tsv + fetcher_fn = load_tsv_aligned_codes_type if self.load_aligned_codes else load_tsv_type elif fm == 'mozilla_cv': assert not self.load_conditioning # Conditioning inputs are incompatible with mozilla_cv fetcher_fn = load_mozilla_cv @@ -118,7 +153,7 @@ class TextWavLoader(torch.utils.data.Dataset): fetcher_fn = load_voxpopuli else: raise NotImplementedError() - self.audiopaths_and_text.extend(fetcher_fn(p)) + self.audiopaths_and_text.extend(fetcher_fn(p, type)) self.text_cleaners = hparams.text_cleaners self.sample_rate = hparams.sample_rate random.seed(hparams.seed) @@ -138,10 +173,10 @@ class TextWavLoader(torch.utils.data.Dataset): def get_wav_text_pair(self, audiopath_and_text): # separate filename and text - audiopath, text = audiopath_and_text[0], audiopath_and_text[1] + audiopath, text, type = audiopath_and_text[0], audiopath_and_text[1], audiopath_and_text[2] text_seq = self.get_text(text) wav = load_audio(audiopath, self.sample_rate) - return (text_seq, wav, text, audiopath_and_text[0]) + return (text_seq, wav, text, audiopath_and_text[0], type) def get_text(self, text): tokens = self.tokenizer.encode(text) @@ -156,7 +191,7 @@ class TextWavLoader(torch.utils.data.Dataset): def __getitem__(self, index): self.skipped_items += 1 try: - tseq, wav, text, path = self.get_wav_text_pair(self.audiopaths_and_text[index]) + tseq, wav, text, path, type = self.get_wav_text_pair(self.audiopaths_and_text[index]) if text is None or len(text.strip()) == 0: raise ValueError if wav is None or wav.shape[-1] < (.6 * self.sample_rate): @@ -202,6 +237,7 @@ class TextWavLoader(torch.utils.data.Dataset): 'wav_lengths': torch.tensor(orig_output, dtype=torch.long), 'filenames': path, 'skipped_items': actually_skipped_items, + 'type': type, } if self.load_conditioning: res['conditioning'] = cond @@ -249,7 +285,7 @@ if __name__ == '__main__': batch_sz = 8 params = { 'mode': 'paired_voice_audio', - 'path': ['Y:\\clips\\books1\\transcribed-w2v.tsv'], + 'path': ['Y:\\clips\\books1\\transcribed-oco.tsv'], 'fetcher_mode': ['tsv'], 'phase': 'train', 'n_workers': 0, diff --git a/codes/models/audio/tts/tacotron2/taco_utils.py b/codes/models/audio/tts/tacotron2/taco_utils.py index c6297bfe..13c63fa5 100644 --- a/codes/models/audio/tts/tacotron2/taco_utils.py +++ b/codes/models/audio/tts/tacotron2/taco_utils.py @@ -26,6 +26,14 @@ def load_wav_to_torch(full_path): return (torch.FloatTensor(data.astype(np.float32)) / norm_fix, sampling_rate) +def load_filepaths_and_text_type(filename, type, split="|"): + with open(filename, encoding='utf-8') as f: + filepaths_and_text = [list(line.strip().split(split)) + [type] for line in f] + base = os.path.dirname(filename) + for j in range(len(filepaths_and_text)): + filepaths_and_text[j][0] = os.path.join(base, filepaths_and_text[j][0]) + return filepaths_and_text + def load_filepaths_and_text(filename, split="|"): with open(filename, encoding='utf-8') as f: filepaths_and_text = [line.strip().split(split) for line in f] diff --git a/codes/models/audio/tts/unified_voice2.py b/codes/models/audio/tts/unified_voice2.py index 22d6fce6..5f37141d 100644 --- a/codes/models/audio/tts/unified_voice2.py +++ b/codes/models/audio/tts/unified_voice2.py @@ -238,10 +238,10 @@ class MelEncoder(nn.Module): class UnifiedVoice(nn.Module): def __init__(self, layers=8, model_dim=512, heads=8, max_text_tokens=120, max_mel_tokens=250, max_conditioning_inputs=1, - mel_length_compression=1024, number_text_tokens=256, - start_text_token=255, stop_text_token=0, number_mel_codes=8194, start_mel_token=8192, + mel_length_compression=1024, number_text_tokens=256, number_mel_codes=8194, start_mel_token=8192, stop_mel_token=8193, train_solo_embeddings=False, use_mel_codes_as_input=True, - checkpointing=True, average_conditioning_embeddings=False, freeze_everything_but_position_embeddings=False): + checkpointing=True, average_conditioning_embeddings=False, freeze_everything_but_position_embeddings=False, + types=1): """ Args: layers: Number of layers in transformer stack. @@ -252,7 +252,6 @@ class UnifiedVoice(nn.Module): max_conditioning_inputs: Maximum number of conditioning inputs provided to the model. If (1), conditioning input can be of format (b,80,s), otherwise (b,n,80,s). mel_length_compression: The factor between and . Used to compute MEL code padding given wav input length. number_text_tokens: - start_text_token: stop_text_token: number_mel_codes: start_mel_token: @@ -265,8 +264,8 @@ class UnifiedVoice(nn.Module): super().__init__() self.number_text_tokens = number_text_tokens - self.start_text_token = start_text_token - self.stop_text_token = stop_text_token + self.start_text_token = number_text_tokens * types + self.stop_text_token = 0 self.number_mel_codes = number_mel_codes self.start_mel_token = start_mel_token self.stop_mel_token = stop_mel_token @@ -279,7 +278,7 @@ class UnifiedVoice(nn.Module): self.mel_length_compression = mel_length_compression self.conditioning_encoder = ConditioningEncoder(80, model_dim, num_attn_heads=heads) self.average_conditioning_embeddings = average_conditioning_embeddings - self.text_embedding = nn.Embedding(self.number_text_tokens, model_dim) + self.text_embedding = nn.Embedding(self.number_text_tokens*types+1, model_dim) if use_mel_codes_as_input: self.mel_embedding = nn.Embedding(self.number_mel_codes, model_dim) else: @@ -294,7 +293,7 @@ class UnifiedVoice(nn.Module): self.text_solo_embedding = 0 self.final_norm = nn.LayerNorm(model_dim) - self.text_head = nn.Linear(model_dim, self.number_text_tokens) + self.text_head = nn.Linear(model_dim, self.number_text_tokens*types+1) self.mel_head = nn.Linear(model_dim, self.number_mel_codes) # Initialize the embeddings per the GPT-2 scheme @@ -366,7 +365,7 @@ class UnifiedVoice(nn.Module): else: return first_logits - def forward(self, speech_conditioning_input, text_inputs, text_lengths, mel_codes, wav_lengths, text_first=True, raw_mels=None, return_attentions=False, + def forward(self, speech_conditioning_input, text_inputs, text_lengths, mel_codes, wav_lengths, types=None, text_first=True, raw_mels=None, return_attentions=False, return_latent=False, clip_inputs=True): """ Forward pass that uses both text and voice in either text conditioning mode or voice conditioning mode @@ -383,6 +382,10 @@ class UnifiedVoice(nn.Module): If return_latent is specified, loss & logits are not computed or returned. Only the predicted latents are returned. If clip_inputs is True, the inputs will be clipped to the smallest input size across each input modality. """ + # Types are expressed by expanding the text embedding space. + if types is not None: + text_inputs = text_inputs * (1+types).unsqueeze(-1) + if clip_inputs: # This model will receive micro-batches with a ton of padding for both the text and MELs. Ameliorate this by # chopping the inputs by the maximum actual length. @@ -579,10 +582,11 @@ def register_unified_voice2(opt_net, opt): if __name__ == '__main__': - gpt = UnifiedVoice(model_dim=256, heads=4, train_solo_embeddings=True, use_mel_codes_as_input=True, max_conditioning_inputs=4, freeze_everything_but_position_embeddings=True) + gpt = UnifiedVoice(model_dim=256, heads=4, train_solo_embeddings=True, use_mel_codes_as_input=True, max_conditioning_inputs=4, freeze_everything_but_position_embeddings=True, types=2) l = gpt(torch.randn(2, 3, 80, 800), torch.randint(high=256, size=(2,120)), torch.tensor([32, 120]), torch.randint(high=8192, size=(2,250)), - torch.tensor([250*256,195*256])) + torch.tensor([250*256,195*256]), + types=torch.tensor([0, 1])) #gpt.text_forward(torch.randn(2,80,800), torch.randint(high=50, size=(2,80)), torch.tensor([32, 80]))