From 8e94abd341cd16d8dbdbfaf981073e7fa9ca699d Mon Sep 17 00:00:00 2001 From: James Betker Date: Mon, 18 Apr 2022 14:47:44 -0600 Subject: [PATCH] Support CVVP & fix for major bug in API --- api.py | 32 ++++-- models/autoregressive.py | 3 +- models/{text_voice_clip.py => clvp.py} | 4 +- models/cvvp.py | 133 +++++++++++++++++++++++++ read.py | 2 +- 5 files changed, 161 insertions(+), 13 deletions(-) rename models/{text_voice_clip.py => clvp.py} (98%) diff --git a/api.py b/api.py index 04c3af8..c436f8b 100644 --- a/api.py +++ b/api.py @@ -7,12 +7,13 @@ import torch import torch.nn.functional as F import progressbar +from models.cvvp import CVVP from models.diffusion_decoder import DiffusionTts from models.autoregressive import UnifiedVoice from tqdm import tqdm from models.arch_util import TorchMelSpectrogram -from models.text_voice_clip import VoiceCLIP +from models.clvp import CLVP from models.vocoder import UnivNetGenerator from utils.audio import load_audio, wav_to_univnet_mel, denormalize_tacotron_mel from utils.diffusion import SpacedDiffusion, space_timesteps, get_named_beta_schedule @@ -175,11 +176,15 @@ class TextToSpeech: average_conditioning_embeddings=True).cpu().eval() self.autoregressive_for_diffusion.load_state_dict(torch.load('.models/autoregressive.pth')) - self.clip = VoiceCLIP(dim_text=512, dim_speech=512, dim_latent=512, num_text_tokens=256, text_enc_depth=12, - text_seq_len=350, text_heads=8, - num_speech_tokens=8192, speech_enc_depth=12, speech_heads=8, speech_seq_len=430, - use_xformers=True).cpu().eval() - self.clip.load_state_dict(torch.load('.models/clip.pth')) + self.clvp = CLVP(dim_text=512, dim_speech=512, dim_latent=512, num_text_tokens=256, text_enc_depth=12, + text_seq_len=350, text_heads=8, + num_speech_tokens=8192, speech_enc_depth=12, speech_heads=8, speech_seq_len=430, + use_xformers=True).cpu().eval() + self.clvp.load_state_dict(torch.load('.models/clip.pth')) + + self.cvvp = CVVP(model_dim=512, transformer_heads=8, dropout=0, mel_codes=8192, conditioning_enc_depth=8, cond_mask_percentage=0, + speech_enc_depth=8, speech_mask_percentage=0, latent_multiplier=1).cpu().eval() + self.cvvp.load_state_dict(torch.load('.models/cvvp.pth')) self.diffusion = DiffusionTts(model_channels=1024, num_layers=10, in_channels=100, out_channels=200, in_latent_channels=1024, in_tokens=8193, dropout=0, use_fp16=False, num_heads=16, @@ -216,6 +221,8 @@ class TextToSpeech: def tts(self, text, voice_samples, k=1, # autoregressive generation parameters follow num_autoregressive_samples=512, temperature=.8, length_penalty=1, repetition_penalty=2.0, top_p=.8, max_mel_tokens=500, + # CLVP & CVVP parameters + clvp_cvvp_slider=.5, # diffusion generation parameters follow diffusion_iterations=100, cond_free=True, cond_free_k=2, diffusion_temperature=1.0, **hf_generate_kwargs): @@ -253,15 +260,22 @@ class TextToSpeech: self.autoregressive = self.autoregressive.cpu() clip_results = [] - self.clip = self.clip.cuda() + self.clvp = self.clvp.cuda() + self.cvvp = self.cvvp.cuda() for batch in samples: for i in range(batch.shape[0]): batch[i] = fix_autoregressive_output(batch[i], stop_mel_token) - clip_results.append(self.clip(text.repeat(batch.shape[0], 1), batch, return_loss=False)) + clvp = self.clvp(text.repeat(batch.shape[0], 1), batch, return_loss=False) + cvvp_accumulator = 0 + for cl in range(conds.shape[1]): + cvvp_accumulator = cvvp_accumulator + self.cvvp(conds[:, cl].repeat(batch.shape[0], 1, 1), batch, return_loss=False ) + cvvp = cvvp_accumulator / conds.shape[1] + clip_results.append(clvp * clvp_cvvp_slider + cvvp * (1-clvp_cvvp_slider)) clip_results = torch.cat(clip_results, dim=0) samples = torch.cat(samples, dim=0) best_results = samples[torch.topk(clip_results, k=k).indices] - self.clip = self.clip.cpu() + self.clvp = self.clvp.cpu() + self.cvvp = self.cvvp.cpu() del samples # The diffusion model actually wants the last hidden layer from the autoregressive model as conditioning diff --git a/models/autoregressive.py b/models/autoregressive.py index 932e508..0c211f3 100644 --- a/models/autoregressive.py +++ b/models/autoregressive.py @@ -562,7 +562,8 @@ class UnifiedVoice(nn.Module): logits_processor = LogitsProcessorList([TypicalLogitsWarper(mass=typical_mass)]) if typical_sampling else LogitsProcessorList() max_length = trunc_index + self.max_mel_tokens - 1 if max_generate_length is None else trunc_index + max_generate_length gen = self.inference_model.generate(inputs, bos_token_id=self.start_mel_token, pad_token_id=self.stop_mel_token, eos_token_id=self.stop_mel_token, - max_length=max_length, logits_processor=logits_processor, **hf_generate_kwargs) + max_length=max_length, logits_processor=logits_processor, + num_return_sequences=num_return_sequences, **hf_generate_kwargs) return gen[:, trunc_index:] diff --git a/models/text_voice_clip.py b/models/clvp.py similarity index 98% rename from models/text_voice_clip.py rename to models/clvp.py index 674e62b..ecb8c40 100644 --- a/models/text_voice_clip.py +++ b/models/clvp.py @@ -16,7 +16,7 @@ def masked_mean(t, mask, dim = 1): t = t.masked_fill(~mask[:, :, None], 0.) return t.sum(dim = 1) / mask.sum(dim = 1)[..., None] -class VoiceCLIP(nn.Module): +class CLVP(nn.Module): """ CLIP model retrofitted for performing contrastive evaluation between tokenized audio data and the corresponding transcribed text. @@ -141,7 +141,7 @@ class VoiceCLIP(nn.Module): if __name__ == '__main__': - clip = VoiceCLIP(text_mask_percentage=.2, voice_mask_percentage=.2) + clip = CLVP(text_mask_percentage=.2, voice_mask_percentage=.2) clip(torch.randint(0,256,(2,120)), torch.tensor([50,100]), torch.randint(0,8192,(2,250)), diff --git a/models/cvvp.py b/models/cvvp.py index e69de29..0c9fd35 100644 --- a/models/cvvp.py +++ b/models/cvvp.py @@ -0,0 +1,133 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch import einsum +from torch.utils.checkpoint import checkpoint + +from models.arch_util import AttentionBlock +from models.xtransformers import ContinuousTransformerWrapper, Encoder + + +def exists(val): + return val is not None + + +def masked_mean(t, mask): + t = t.masked_fill(~mask, 0.) + return t.sum(dim = 1) / mask.sum(dim = 1) + + +class CollapsingTransformer(nn.Module): + def __init__(self, model_dim, output_dims, heads, dropout, depth, mask_percentage=0, **encoder_kwargs): + super().__init__() + self.transformer = ContinuousTransformerWrapper( + max_seq_len=-1, + use_pos_emb=False, + attn_layers=Encoder( + dim=model_dim, + depth=depth, + heads=heads, + ff_dropout=dropout, + ff_mult=1, + attn_dropout=dropout, + use_rmsnorm=True, + ff_glu=True, + rotary_pos_emb=True, + **encoder_kwargs, + )) + self.pre_combiner = nn.Sequential(nn.Conv1d(model_dim, output_dims, 1), + AttentionBlock(output_dims, num_heads=heads, do_checkpoint=False), + nn.Conv1d(output_dims, output_dims, 1)) + self.mask_percentage = mask_percentage + + def forward(self, x, **transformer_kwargs): + h = self.transformer(x, **transformer_kwargs) + h = h.permute(0,2,1) + h = checkpoint(self.pre_combiner, h).permute(0,2,1) + if self.training: + mask = torch.rand_like(h.float()) > self.mask_percentage + else: + mask = torch.ones_like(h.float()).bool() + return masked_mean(h, mask) + + +class ConvFormatEmbedding(nn.Module): + def __init__(self, *args, **kwargs): + super().__init__() + self.emb = nn.Embedding(*args, **kwargs) + + def forward(self, x): + y = self.emb(x) + return y.permute(0,2,1) + + +class CVVP(nn.Module): + def __init__( + self, + model_dim=512, + transformer_heads=8, + dropout=.1, + conditioning_enc_depth=8, + cond_mask_percentage=0, + mel_channels=80, + mel_codes=None, + speech_enc_depth=8, + speech_mask_percentage=0, + latent_multiplier=1, + ): + super().__init__() + latent_dim = latent_multiplier*model_dim + self.temperature = nn.Parameter(torch.tensor(1.)) + + self.cond_emb = nn.Sequential(nn.Conv1d(mel_channels, model_dim//2, kernel_size=5, stride=2, padding=2), + nn.Conv1d(model_dim//2, model_dim, kernel_size=3, stride=2, padding=1)) + self.conditioning_transformer = CollapsingTransformer(model_dim, model_dim, transformer_heads, dropout, conditioning_enc_depth, cond_mask_percentage) + self.to_conditioning_latent = nn.Linear(latent_dim, latent_dim, bias=False) + + if mel_codes is None: + self.speech_emb = nn.Conv1d(mel_channels, model_dim, kernel_size=5, padding=2) + else: + self.speech_emb = ConvFormatEmbedding(mel_codes, model_dim) + self.speech_transformer = CollapsingTransformer(model_dim, latent_dim, transformer_heads, dropout, speech_enc_depth, speech_mask_percentage) + self.to_speech_latent = nn.Linear(latent_dim, latent_dim, bias=False) + + def get_grad_norm_parameter_groups(self): + return { + 'conditioning': list(self.conditioning_transformer.parameters()), + 'speech': list(self.speech_transformer.parameters()), + } + + def forward( + self, + mel_cond, + mel_input, + return_loss=False + ): + cond_emb = self.cond_emb(mel_cond).permute(0,2,1) + enc_cond = self.conditioning_transformer(cond_emb) + cond_latents = self.to_conditioning_latent(enc_cond) + + speech_emb = self.speech_emb(mel_input).permute(0,2,1) + enc_speech = self.speech_transformer(speech_emb) + speech_latents = self.to_speech_latent(enc_speech) + + + cond_latents, speech_latents = map(lambda t: F.normalize(t, p=2, dim=-1), (cond_latents, speech_latents)) + temp = self.temperature.exp() + + if not return_loss: + sim = einsum('n d, n d -> n', cond_latents, speech_latents) * temp + return sim + + sim = einsum('i d, j d -> i j', cond_latents, speech_latents) * temp + labels = torch.arange(cond_latents.shape[0], device=mel_input.device) + loss = (F.cross_entropy(sim, labels) + F.cross_entropy(sim.t(), labels)) / 2 + + return loss + + +if __name__ == '__main__': + clvp = CVVP() + clvp(torch.randn(2,80,100), + torch.randn(2,80,95), + return_loss=True) \ No newline at end of file diff --git a/read.py b/read.py index 22623ac..fbff527 100644 --- a/read.py +++ b/read.py @@ -28,7 +28,7 @@ def split_and_recombine_text(texts, desired_length=200, max_len=300): if __name__ == '__main__': parser = argparse.ArgumentParser() - parser.add_argument('--textfile', type=str, help='A file containing the text to read.', default="data/riding_hood.txt") + parser.add_argument('--textfile', type=str, help='A file containing the text to read.', default="data/riding_hood2.txt") parser.add_argument('--voice', type=str, help='Selects the voice to use for generation. See options in voices/ directory (and add your own!) ' 'Use the & character to join two voices together. Use a comma to perform inference on multiple voices.', default='patrick_stewart') parser.add_argument('--output_path', type=str, help='Where to store outputs.', default='results/longform/')