From f1adc125058e69e83a9a9b3f059128809b8dd6cc Mon Sep 17 00:00:00 2001 From: James Betker Date: Mon, 28 Mar 2022 19:33:31 -0600 Subject: [PATCH] Upgrade CLIP model and add eval_multiple --- api.py | 214 ++++++++++++++++++++++++++++++++++++ do_tts.py | 14 +-- eval_multiple.py | 33 ++++++ models/arch_util.py | 46 +++++++- models/diffusion_decoder.py | 40 +------ models/text_voice_clip.py | 70 ++++++++---- 6 files changed, 350 insertions(+), 67 deletions(-) create mode 100644 api.py create mode 100644 eval_multiple.py diff --git a/api.py b/api.py new file mode 100644 index 0000000..28ce9ed --- /dev/null +++ b/api.py @@ -0,0 +1,214 @@ +import argparse +import os +import random +from urllib import request + +import torch +import torch.nn.functional as F +import torchaudio +import progressbar +import ocotillo + +from models.diffusion_decoder import DiffusionTts +from models.autoregressive import UnifiedVoice +from tqdm import tqdm + +from models.arch_util import TorchMelSpectrogram +from models.text_voice_clip import VoiceCLIP +from models.vocoder import UnivNetGenerator +from utils.audio import load_audio, wav_to_univnet_mel, denormalize_tacotron_mel +from utils.diffusion import SpacedDiffusion, space_timesteps, get_named_beta_schedule +from utils.tokenizer import VoiceBpeTokenizer, lev_distance + + +pbar = None +def download_models(): + MODELS = { + 'clip.pth': 'https://huggingface.co/jbetker/tortoise-tts-clip/resolve/main/pytorch-model.bin', + 'diffusion.pth': 'https://huggingface.co/jbetker/tortoise-tts-diffusion-v1/resolve/main/pytorch-model.bin', + 'autoregressive.pth': 'https://huggingface.co/jbetker/tortoise-tts-autoregressive/resolve/main/pytorch-model.bin' + } + os.makedirs('.models', exist_ok=True) + def show_progress(block_num, block_size, total_size): + global pbar + if pbar is None: + pbar = progressbar.ProgressBar(maxval=total_size) + pbar.start() + + downloaded = block_num * block_size + if downloaded < total_size: + pbar.update(downloaded) + else: + pbar.finish() + pbar = None + for model_name, url in MODELS.items(): + if os.path.exists(f'.models/{model_name}'): + continue + print(f'Downloading {model_name} from {url}...') + request.urlretrieve(url, f'.models/{model_name}', show_progress) + print('Done.') + + +def load_discrete_vocoder_diffuser(trained_diffusion_steps=4000, desired_diffusion_steps=200, cond_free=True): + """ + Helper function to load a GaussianDiffusion instance configured for use as a vocoder. + """ + return SpacedDiffusion(use_timesteps=space_timesteps(trained_diffusion_steps, [desired_diffusion_steps]), model_mean_type='epsilon', + model_var_type='learned_range', loss_type='mse', betas=get_named_beta_schedule('linear', trained_diffusion_steps), + conditioning_free=cond_free, conditioning_free_k=1) + + +def load_conditioning(clip, cond_length=132300): + gap = clip.shape[-1] - cond_length + if gap < 0: + clip = F.pad(clip, pad=(0, abs(gap))) + elif gap > 0: + rand_start = random.randint(0, gap) + clip = clip[:, rand_start:rand_start + cond_length] + mel_clip = TorchMelSpectrogram()(clip.unsqueeze(0)).squeeze(0) + return mel_clip.unsqueeze(0).cuda() + + +def fix_autoregressive_output(codes, stop_token): + """ + This function performs some padding on coded audio that fixes a mismatch issue between what the diffusion model was + trained on and what the autoregressive code generator creates (which has no padding or end). + This is highly specific to the DVAE being used, so this particular coding will not necessarily work if used with + a different DVAE. This can be inferred by feeding a audio clip padded with lots of zeros on the end through the DVAE + and copying out the last few codes. + + Failing to do this padding will produce speech with a harsh end that sounds like "BLAH" or similar. + """ + # Strip off the autoregressive stop token and add padding. + stop_token_indices = (codes == stop_token).nonzero() + if len(stop_token_indices) == 0: + print("No stop tokens found, enjoy that output of yours!") + return codes + else: + codes[stop_token_indices] = 83 + stm = stop_token_indices.min().item() + codes[stm:] = 83 + if stm - 3 < codes.shape[0]: + codes[-3] = 45 + codes[-2] = 45 + codes[-1] = 248 + + return codes + + +def do_spectrogram_diffusion(diffusion_model, diffuser, mel_codes, conditioning_input, mean=False): + """ + Uses the specified diffusion model and DVAE model to convert the provided MEL & conditioning inputs into an audio clip. + """ + with torch.no_grad(): + cond_mel = wav_to_univnet_mel(conditioning_input.squeeze(1), do_normalization=False) + # Pad MEL to multiples of 32 + msl = mel_codes.shape[-1] + dsl = 32 + gap = dsl - (msl % dsl) + if gap > 0: + mel = torch.nn.functional.pad(mel_codes, (0, gap)) + + output_shape = (mel.shape[0], 100, mel.shape[-1]*4) + precomputed_embeddings = diffusion_model.timestep_independent(mel_codes, cond_mel) + if mean: + mel = diffuser.p_sample_loop(diffusion_model, output_shape, noise=torch.zeros(output_shape, device=mel_codes.device), + model_kwargs={'precomputed_aligned_embeddings': precomputed_embeddings}) + else: + mel = diffuser.p_sample_loop(diffusion_model, output_shape, model_kwargs={'precomputed_aligned_embeddings': precomputed_embeddings}) + return denormalize_tacotron_mel(mel)[:,:,:msl*4] + + +class TextToSpeech: + def __init__(self, autoregressive_batch_size=32): + self.autoregressive_batch_size = autoregressive_batch_size + self.tokenizer = VoiceBpeTokenizer() + download_models() + + self.autoregressive = UnifiedVoice(max_mel_tokens=300, max_text_tokens=200, max_conditioning_inputs=2, layers=30, + model_dim=1024, + heads=16, number_text_tokens=256, start_text_token=255, checkpointing=False, + train_solo_embeddings=False, + average_conditioning_embeddings=True).cpu().eval() + self.autoregressive.load_state_dict(torch.load('.models/autoregressive.pth')) + + self.clip = VoiceCLIP(dim_text=512, dim_speech=512, dim_latent=512, num_text_tokens=256, text_enc_depth=12, + text_seq_len=350, text_heads=8, + num_speech_tokens=8192, speech_enc_depth=12, speech_heads=8, speech_seq_len=430, + use_xformers=True).cpu().eval() + self.clip.load_state_dict(torch.load('.models/clip.pth')) + + self.diffusion = DiffusionTts(model_channels=512, in_channels=100, out_channels=200, in_latent_channels=1024, + channel_mult=[1, 2, 3, 4], num_res_blocks=[3, 3, 3, 3], + token_conditioning_resolutions=[1, 4, 8], + dropout=0, attention_resolutions=[4, 8], num_heads=8, kernel_size=3, scale_factor=2, + time_embed_dim_multiplier=4, unconditioned_percentage=0, conditioning_dim_factor=2, + conditioning_expansion=1).cpu().eval() + self.diffusion.load_state_dict(torch.load('.models/diffusion.pth')) + + self.vocoder = UnivNetGenerator().cpu() + self.vocoder.load_state_dict(torch.load('.models/vocoder.pth')['model_g']) + self.vocoder.eval(inference=True) + + def tts(self, text, voice_samples, num_autoregressive_samples=512, k=1, diffusion_iterations=100, cond_free=True): + text = torch.IntTensor(self.tokenizer.encode(text)).unsqueeze(0).cuda() + text = F.pad(text, (0, 1)) # This may not be necessary. + + conds = [] + if not isinstance(voice_samples, list): + voice_samples = [voice_samples] + for vs in voice_samples: + conds.append(load_conditioning(vs)) + conds = torch.stack(conds, dim=1) + cond_diffusion = voice_samples[0].cuda() + # The diffusion model expects = 88200 conditioning samples. + if cond_diffusion.shape[-1] < 88200: + cond_diffusion = F.pad(cond_diffusion, (0, 88200-cond_diffusion.shape[-1])) + else: + cond_diffusion = cond_diffusion[:, :88200] + + diffuser = load_discrete_vocoder_diffuser(desired_diffusion_steps=diffusion_iterations, cond_free=cond_free) + + with torch.no_grad(): + samples = [] + num_batches = num_autoregressive_samples // self.autoregressive_batch_size + stop_mel_token = self.autoregressive.stop_mel_token + self.autoregressive = self.autoregressive.cuda() + for b in tqdm(range(num_batches)): + codes = self.autoregressive.inference_speech(conds, text, num_beams=1, repetition_penalty=1.0, do_sample=True, + top_k=50, top_p=.95, + temperature=.9, + num_return_sequences=self.autoregressive_batch_size, + length_penalty=1) + padding_needed = 250 - codes.shape[1] + codes = F.pad(codes, (0, padding_needed), value=stop_mel_token) + samples.append(codes) + self.autoregressive = self.autoregressive.cpu() + + clip_results = [] + self.clip = self.clip.cuda() + for batch in samples: + for i in range(batch.shape[0]): + batch[i] = fix_autoregressive_output(batch[i], stop_mel_token) + clip_results.append(self.clip(text.repeat(batch.shape[0], 1), batch, return_loss=False)) + clip_results = torch.cat(clip_results, dim=0) + samples = torch.cat(samples, dim=0) + best_results = samples[torch.topk(clip_results, k=k).indices] + self.clip = self.clip.cpu() + del samples + + print("Performing vocoding..") + wav_candidates = [] + self.diffusion = self.diffusion.cuda() + self.vocoder = self.vocoder.cuda() + for b in range(best_results.shape[0]): + code = best_results[b].unsqueeze(0) + mel = do_spectrogram_diffusion(self.diffusion, diffuser, code, cond_diffusion, mean=False) + wav = self.vocoder.inference(mel) + wav_candidates.append(wav.cpu()) + self.diffusion = self.diffusion.cpu() + self.vocoder = self.vocoder.cpu() + + if len(wav_candidates) > 1: + return wav_candidates + return wav_candidates[0] \ No newline at end of file diff --git a/do_tts.py b/do_tts.py index 8473fa2..aa2cbdc 100644 --- a/do_tts.py +++ b/do_tts.py @@ -138,8 +138,8 @@ if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('-text', type=str, help='Text to speak.', default="I am a language model that has learned to speak.") parser.add_argument('-voice', type=str, help='Use a preset conditioning voice (defined above). Overrides cond_path.', default='dotrice,harris,lescault,otto,atkins,grace,kennard,mol') - parser.add_argument('-num_samples', type=int, help='How many total outputs the autoregressive transformer should produce.', default=1024) - parser.add_argument('-num_batches', type=int, help='How many batches those samples should be produced over.', default=32) + parser.add_argument('-num_samples', type=int, help='How many total outputs the autoregressive transformer should produce.', default=512) + parser.add_argument('-num_batches', type=int, help='How many batches those samples should be produced over.', default=16) parser.add_argument('-num_diffusion_samples', type=int, help='Number of outputs that progress to the diffusion stage.', default=16) parser.add_argument('-output_path', type=str, help='Where to store outputs.', default='results/') args = parser.parse_args() @@ -179,19 +179,15 @@ if __name__ == '__main__': del autoregressive print("Loading CLIP..") - clip = VoiceCLIP(dim_text=512, dim_speech=512, dim_latent=512, num_text_tokens=256, text_enc_depth=8, text_seq_len=120, text_heads=8, - num_speech_tokens=8192, speech_enc_depth=10, speech_heads=8, speech_seq_len=250).cuda().eval() + clip = VoiceCLIP(dim_text=512, dim_speech=512, dim_latent=512, num_text_tokens=256, text_enc_depth=12, text_seq_len=350, text_heads=8, + num_speech_tokens=8192, speech_enc_depth=12, speech_heads=8, speech_seq_len=430, use_xformers=True).cuda().eval() clip.load_state_dict(torch.load('.models/clip.pth')) print("Performing CLIP filtering..") clip_results = [] for batch in samples: for i in range(batch.shape[0]): batch[i] = fix_autoregressive_output(batch[i], stop_mel_token) - text = text[:, :120] # Ugly hack to fix the fact that I didn't train CLIP to handle long enough text. - clip_results.append(clip(text.repeat(batch.shape[0], 1), - torch.full((batch.shape[0],), fill_value=text.shape[1]-1, dtype=torch.long, device='cuda'), - batch, torch.full((batch.shape[0],), fill_value=batch.shape[1]*1024, dtype=torch.long, device='cuda'), - return_loss=False)) + clip_results.append(clip(text.repeat(batch.shape[0], 1), batch, return_loss=False)) clip_results = torch.cat(clip_results, dim=0) samples = torch.cat(samples, dim=0) best_results = samples[torch.topk(clip_results, k=args.num_diffusion_samples).indices] diff --git a/eval_multiple.py b/eval_multiple.py new file mode 100644 index 0000000..43e3b4a --- /dev/null +++ b/eval_multiple.py @@ -0,0 +1,33 @@ +import os + +import torchaudio + +from api import TextToSpeech +from utils.audio import load_audio + +if __name__ == '__main__': + fname = 'Y:\\libritts\\test-clean\\transcribed-brief-w2v.tsv' + outpath = 'D:\\tmp\\tortoise-tts-eval\\baseline' + outpath_real = 'D:\\tmp\\tortoise-tts-eval\\real' + + os.makedirs(outpath, exist_ok=True) + os.makedirs(outpath_real, exist_ok=True) + with open(fname, 'r', encoding='utf-8') as f: + lines = [l.strip().split('\t') for l in f.readlines()] + + recorder = open(os.path.join(outpath, 'transcript.tsv'), 'w', encoding='utf-8') + tts = TextToSpeech() + for e, line in enumerate(lines): + transcript = line[0] + if len(transcript) > 120: + continue # We need to support this, but cannot yet. + path = os.path.join(os.path.dirname(fname), line[1]) + cond_audio = load_audio(path, 22050) + torchaudio.save(os.path.join(outpath_real, os.path.basename(line[1])), cond_audio, 22050) + sample = tts.tts(transcript, [cond_audio, cond_audio], num_autoregressive_samples=512, k=1, diffusion_iterations=200, cond_free=True) + down = torchaudio.functional.resample(sample, 24000, 22050) + fout_path = os.path.join(outpath, os.path.basename(line[1])) + torchaudio.save(fout_path, down.squeeze(0), 22050) + recorder.write(f'{transcript}\t{fout_path}\n') + recorder.flush() + recorder.close() \ No newline at end of file diff --git a/models/arch_util.py b/models/arch_util.py index ea2c214..d374594 100644 --- a/models/arch_util.py +++ b/models/arch_util.py @@ -1,9 +1,11 @@ +import functools import math import torch import torch.nn as nn import torch.nn.functional as F import torchaudio +from x_transformers import ContinuousTransformerWrapper def zero_module(module): @@ -316,4 +318,46 @@ class TorchMelSpectrogram(nn.Module): if self.mel_norms is not None: self.mel_norms = self.mel_norms.to(mel.device) mel = mel / self.mel_norms.unsqueeze(0).unsqueeze(-1) - return mel \ No newline at end of file + return mel + + +class CheckpointedLayer(nn.Module): + """ + Wraps a module. When forward() is called, passes kwargs that require_grad through torch.checkpoint() and bypasses + checkpoint for all other args. + """ + def __init__(self, wrap): + super().__init__() + self.wrap = wrap + + def forward(self, x, *args, **kwargs): + for k, v in kwargs.items(): + assert not (isinstance(v, torch.Tensor) and v.requires_grad) # This would screw up checkpointing. + partial = functools.partial(self.wrap, **kwargs) + return torch.utils.checkpoint.checkpoint(partial, x, *args) + + +class CheckpointedXTransformerEncoder(nn.Module): + """ + Wraps a ContinuousTransformerWrapper and applies CheckpointedLayer to each layer and permutes from channels-mid + to channels-last that XTransformer expects. + """ + def __init__(self, needs_permute=True, exit_permute=True, checkpoint=True, **xtransformer_kwargs): + super().__init__() + self.transformer = ContinuousTransformerWrapper(**xtransformer_kwargs) + self.needs_permute = needs_permute + self.exit_permute = exit_permute + + if not checkpoint: + return + for i in range(len(self.transformer.attn_layers.layers)): + n, b, r = self.transformer.attn_layers.layers[i] + self.transformer.attn_layers.layers[i] = nn.ModuleList([n, CheckpointedLayer(b), r]) + + def forward(self, x, **kwargs): + if self.needs_permute: + x = x.permute(0,2,1) + h = self.transformer(x, **kwargs) + if self.exit_permute: + h = h.permute(0,2,1) + return h \ No newline at end of file diff --git a/models/diffusion_decoder.py b/models/diffusion_decoder.py index 7a3bb4d..c57e9fb 100644 --- a/models/diffusion_decoder.py +++ b/models/diffusion_decoder.py @@ -15,7 +15,8 @@ from torch.nn import Linear from torch.utils.checkpoint import checkpoint from x_transformers import ContinuousTransformerWrapper, Encoder -from models.arch_util import normalization, zero_module, Downsample, Upsample, AudioMiniEncoder, AttentionBlock +from models.arch_util import normalization, zero_module, Downsample, Upsample, AudioMiniEncoder, AttentionBlock, \ + CheckpointedXTransformerEncoder def is_latent(t): @@ -157,43 +158,6 @@ class ResBlock(TimestepBlock): return self.skip_connection(x) + h -class CheckpointedLayer(nn.Module): - """ - Wraps a module. When forward() is called, passes kwargs that require_grad through torch.checkpoint() and bypasses - checkpoint for all other args. - """ - def __init__(self, wrap): - super().__init__() - self.wrap = wrap - - def forward(self, x, *args, **kwargs): - for k, v in kwargs.items(): - assert not (isinstance(v, torch.Tensor) and v.requires_grad) # This would screw up checkpointing. - partial = functools.partial(self.wrap, **kwargs) - return torch.utils.checkpoint.checkpoint(partial, x, *args) - - -class CheckpointedXTransformerEncoder(nn.Module): - """ - Wraps a ContinuousTransformerWrapper and applies CheckpointedLayer to each layer and permutes from channels-mid - to channels-last that XTransformer expects. - """ - def __init__(self, needs_permute=True, **xtransformer_kwargs): - super().__init__() - self.transformer = ContinuousTransformerWrapper(**xtransformer_kwargs) - self.needs_permute = needs_permute - - for i in range(len(self.transformer.attn_layers.layers)): - n, b, r = self.transformer.attn_layers.layers[i] - self.transformer.attn_layers.layers[i] = nn.ModuleList([n, CheckpointedLayer(b), r]) - - def forward(self, x, **kwargs): - if self.needs_permute: - x = x.permute(0,2,1) - h = self.transformer(x, **kwargs) - return h.permute(0,2,1) - - class DiffusionTts(nn.Module): """ The full UNet model with attention and timestep embedding. diff --git a/models/text_voice_clip.py b/models/text_voice_clip.py index 31194ae..b4b51a7 100644 --- a/models/text_voice_clip.py +++ b/models/text_voice_clip.py @@ -2,6 +2,9 @@ import torch import torch.nn as nn import torch.nn.functional as F from torch import einsum +from x_transformers import Encoder + +from models.arch_util import CheckpointedXTransformerEncoder from models.transformer import Transformer @@ -13,7 +16,6 @@ def masked_mean(t, mask, dim = 1): t = t.masked_fill(~mask[:, :, None], 0.) return t.sum(dim = 1) / mask.sum(dim = 1)[..., None] - class VoiceCLIP(nn.Module): """ CLIP model retrofitted for performing contrastive evaluation between tokenized audio data and the corresponding @@ -39,40 +41,69 @@ class VoiceCLIP(nn.Module): text_mask_percentage=0, voice_mask_percentage=0, wav_token_compression=1024, + use_xformers=False, ): super().__init__() self.text_emb = nn.Embedding(num_text_tokens, dim_text) - self.text_pos_emb = nn.Embedding(text_seq_len, dim_text) - self.text_transformer = Transformer(causal=False, seq_len=text_seq_len, dim=dim_text, depth=text_enc_depth, - heads=text_heads) self.to_text_latent = nn.Linear(dim_text, dim_latent, bias=False) self.speech_emb = nn.Embedding(num_speech_tokens, dim_speech) - self.speech_pos_emb = nn.Embedding(num_speech_tokens, dim_speech) - self.speech_transformer = Transformer(causal=False, seq_len=speech_seq_len, dim=dim_speech, - depth=speech_enc_depth, heads=speech_heads) self.to_speech_latent = nn.Linear(dim_speech, dim_latent, bias=False) + if use_xformers: + self.text_transformer = CheckpointedXTransformerEncoder( + needs_permute=False, + exit_permute=False, + max_seq_len=-1, + use_pos_emb=False, + attn_layers=Encoder( + dim=dim_text, + depth=text_enc_depth, + heads=text_heads, + ff_dropout=.1, + ff_mult=2, + attn_dropout=.1, + use_rmsnorm=True, + ff_glu=True, + rotary_pos_emb=True, + )) + self.speech_transformer = CheckpointedXTransformerEncoder( + needs_permute=False, + exit_permute=False, + max_seq_len=-1, + use_pos_emb=False, + attn_layers=Encoder( + dim=dim_speech, + depth=speech_enc_depth, + heads=speech_heads, + ff_dropout=.1, + ff_mult=2, + attn_dropout=.1, + use_rmsnorm=True, + ff_glu=True, + rotary_pos_emb=True, + )) + else: + self.text_transformer = Transformer(causal=False, seq_len=text_seq_len, dim=dim_text, depth=text_enc_depth, + heads=text_heads) + self.speech_transformer = Transformer(causal=False, seq_len=speech_seq_len, dim=dim_speech, + depth=speech_enc_depth, heads=speech_heads) + self.temperature = nn.Parameter(torch.tensor(1.)) self.text_mask_percentage = text_mask_percentage self.voice_mask_percentage = voice_mask_percentage self.wav_token_compression = wav_token_compression + self.xformers = use_xformers + if not use_xformers: + self.text_pos_emb = nn.Embedding(text_seq_len, dim_text) + self.speech_pos_emb = nn.Embedding(num_speech_tokens, dim_speech) def forward( self, text, - text_lengths, speech_tokens, - wav_lengths, return_loss=False ): - # This model will receive micro-batches with a ton of padding for both the text and MELs. Ameliorate this by - # chopping the inputs by the maximum actual length. - max_text_len = text_lengths.max() - text = text[:, :max_text_len] - max_mel_len = wav_lengths.max() // self.wav_token_compression - speech_tokens = speech_tokens[:, :max_mel_len] - b, device = text.shape[0], text.device if self.training: text_mask = torch.rand_like(text.float()) > self.text_mask_percentage @@ -82,10 +113,11 @@ class VoiceCLIP(nn.Module): voice_mask = torch.ones_like(speech_tokens.float()).bool() text_emb = self.text_emb(text) - text_emb += self.text_pos_emb(torch.arange(text.shape[1], device=device)) - speech_emb = self.speech_emb(speech_tokens) - speech_emb += self.speech_pos_emb(torch.arange(speech_emb.shape[1], device=device)) + + if not self.xformers: + text_emb += self.text_pos_emb(torch.arange(text.shape[1], device=device)) + speech_emb += self.speech_pos_emb(torch.arange(speech_emb.shape[1], device=device)) enc_text = self.text_transformer(text_emb, mask=text_mask) enc_speech = self.speech_transformer(speech_emb, mask=voice_mask)