From 5ae816bead39c4c2a167e3d6670d33dd7449db05 Mon Sep 17 00:00:00 2001 From: James Betker Date: Sat, 5 Feb 2022 15:59:53 -0700 Subject: [PATCH] ctc gen checkin --- codes/data/audio/fast_paired_dataset.py | 38 ++-- codes/models/gpt_voice/ctc_code_generator2.py | 143 +++++++++++---- .../gen/use_diffuse_voice_translation.py | 169 ++++++++++++++++++ 3 files changed, 299 insertions(+), 51 deletions(-) create mode 100644 codes/scripts/audio/gen/use_diffuse_voice_translation.py diff --git a/codes/data/audio/fast_paired_dataset.py b/codes/data/audio/fast_paired_dataset.py index ed405608..11d9377f 100644 --- a/codes/data/audio/fast_paired_dataset.py +++ b/codes/data/audio/fast_paired_dataset.py @@ -115,34 +115,34 @@ class FastPairedVoiceDataset(torch.utils.data.Dataset): def get_ctc_metadata(self, codes): grouped = groupby(codes.tolist()) - codes, repeats, pads = [], [], [0] + rcodes, repeats, seps = [], [], [0] for val, group in grouped: if val == 0: - pads[-1] = len(list(group)) # This is a very important distinction! It means the padding belongs to the character proceeding it. + seps[-1] = len(list(group)) # This is a very important distinction! It means the padding belongs to the character proceeding it. else: - codes.append(val) + rcodes.append(val) repeats.append(len(list(group))) - pads.append(0) + seps.append(0) - codes = torch.tensor(codes) + rcodes = torch.tensor(rcodes) # These clip values are sane maximum values which I did not see in the datasets I have access to. - repeats = torch.clip(torch.tensor(repeats), max=30) - pads = torch.clip(torch.tensor(pads[:-1]), max=120) + repeats = torch.clip(torch.tensor(repeats), min=1, max=30) + seps = torch.clip(torch.tensor(seps[:-1]), max=120) # Pad or clip the codes to get them to exactly self.max_text_len - orig_lens = codes.shape[0] - if codes.shape[0] < self.max_text_len: - gap = self.max_text_len - codes.shape[0] - codes = F.pad(codes, (0, gap)) - repeats = F.pad(repeats, (0, gap)) - pads = F.pad(pads, (0, gap)) - elif codes.shape[0] > self.max_text_len: - codes = codes[:self.max_text_len] - repeats = codes[:self.max_text_len] - pads = pads[:self.max_text_len] + orig_lens = rcodes.shape[0] + if rcodes.shape[0] < self.max_text_len: + gap = self.max_text_len - rcodes.shape[0] + rcodes = F.pad(rcodes, (0, gap)) + repeats = F.pad(repeats, (0, gap), value=1) # The minimum value for repeats is 1, hence this is the pad value too. + seps = F.pad(seps, (0, gap)) + elif rcodes.shape[0] > self.max_text_len: + rcodes = rcodes[:self.max_text_len] + repeats = rcodes[:self.max_text_len] + seps = seps[:self.max_text_len] return { - 'ctc_raw_codes': codes, - 'ctc_pads': pads, + 'ctc_raw_codes': rcodes, + 'ctc_separators': seps, 'ctc_repeats': repeats, 'ctc_raw_lengths': orig_lens, } diff --git a/codes/models/gpt_voice/ctc_code_generator2.py b/codes/models/gpt_voice/ctc_code_generator2.py index 1d6db586..24fa37a1 100644 --- a/codes/models/gpt_voice/ctc_code_generator2.py +++ b/codes/models/gpt_voice/ctc_code_generator2.py @@ -1,29 +1,34 @@ import functools +import json import torch import torch.nn as nn import torch.nn.functional as F -from transformers import T5Config, T5Model +from torch.nn import CrossEntropyLoss +from transformers import T5Config, T5Model, T5PreTrainedModel, T5ForConditionalGeneration +from transformers.file_utils import replace_return_docstrings +from transformers.modeling_outputs import Seq2SeqLMOutput, BaseModelOutput +from transformers.utils.model_parallel_utils import get_device_map, assert_device_map from x_transformers import Encoder, XTransformer from models.gpt_voice.transformer_builders import null_position_embeddings from models.gpt_voice.unet_diffusion_tts6 import CheckpointedLayer from models.gpt_voice.unified_voice2 import ConditioningEncoder +from models.tacotron2.text.cleaners import english_cleaners from trainer.networks import register_model from utils.util import opt_get class CtcCodeGenerator(nn.Module): - def __init__(self, model_dim=512, layers=10, num_heads=8, dropout=.1, ctc_codes=36, max_pad=120, max_repeat=30, checkpointing=True): + def __init__(self, model_dim=512, layers=10, num_heads=8, dropout=.1, ctc_codes=36, max_pad=121, max_repeat=30, checkpointing=True): super().__init__() self.max_pad = max_pad self.max_repeat = max_repeat - self.start_token = (self.max_repeat+1)*(self.max_pad+1)+1 + self.start_token = self.max_repeat*self.max_pad+1 self.conditioning_encoder = ConditioningEncoder(80, model_dim, num_attn_heads=num_heads) self.embedding = nn.Embedding(ctc_codes, model_dim) - self.dec_embedding = nn.Embedding(self.start_token+1, model_dim) self.config = T5Config( - vocab_size=1, # T5 embedding will be removed and replaced with custom embedding. + vocab_size=self.start_token+1, d_model=model_dim, d_kv=model_dim//num_heads, d_ff=model_dim*4, @@ -32,29 +37,36 @@ class CtcCodeGenerator(nn.Module): dropout_rate=dropout, feed_forward_proj='gated-gelu', use_cache=not checkpointing, - gradient_checkpointing=checkpointing + gradient_checkpointing=checkpointing, + tie_word_embeddings=False, + tie_encoder_decoder=False, + decoder_start_token_id=self.start_token, + pad_token_id=0, ) - self.transformer = T5Model(self.config) + self.transformer = T5ForConditionalGeneration(self.config) del self.transformer.encoder.embed_tokens - del self.transformer.decoder.embed_tokens + del self.transformer.shared self.transformer.encoder.embed_tokens = functools.partial(null_position_embeddings, dim=model_dim) - self.transformer.decoder.embed_tokens = functools.partial(null_position_embeddings, dim=model_dim) - self.output_layer = nn.Linear(model_dim, self.start_token+1) - - def forward(self, conditioning_input, codes, pads, repeats, unpadded_lengths): + def forward(self, conditioning_input, codes, separators, repeats, unpadded_lengths): max_len = unpadded_lengths.max() codes = codes[:, :max_len] - pads = pads[:, :max_len] + separators = separators[:, :max_len] repeats = repeats[:, :max_len] - - if pads.max() > self.max_pad: - print(f"Got unexpectedly long pads. Max: {pads.max()}, {pads}") - pads = torch.clip(pads, 0, self.max_pad) + if separators.max() > self.max_pad: + print(f"Got unexpectedly long separators. Max: {separators.max()}, {separators}") + separators = torch.clip(separators, 0, self.max_pad) if repeats.max() > self.max_repeat: print(f"Got unexpectedly long repeats. Max: {repeats.max()}, {repeats}") repeats = torch.clip(repeats, 0, self.max_repeat) + assert not torch.any(repeats < 1) + repeats = repeats - 1 # Per above, min(repeats) is 1; make it 0 to avoid wasting a prediction slot. + assert codes.max() < 36, codes.max() + labels = separators + repeats * self.max_pad + labels = labels + 1 # We want '0' to be used as the EOS or padding token, so add 1. + for i in range(unpadded_lengths.shape[0]): + labels[i, unpadded_lengths[i]:] = 0 conditioning_input = conditioning_input.unsqueeze(1) if len(conditioning_input.shape) == 3 else conditioning_input conds = [] @@ -63,32 +75,99 @@ class CtcCodeGenerator(nn.Module): conds = torch.stack(conds, dim=1) h = torch.cat([conds, self.embedding(codes)], dim=1) - labels = pads + repeats * self.max_pad + 1 - for i in range(unpadded_lengths.shape[0]): - labels[i, unpadded_lengths[i]:] = 0 - labels_in = F.pad(labels, (1,0), value=self.start_token) - h_dec = self.dec_embedding(labels_in) - - h = self.transformer(inputs_embeds=h, decoder_inputs_embeds=h_dec).last_hidden_state - logits = self.output_layer(h) - logits = logits.permute(0,2,1)[:,:,:-1] # Strip off the last token. There is no "stop" token here, so this is just an irrelevant prediction on some future that doesn't actually exist. - loss = F.cross_entropy(logits, labels, reduction='none') - - # Ignore the first predictions of the sequences. This corresponds to the padding for the first CTC character, which is pretty much random and cannot be predicted. - #loss = loss[1:].mean() + decoder_inputs = F.pad(labels, (1, 0), value=self.start_token)[:, :-1] + loss = self.transformer(inputs_embeds=h, decoder_input_ids=decoder_inputs, labels=labels).loss return loss + def generate(self, speech_conditioning_inputs, texts, **hf_generate_kwargs): + codes = [] + max_seq = 50 + for text in texts: + # First, generate CTC codes from the given texts. + vocab = json.loads('{" ": 4, "E": 5, "T": 6, "A": 7, "O": 8, "N": 9, "I": 10, "H": 11, "S": 12, "R": 13, "D": 14, "L": 15, "U": 16, "M": 17, "W": 18, "C": 19, "F": 20, "G": 21, "Y": 22, "P": 23, "B": 24, "V": 25, "K": 26, "\'": 27, "X": 28, "J": 29, "Q": 30, "Z": 31}') + text = english_cleaners(text) + text = text.strip().upper() + cd = [] + for c in text: + if c not in vocab.keys(): + continue + cd.append(vocab[c]) + codes.append(torch.tensor(cd, device=speech_conditioning_inputs.device)) + max_seq = max(max_seq, codes[-1].shape[-1]) + # Collate + for i in range(len(codes)): + if codes[i].shape[-1] < max_seq: + codes[i] = F.pad(codes[i], (0, max_seq-codes[i].shape[-1])) + codes = torch.stack(codes, dim=0) + + conditioning_input = speech_conditioning_inputs.unsqueeze(1) if len(speech_conditioning_inputs.shape) == 3 else speech_conditioning_inputs + conds = [] + for j in range(conditioning_input.shape[1]): + conds.append(self.conditioning_encoder(conditioning_input[:, j])) + conds = torch.stack(conds, dim=1) + h = torch.cat([conds, self.embedding(codes)], dim=1) + generate = self.transformer.generate(inputs_embeds=h, max_length=codes.shape[-1]+1, min_length=codes.shape[-1]+1, + bos_token_id=self.start_token, + bad_words_ids=[[0], [self.start_token]], **hf_generate_kwargs) + # The HF generate API returns a sequence with the BOS token included, hence the +1s above. Remove it. + generate = generate[:, 1:] + + # De-compress the codes from the generated output + generate = generate - 1 # Remember above when we added 1 to the labels to avoid overlapping the EOS pad token? + pads = generate % self.max_pad + repeats = (generate // self.max_pad) + 1 + ctc_batch = [] + max_seq = 0 + for bc, bp, br in zip(codes, pads, repeats): + ctc = [] + for c, p, r in zip(bc, bp, br): + for _ in range(p): + ctc.append(0) + for _ in range(r): + ctc.append(c.item()) + ctc_batch.append(torch.tensor(ctc, device=speech_conditioning_inputs.device)) + max_seq = max(max_seq, ctc_batch[-1].shape[-1]) + + # Collate the batch + for i in range(len(ctc_batch)): + if ctc_batch[i].shape[-1] < max_seq: + ctc_batch[i] = F.pad(ctc_batch[i], (0, max_seq-ctc_batch[i].shape[-1])) + return torch.stack(ctc_batch, dim=0) + @register_model def register_ctc_code_generator2(opt_net, opt): return CtcCodeGenerator(**opt_get(opt_net, ['kwargs'], {})) +def inf(): + sd = torch.load('D:\\dlas\\experiments\\train_encoder_build_ctc_alignments\\models\\24000_generator.pth', map_location='cpu') + model = CtcCodeGenerator(layers=10, checkpointing=False).eval() + model.load_state_dict(sd) + raw_batch = torch.load('raw_batch.pth') + with torch.no_grad(): + from data.audio.unsupervised_audio_dataset import load_audio + from scripts.audio.gen.speech_synthesis_utils import wav_to_mel + ref_mel = torch.cat([wav_to_mel(raw_batch['conditioning'][0])[:, :, :256], + wav_to_mel(raw_batch['conditioning'][0])[:, :, :256]], dim=0).unsqueeze(0) + loss = model(ref_mel, raw_batch['ctc_raw_codes'][0].unsqueeze(0), + raw_batch['ctc_pads'][0].unsqueeze(0), + raw_batch['ctc_repeats'][0].unsqueeze(0), + raw_batch['ctc_raw_lengths'][0].unsqueeze(0),) + #ref_mel = torch.cat([wav_to_mel(load_audio("D:\\tortoise-tts\\voices\\atkins\\1.wav", 22050))[:, :, :256], + # wav_to_mel(load_audio("D:\\tortoise-tts\\voices\\atkins\\2.wav", 22050))[:, :, :256]], dim=0).unsqueeze(0) + #ctc = model.generate(ref_mel, ["i suppose though it's too early for them"], num_beams=4, ) + print("Break") + + if __name__ == '__main__': + inf() + model = CtcCodeGenerator() conds = torch.randn(4,2,80,600) inps = torch.randint(0,36, (4, 300)) pads = torch.randint(0,100, (4,300)) repeats = torch.randint(0,20, (4,300)) - loss = model(conds, inps, pads, repeats, torch.tensor([250, 300, 280, 30])) - print(loss.shape) \ No newline at end of file + #loss = model(conds, inps, pads, repeats, torch.tensor([250, 300, 280, 30])) + #print(loss.shape) + #model.generate(conds, ["Hello, world!", "Ahoi!", "KKKKKK", "what's going on??"]) \ No newline at end of file diff --git a/codes/scripts/audio/gen/use_diffuse_voice_translation.py b/codes/scripts/audio/gen/use_diffuse_voice_translation.py new file mode 100644 index 00000000..a60b1b8f --- /dev/null +++ b/codes/scripts/audio/gen/use_diffuse_voice_translation.py @@ -0,0 +1,169 @@ +import argparse +import os + +import torch +import torchaudio + +from data.audio.unsupervised_audio_dataset import load_audio +from scripts.audio.gen.speech_synthesis_utils import do_spectrogram_diffusion, \ + load_discrete_vocoder_diffuser, wav_to_mel, convert_mel_to_codes +from utils.audio import plot_spectrogram +from utils.util import load_model_from_config + + +def ceil_multiple(base, multiple): + res = base % multiple + if res == 0: + return base + return base + (multiple - res) + + +if __name__ == '__main__': + conditioning_clips = { + # Male + 'simmons': 'Y:\\clips\\books1\\754_Dan Simmons - The Rise Of Endymion 356 of 450\\00026.wav', + 'carlin': 'Y:\\clips\\books1\\12_dchha13 Bubonic Nukes\\00097.wav', + 'entangled': 'Y:\\clips\\books1\\3857_25_The_Entangled_Bank__000000000\\00123.wav', + 'snowden': 'Y:\\clips\\books1\\7658_Edward_Snowden_-_Permanent_Record__000000004\\00027.wav', + # Female + 'the_doctor': 'Y:\\clips\\books2\\37062___The_Doctor__000000003\\00206.wav', + 'puppy': 'Y:\\clips\\books2\\17830___3_Puppy_Kisses__000000002\\00046.wav', + 'adrift': 'Y:\\clips\\books2\\5608_Gear__W_Michael_-_Donovan_1-5_(2018-2021)_(book_4_Gear__W_Michael_-_Donovan_5_-_Adrift_(2021)_Gear__W_Michael_-_Adrift_(Donovan_5)_—_82__000000000\\00019.wav', + } + + provided_codes = [ + # but facts within easy reach of any one who cares to know them go to say that the greater abstenence of women is in some part + # due to an imperative conventionality and this conventionality is in a general way strongest were the patriarchal tradition + # the tradition that the woman is a chattel has retained its hold in greatest vigor + # 3570/5694/3570_5694_000008_000001.wav + [0, 0, 24, 0, 16, 0, 6, 0, 4, 0, 0, 0, 0, 0, 20, 0, 7, 0, 0, 19, 19, 0, 0, 6, 0, 0, 12, 12, 0, 4, 4, 0, 18, 18, + 0, 10, 0, 6, 11, 11, 10, 10, 9, 9, 4, 4, 4, 5, 5, 0, 7, 0, 0, 0, 0, 12, 0, 22, 22, 0, 4, 4, 0, 13, 13, 5, 0, 7, + 7, 0, 0, 19, 11, 0, 4, 4, 8, 20, 4, 4, 4, 7, 0, 9, 9, 0, 22, 4, 4, 0, 8, 0, 9, 5, 4, 4, 18, 11, 11, 8, 4, 4, 0, + 0, 0, 19, 19, 7, 0, 0, 13, 5, 5, 0, 12, 12, 4, 4, 6, 6, 8, 8, 4, 4, 0, 26, 9, 9, 8, 0, 18, 0, 0, 4, 4, 6, 6, + 11, 5, 0, 17, 17, 0, 0, 4, 4, 4, 4, 0, 0, 0, 21, 0, 8, 0, 0, 0, 0, 4, 4, 6, 6, 8, 0, 4, 4, 0, 0, 12, 0, 7, 7, + 0, 0, 22, 0, 4, 4, 6, 11, 11, 7, 6, 6, 4, 4, 6, 11, 5, 4, 4, 4, 0, 21, 0, 13, 5, 5, 7, 7, 0, 0, 6, 6, 5, 0, 13, + 0, 4, 4, 0, 7, 0, 0, 0, 24, 0, 0, 12, 12, 0, 0, 6, 0, 5, 0, 0, 9, 9, 0, 5, 0, 9, 0, 0, 19, 5, 5, 4, 4, 8, 20, + 20, 4, 4, 4, 4, 0, 18, 18, 8, 0, 0, 0, 17, 0, 5, 0, 9, 0, 0, 0, 4, 4, 4, 4, 0, 0, 0, 10, 0, 0, 12, 12, 4, 4, 0, + 10, 0, 9, 0, 4, 4, 0, 0, 12, 0, 0, 8, 0, 17, 5, 5, 4, 4, 0, 0, 0, 23, 23, 0, 7, 0, 13, 0, 0, 0, 6, 0, 4, 0, 0, + 0, 0, 14, 0, 16, 16, 0, 0, 5, 0, 4, 4, 0, 6, 8, 0, 4, 4, 7, 9, 4, 4, 4, 0, 10, 10, 17, 0, 0, 0, 23, 0, 5, 0, 0, + 13, 13, 0, 7, 0, 0, 6, 6, 0, 10, 0, 25, 5, 5, 4, 4, 0, 0, 0, 19, 19, 8, 8, 9, 0, 0, 0, 0, 0, 25, 0, 5, 0, 9, 0, + 0, 0, 6, 6, 10, 8, 8, 0, 9, 0, 0, 0, 7, 0, 0, 15, 0, 10, 0, 0, 0, 0, 6, 6, 0, 0, 22, 0, 0, 0, 4, 4, 4, 4, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 7, 0, 9, 14, 0, 4, 0, 0, 6, 11, 10, 0, 0, 0, 12, 0, 4, 4, 0, 19, 19, 8, 9, 9, 0, 0, 25, 0, 5, 0, 9, 0, 0, 6, 6, + 10, 8, 8, 9, 9, 0, 0, 7, 0, 0, 15, 0, 10, 0, 0, 0, 0, 6, 0, 22, 22, 0, 4, 4, 0, 0, 10, 0, 0, 0, 0, 12, 12, 0, + 0, 0, 0, 4, 4, 4, 4, 0, 0, 0, 0, 10, 0, 9, 4, 4, 4, 7, 4, 4, 4, 0, 21, 0, 5, 0, 9, 0, 5, 5, 13, 13, 7, 0, 15, + 15, 0, 0, 4, 4, 0, 18, 18, 0, 7, 0, 0, 22, 0, 0, 4, 4, 4, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 12, 12, 0, 0, 0, 6, 6, 13, 13, 8, 0, 0, 9, 9, 0, 21, 0, 0, 5, 5, 0, 0, 0, 12, 12, 0, 0, 6, + 0, 0, 0, 4, 4, 0, 0, 0, 18, 0, 5, 0, 13, 0, 5, 4, 4, 6, 11, 5, 0, 4, 4, 23, 23, 7, 7, 0, 0, 0, 6, 0, 13, 13, + 10, 10, 0, 0, 0, 0, 7, 13, 13, 0, 19, 11, 11, 0, 0, 7, 15, 15, 0, 0, 4, 4, 0, 6, 13, 13, 7, 7, 0, 0, 0, 14, 10, + 10, 0, 0, 0, 0, 0, 6, 10, 10, 8, 8, 9, 0, 0, 0, 4, 4, 4, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 6, 11, 5, 0, 4, 4, 0, 6, 13, 13, 7, 7, 0, 0, 0, 14, 10, 10, 0, 0, 0, 6, 10, 10, + 8, 9, 9, 0, 0, 4, 4, 0, 6, 11, 7, 0, 6, 4, 4, 6, 11, 5, 4, 4, 4, 18, 18, 8, 0, 0, 17, 7, 0, 9, 0, 4, 10, 0, 0, + 12, 12, 4, 4, 4, 7, 4, 4, 0, 0, 0, 19, 11, 0, 7, 0, 6, 0, 0, 0, 6, 0, 5, 0, 15, 15, 0, 0, 0, 4, 4, 4, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 11, 0, 7, 0, 0, 0, 12, 0, 0, 4, 4, 0, 13, 5, 5, 0, 0, 0, 0, 6, 6, 0, 0, + 7, 10, 10, 0, 9, 0, 5, 0, 14, 4, 4, 4, 0, 10, 0, 0, 0, 6, 0, 0, 0, 0, 0, 12, 0, 4, 4, 0, 0, 0, 11, 0, 0, 8, 0, + 0, 0, 15, 0, 0, 14, 0, 4, 4, 4, 0, 10, 0, 9, 4, 4, 4, 4, 4, 0, 21, 0, 13, 5, 5, 7, 7, 0, 0, 6, 0, 5, 0, 0, 12, + 0, 6, 0, 4, 0, 0, 25, 10, 0, 0, 0, 21, 0, 8, 0, 0, 13, 13, 0, 0, 4, 4, 4, 4, 0, 0, 0], + # the competitor with whom the entertainer wishes to institute a comparison is by this method made to serve as a means to the end + # 3570/5694/3570_5694_000011_000005.wav + [0, 0, 6, 11, 5, 0, 4, 0, 19, 19, 8, 17, 0, 0, 0, 0, 23, 0, 5, 5, 0, 0, 6, 6, 10, 10, 0, 0, 6, 6, 0, 8, 0, 13, + 13, 0, 4, 4, 18, 18, 10, 0, 6, 11, 11, 4, 4, 4, 0, 0, 18, 18, 11, 0, 8, 0, 0, 0, 0, 17, 0, 0, 4, 0, 6, 11, 5, + 0, 4, 4, 0, 5, 9, 9, 0, 6, 5, 5, 13, 13, 0, 0, 6, 6, 0, 7, 0, 10, 0, 9, 0, 0, 5, 0, 13, 4, 4, 0, 18, 10, 10, 0, + 0, 12, 11, 11, 0, 5, 0, 0, 0, 12, 0, 0, 4, 4, 0, 0, 6, 6, 8, 0, 0, 4, 4, 4, 0, 10, 9, 9, 0, 0, 0, 0, 12, 0, 0, + 6, 0, 10, 0, 0, 0, 6, 0, 16, 16, 0, 6, 5, 0, 4, 4, 7, 4, 4, 19, 19, 8, 0, 17, 0, 0, 0, 0, 0, 23, 0, 0, 7, 0, 0, + 0, 13, 0, 10, 0, 0, 0, 0, 0, 12, 0, 0, 8, 0, 9, 0, 0, 4, 4, 4, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 10, 0, 0, 0, 0, 0, 12, 0, 0, 0, 0, 0, 4, 4, 4, 4, 0, 0, 0, 0, 0, 0, 24, 0, 22, 0, 4, 4, + 0, 6, 11, 10, 0, 0, 0, 12, 0, 0, 4, 4, 0, 0, 17, 5, 5, 0, 0, 0, 6, 11, 11, 8, 0, 0, 14, 14, 0, 0, 4, 4, 4, 4, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 17, 17, 7, 0, 0, 0, 0, 14, 5, 0, 4, 4, 6, 8, 4, + 4, 0, 0, 0, 12, 12, 0, 5, 5, 0, 13, 13, 0, 25, 5, 4, 4, 7, 0, 12, 4, 4, 4, 7, 4, 4, 0, 17, 5, 0, 0, 7, 0, 0, 9, + 0, 0, 0, 0, 12, 0, 4, 4, 0, 6, 0, 8, 0, 4, 4, 6, 11, 5, 4, 4, 4, 0, 0, 5, 0, 9, 9, 0, 0, 0, 0, 14, 0, 0, 4, 4, + 4, 4, 4, 0, 0], + # the livery becomes obnoxious to nearly all who are required to wear it + # 3570/5694/3570_5694_000014_000021.wav + [0, 0, 6, 11, 5, 0, 0, 4, 4, 0, 15, 10, 10, 0, 0, 25, 5, 0, 13, 13, 0, 22, 0, 0, 4, 0, 24, 24, 5, 0, 0, 0, 19, + 19, 0, 8, 0, 17, 5, 5, 0, 12, 0, 4, 4, 4, 0, 8, 0, 0, 24, 0, 0, 0, 9, 9, 0, 8, 0, 0, 0, 0, 0, 28, 0, 0, 0, 10, + 0, 8, 16, 0, 12, 12, 12, 0, 4, 0, 6, 6, 8, 0, 4, 4, 0, 9, 5, 0, 7, 7, 13, 0, 0, 15, 22, 22, 4, 4, 0, 0, 0, 0, + 0, 0, 0, 7, 0, 15, 0, 0, 15, 0, 4, 4, 4, 18, 11, 11, 8, 0, 4, 4, 0, 7, 0, 13, 5, 4, 4, 13, 13, 5, 0, 0, 0, 30, + 30, 16, 0, 0, 10, 0, 0, 0, 13, 5, 0, 14, 4, 4, 6, 6, 8, 0, 4, 4, 18, 18, 5, 5, 7, 7, 13, 13, 0, 4, 4, 0, 10, 0, + 0, 0, 0, 6, 0, 0, 4, 4, 4, 4, 4, 4, 0, 0, 0, 0, 0, 0], + # in the nature of things luxuries and the comforts of life belong to the leisure class + # 3570/5694/3570_5694_000006_000007.wav + [0, 0, 0, 0, 0, 10, 9, 0, 4, 4, 6, 11, 5, 4, 4, 4, 9, 9, 7, 7, 0, 0, 0, 0, 0, 0, 6, 0, 16, 16, 13, 13, 5, 0, 4, 4, 8, 0, 20, 4, 4, 4, 0, 6, 0, 11, 10, 0, 9, 0, 21, 0, 0, 0, 12, 12, 0, 0, 0, 4, 4, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 15, 15, 0, 16, 16, 0, 0, 28, 0, 0, 0, 16, 16, 0, 13, 13, 0, 10, 0, 5, 0, 0, 0, 12, 0, 0, 4, 4, 4, 0, 0, 7, 0, 9, 0, 14, 4, 4, 6, 11, 5, 4, 4, 0, 0, 19, 0, 8, 17, 17, 0, 0, 0, 0, 0, 20, 0, 8, 0, 13, 0, 6, 0, 12, 4, 4, 8, 0, 20, 4, 4, 4, 0, 0, 15, 0, 10, 10, 0, 0, 0, 20, 5, 0, 4, 4, 0, 0, 24, 5, 0, 0, 0, 15, 8, 0, 9, 0, 21, 0, 0, 0, 4, 4, 6, 8, 4, 4, 4, 6, 11, 5, 4, 4, 15, 15, 5, 10, 0, 0, 12, 0, 16, 13, 5, 5, 4, 4, 0, 19, 0, 15, 15, 0, 0, 7, 0, 0, 12, 12, 0, 0, 0, 12, 12, 0, 0, 0, 4, 4, 4, 4, 4, 0, 0, 0, 0, 0], + # from arcaic times down through all the length of the patriarchal regime it has been the office of the women to + # prepare and administer these luxuries and it has been the perquisite of the men of gentle birth and breeding + # to consume them + # 3570/5694/3570_5694_000007_000003.wav + [0, 0, 0, 0, 0, 0, 20, 13, 8, 0, 17, 0, 4, 4, 0, 7, 0, 13, 0, 0, 0, 0, 0, 19, 0, 0, 0, 7, 0, 0, 0, 0, 10, 0, 19, 0, 0, 0, 4, 4, 0, 0, 0, 0, 6, 0, 0, 0, 10, 0, 0, 17, 5, 0, 0, 0, 12, 0, 4, 0, 0, 0, 0, 14, 0, 0, 8, 0, 18, 0, 0, 0, 9, 0, 0, 0, 0, 4, 4, 0, 0, 0, 6, 11, 13, 8, 0, 16, 21, 21, 11, 0, 4, 4, 7, 0, 15, 0, 15, 15, 4, 4, 6, 11, 5, 5, 4, 4, 0, 15, 0, 5, 0, 0, 9, 9, 0, 21, 0, 0, 6, 11, 0, 4, 4, 8, 8, 20, 4, 4, 4, 6, 11, 5, 4, 4, 0, 0, 0, 23, 0, 7, 7, 0, 0, 0, 0, 0, 6, 6, 13, 13, 13, 10, 0, 0, 0, 0, 0, 7, 13, 13, 0, 19, 11, 11, 11, 0, 0, 7, 15, 15, 0, 4, 4, 4, 13, 13, 5, 0, 0, 0, 0, 21, 21, 0, 0, 10, 0, 0, 0, 0, 17, 5, 0, 0, 0, 4, 4, 4, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 10, 0, 0, 6, 4, 4, 0, 0, 11, 7, 7, 0, 0, 12, 0, 4, 4, 0, 24, 5, 0, 0, 5, 5, 9, 0, 4, 6, 6, 11, 5, 4, 4, 0, 0, 8, 0, 20, 0, 0, 0, 20, 0, 10, 0, 0, 0, 19, 5, 0, 4, 4, 8, 0, 20, 4, 4, 6, 11, 5, 4, 4, 4, 18, 8, 0, 0, 0, 17, 5, 0, 9, 9, 0, 0, 4, 4, 0, 6, 6, 8, 0, 0, 4, 4, 0, 23, 23, 13, 5, 5, 0, 0, 0, 0, 23, 23, 0, 7, 0, 0, 0, 13, 5, 0, 0, 0, 4, 4, 0, 7, 0, 9, 14, 0, 4, 4, 0, 0, 7, 0, 14, 0, 0, 0, 17, 17, 10, 0, 9, 0, 10, 10, 0, 0, 12, 12, 0, 0, 0, 6, 0, 5, 13, 13, 0, 0, 0, 0, 4, 4, 4, 6, 11, 11, 5, 0, 0, 0, 12, 5, 5, 4, 4, 15, 15, 0, 16, 0, 0, 0, 28, 0, 0, 0, 16, 0, 0, 13, 13, 10, 0, 5, 5, 0, 0, 12, 12, 0, 0, 4, 4, 4, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 7, 9, 0, 14, 4, 4, 10, 0, 6, 4, 4, 0, 11, 11, 7, 0, 0, 0, 12, 0, 4, 4, 0, 0, 0, 0, 24, 5, 0, 0, 5, 5, 9, 9, 4, 4, 4, 6, 11, 5, 4, 4, 0, 0, 0, 23, 0, 5, 0, 13, 0, 0, 0, 0, 0, 30, 30, 16, 10, 10, 0, 0, 0, 12, 0, 10, 0, 0, 6, 5, 0, 4, 4, 8, 20, 0, 4, 4, 6, 11, 5, 4, 4, 0, 17, 5, 0, 0, 0, 9, 0, 0, 0, 0, 0, 4, 4, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 8, 0, 20, 4, 4, 4, 0, 0, 21, 0, 5, 5, 0, 9, 9, 0, 0, 0, 6, 0, 15, 0, 5, 0, 4, 0, 0, 0, 24, 0, 10, 0, 13, 0, 0, 0, 0, 6, 11, 0, 0, 4, 0, 0, 7, 0, 9, 14, 14, 4, 4, 4, 0, 0, 24, 13, 5, 0, 0, 0, 5, 0, 0, 14, 10, 0, 9, 21, 21, 0, 4, 4, 0, 6, 8, 0, 4, 4, 0, 19, 8, 0, 9, 0, 0, 0, 0, 0, 0, 0, 12, 0, 16, 0, 17, 5, 0, 0, 4, 4, 6, 11, 5, 0, 17, 0, 4, 4, 4, 4, 0, 0], + # yes it is perfection she declared + # 1284/1180/1284_1180_000036_000000.wav + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 22, 0, 5, 5, 0, 0, 0, 0, 0, 0, 0, 0, 12, 0, 0, 4, 4, 4, 4, 0, 0, 10, 0, 6, 0, 4, 4, 0, 0, 10, 0, 0, 0, 0, 0, 12, 0, 4, 4, 0, 0, 0, 23, 0, 5, 0, 13, 13, 0, 0, 0, 0, 0, 0, 0, 20, 0, 0, 5, 0, 0, 0, 19, 0, 0, 6, 6, 0, 10, 0, 8, 0, 9, 0, 0, 4, 4, 4, 4, 4, 0, 0, 0, 0, 12, 11, 11, 5, 0, 4, 4, 0, 14, 0, 5, 0, 0, 0, 0, 19, 15, 15, 0, 0, 7, 0, 0, 0, 13, 0, 5, 0, 14, 4, 4, 4, 4, 0, 0, 0], + # then it must be somewhere in the blue forest + # 1284/1180/1284_1180_000016_000002.wav + [0, 0, 0, 6, 11, 5, 0, 9, 0, 4, 4, 10, 6, 4, 4, 0, 17, 17, 16, 0, 0, 12, 0, 6, 4, 4, 0, 24, 5, 5, 0, 0, 4, 4, 0, 0, 12, 12, 0, 8, 0, 0, 17, 5, 5, 0, 0, 18, 18, 11, 5, 0, 13, 13, 5, 0, 4, 4, 10, 9, 4, 4, 6, 11, 5, 4, 4, 0, 24, 15, 15, 16, 16, 0, 5, 5, 0, 0, 4, 4, 0, 0, 0, 20, 8, 8, 8, 0, 0, 0, 13, 13, 0, 5, 5, 0, 0, 0, 0, 0, 12, 12, 0, 0, 6, 0, 0, 4, 4, 4, 4, 0, 0, 0, 0], + # happy youth that is ready to pack its valus and start for cathay on an hour's notice + # 4970/29093/4970_29093_000044_000002.wav + [0, 0, 0, 0, 11, 0, 7, 23, 0, 0, 0, 0, 23, 0, 22, 22, 0, 0, 0, 4, 4, 0, 0, 22, 8, 8, 16, 16, 0, 0, 0, 6, 6, 11, 0, 0, 4, 4, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 6, 11, 7, 6, 0, 4, 4, 10, 0, 0, 12, 0, 4, 0, 13, 13, 5, 0, 7, 0, 0, 14, 22, 0, 0, 0, 4, 0, 6, 0, 8, 4, 4, 0, 0, 0, 0, 0, 0, 23, 0, 7, 0, 0, 19, 0, 0, 26, 4, 4, 4, 10, 0, 6, 0, 12, 4, 4, 0, 0, 0, 25, 0, 7, 0, 0, 0, 15, 0, 0, 16, 0, 0, 0, 0, 12, 0, 0, 0, 0, 4, 4, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 7, 9, 0, 14, 4, 4, 0, 12, 12, 0, 6, 0, 7, 0, 13, 0, 0, 0, 6, 0, 0, 4, 4, 0, 0, 0, 0, 20, 8, 0, 13, 0, 4, 4, 4, 0, 0, 19, 0, 7, 7, 0, 0, 0, 0, 0, 6, 11, 0, 0, 7, 0, 0, 0, 22, 0, 0, 0, 0, 0, 4, 4, 0, 0, 8, 0, 9, 0, 4, 4, 7, 9, 4, 4, 4, 0, 0, 0, 11, 8, 8, 16, 0, 0, 13, 13, 0, 0, 0, 27, 0, 12, 0, 4, 4, 0, 9, 8, 8, 0, 0, 0, 0, 6, 10, 0, 0, 0, 0, 0, 19, 5, 5, 0, 0, 4, 4, 4, 4, 4, 0], + # well then i must make some suggestions to you + # 1580/141084/1580_141084_000057_000000.wav + [0, 0, 0, 0, 0, 0, 0, 18, 0, 5, 0, 15, 0, 0, 15, 15, 4, 4, 0, 0, 6, 11, 5, 0, 0, 0, 9, 0, 0, 4, 4, 4, 4, 4, 4, 0, 0, 0, 0, 0, 0, 0, 10, 0, 4, 4, 0, 17, 0, 16, 0, 0, 12, 0, 6, 0, 4, 4, 0, 17, 17, 7, 0, 26, 5, 5, 4, 4, 0, 12, 12, 8, 8, 17, 17, 5, 0, 4, 4, 4, 12, 12, 16, 0, 21, 0, 0, 0, 0, 21, 21, 0, 5, 0, 0, 0, 12, 0, 0, 0, 6, 6, 0, 10, 0, 8, 8, 9, 0, 0, 0, 0, 0, 0, 12, 0, 0, 4, 4, 0, 0, 6, 0, 8, 0, 4, 4, 4, 0, 0, 22, 22, 0, 8, 16, 0, 0, 0, 0, 0, 4, 4, 4, 4, 4, 0, 0, 0, 0, 0, 0, 0], + # some others too big cotton county + # 1995/1826/1995_1826_000010_000002.wav + [0, 0, 0, 0, 12, 0, 8, 0, 17, 5, 4, 4, 0, 8, 0, 0, 6, 11, 5, 0, 13, 13, 0, 0, 12, 0, 4, 4, 0, 0, 6, 0, 8, 0, 0, 8, 0, 0, 0, 0, 0, 0, 0, 0, 4, 4, 4, 4, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 24, 0, 0, 10, 0, 0, 0, 0, 21, 0, 0, 4, 4, 4, 0, 0, 0, 19, 0, 8, 0, 6, 6, 0, 0, 0, 6, 8, 0, 9, 9, 0, 0, 4, 0, 0, 0, 0, 19, 8, 8, 16, 0, 9, 9, 0, 0, 6, 6, 0, 0, 22, 0, 0, 0, 0, 4, 4, 0, 0, 0], + ] + + parser = argparse.ArgumentParser() + parser.add_argument('-opt', type=str, help='Path to options YAML file used to train the diffusion model', default='X:\\dlas\\experiments\\train_diffusion_tts5_medium.yml') + parser.add_argument('-diffusion_model_name', type=str, help='Name of the diffusion model in opt.', default='generator') + parser.add_argument('-diffusion_model_path', type=str, help='Path to saved model weights', default='X:\\dlas\\experiments\\train_diffusion_tts5_medium\\models\\73000_generator_ema.pth') + parser.add_argument('-sr_opt', type=str, help='Path to options YAML file used to train the SR diffusion model', default='X:\\dlas\\experiments\\train_diffusion_tts6_upsample.yml') + parser.add_argument('-sr_diffusion_model_name', type=str, help='Name of the SR diffusion model in opt.', default='generator') + parser.add_argument('-sr_diffusion_model_path', type=str, help='Path to saved model weights for the SR diffuser', default='X:\\dlas\\experiments\\train_diffusion_tts6_upsample\\models\\7000_generator_ema.pth') + parser.add_argument('-cond', type=str, help='Type of conditioning voice', default='carlin') + parser.add_argument('-diffusion_steps', type=int, help='Number of diffusion steps to perform to create the generate. Lower steps reduces quality, but >40 is generally pretty good.', default=100) + parser.add_argument('-output_path', type=str, help='Where to store outputs.', default='../results/use_diffuse_tts') + parser.add_argument('-device', type=str, help='Device to run on', default='cuda') + args = parser.parse_args() + os.makedirs(args.output_path, exist_ok=True) + + # Fixed parameters. + base_sample_rate = 5500 + sr_sample_rate = 22050 + + print("Loading Diffusion Models..") + diffusion = load_model_from_config(args.opt, args.diffusion_model_name, also_load_savepoint=False, + load_path=args.diffusion_model_path, device='cpu').eval() + diffuser = load_discrete_vocoder_diffuser(desired_diffusion_steps=args.diffusion_steps, schedule='cosine') + aligned_codes_compression_factor = base_sample_rate * 221 // 11025 + sr_diffusion = load_model_from_config(args.sr_opt, args.sr_diffusion_model_name, also_load_savepoint=False, + load_path=args.sr_diffusion_model_path, device='cpu').eval() + sr_diffuser = load_discrete_vocoder_diffuser(desired_diffusion_steps=args.diffusion_steps, schedule='linear') + sr_cond = load_audio(conditioning_clips[args.cond], sr_sample_rate).to(args.device) + if sr_cond.shape[-1] > 88000: + sr_cond = sr_cond[:,:88000] + cond = audio = torchaudio.functional.resample(sr_cond, sr_sample_rate, base_sample_rate) + torchaudio.save(os.path.join(args.output_path, 'cond_base.wav'), cond.cpu(), base_sample_rate) + torchaudio.save(os.path.join(args.output_path, 'cond_sr.wav'), sr_cond.cpu(), sr_sample_rate) + + with torch.no_grad(): + for p, code in enumerate(provided_codes): + print("Loading data..") + aligned_codes = torch.tensor(code).to(args.device) + + print("Performing initial diffusion..") + output_shape = (1, 1, ceil_multiple(aligned_codes.shape[-1]*aligned_codes_compression_factor, 2048)) + diffusion = diffusion.cuda() + output_base = diffuser.p_sample_loop(diffusion, output_shape, noise=torch.zeros(output_shape, device=args.device), + model_kwargs={'tokens': aligned_codes.unsqueeze(0), + 'conditioning_input': cond.unsqueeze(0)}) + diffusion = diffusion.cpu() + torchaudio.save(os.path.join(args.output_path, f'{p}_output_mean_base.wav'), output_base.cpu().squeeze(0), base_sample_rate) + + print("Performing SR diffusion..") + output_shape = (1, 1, output_base.shape[-1] * (sr_sample_rate // base_sample_rate)) + sr_diffusion = sr_diffusion.cuda() + output = diffuser.p_sample_loop(sr_diffusion, output_shape, noise=torch.zeros(output_shape, device=args.device), + model_kwargs={'tokens': aligned_codes.unsqueeze(0), + 'conditioning_input': sr_cond.unsqueeze(0), + 'lr_input': output_base}) + sr_diffusion = sr_diffusion.cpu() + torchaudio.save(os.path.join(args.output_path, f'{p}_output_mean_sr.wav'), output.cpu().squeeze(0), sr_sample_rate)