From 3214ca0dfef49f96c4fc797f19ae1219b42da863 Mon Sep 17 00:00:00 2001 From: James Betker Date: Tue, 12 Apr 2022 20:53:09 -0600 Subject: [PATCH] support latents into the diffusion decoder --- api.py | 21 ++- eval_multiple.py | 2 +- models/autoregressive.py | 44 ++++-- models/diffusion_decoder.py | 17 ++- models/new_autoregressive.py | 286 ----------------------------------- 5 files changed, 55 insertions(+), 315 deletions(-) delete mode 100644 models/new_autoregressive.py diff --git a/api.py b/api.py index f5b2cd6..204c91f 100644 --- a/api.py +++ b/api.py @@ -117,7 +117,7 @@ def do_spectrogram_diffusion(diffusion_model, diffuser, mel_codes, conditioning_ cond_mels.append(cond_mel) cond_mels = torch.stack(cond_mels, dim=1) - output_seq_len = mel_codes.shape[-1]*4*24000//22050 # This diffusion model converts from 22kHz spectrogram codes to a 24kHz spectrogram signal. + output_seq_len = mel_codes.shape[1]*4*24000//22050 # This diffusion model converts from 22kHz spectrogram codes to a 24kHz spectrogram signal. output_shape = (mel_codes.shape[0], 100, output_seq_len) precomputed_embeddings = diffusion_model.timestep_independent(mel_codes, cond_mels, output_seq_len, False) @@ -151,11 +151,6 @@ class TextToSpeech: layer_drop=0, unconditioned_percentage=0).cpu().eval() self.diffusion.load_state_dict(torch.load('.models/diffusion.pth')) - self.diffusion_next = DiffusionTts(model_channels=1024, num_layers=10, in_channels=100, out_channels=200, - in_latent_channels=1024, in_tokens=8193, dropout=0, use_fp16=False, num_heads=16, - layer_drop=0, unconditioned_percentage=0).cpu().eval() - self.diffusion_next.load_state_dict(torch.load('.models/diffusion_next.pth')) - self.vocoder = UnivNetGenerator().cpu() self.vocoder.load_state_dict(torch.load('.models/vocoder.pth')['model_g']) self.vocoder.eval(inference=True) @@ -223,12 +218,22 @@ class TextToSpeech: self.clip = self.clip.cpu() del samples + # The diffusion model actually wants the last hidden layer from the autoregressive model as conditioning + # inputs. Re-produce those for the top results. This could be made more efficient by storing all of these + # results, but will increase memory usage. + self.autoregressive = self.autoregressive.cuda() + best_latents = self.autoregressive(conds, text, torch.tensor([text.shape[-1]], device=conds.device), best_results, + torch.tensor([best_results.shape[-1]*self.autoregressive.mel_length_compression], device=conds.device), + return_latent=True, clip_inputs=False) + self.autoregressive = self.autoregressive.cpu() + print("Performing vocoding..") wav_candidates = [] self.diffusion = self.diffusion.cuda() self.vocoder = self.vocoder.cuda() for b in range(best_results.shape[0]): codes = best_results[b].unsqueeze(0) + latents = best_latents[b].unsqueeze(0) # Find the first occurrence of the "calm" token and trim the codes to that. ctokens = 0 @@ -238,10 +243,10 @@ class TextToSpeech: else: ctokens = 0 if ctokens > 8: # 8 tokens gives the diffusion model some "breathing room" to terminate speech. - codes = codes[:, :k] + latents = latents[:, :k] break - mel = do_spectrogram_diffusion(self.diffusion, diffuser, codes, voice_samples, temperature=diffusion_temperature) + mel = do_spectrogram_diffusion(self.diffusion, diffuser, latents, voice_samples, temperature=diffusion_temperature) wav = self.vocoder.inference(mel) wav_candidates.append(wav.cpu()) self.diffusion = self.diffusion.cpu() diff --git a/eval_multiple.py b/eval_multiple.py index c55cdc1..9f1919d 100644 --- a/eval_multiple.py +++ b/eval_multiple.py @@ -7,7 +7,7 @@ from utils.audio import load_audio if __name__ == '__main__': fname = 'Y:\\libritts\\test-clean\\transcribed-brief-w2v.tsv' - outpath = 'D:\\tmp\\tortoise-tts-eval\\diverse_auto_256_samp_100_di_4' + outpath = 'D:\\tmp\\tortoise-tts-eval\\diverse_new_decoder_1' outpath_real = 'D:\\tmp\\tortoise-tts-eval\\real' os.makedirs(outpath, exist_ok=True) diff --git a/models/autoregressive.py b/models/autoregressive.py index 6f40ca7..64fd451 100644 --- a/models/autoregressive.py +++ b/models/autoregressive.py @@ -362,7 +362,7 @@ class UnifiedVoice(nn.Module): mel_input_tokens[b, actual_end:] = self.stop_mel_token return mel_input_tokens - def get_logits(self, speech_conditioning_inputs, first_inputs, first_head, second_inputs=None, second_head=None, get_attns=False): + def get_logits(self, speech_conditioning_inputs, first_inputs, first_head, second_inputs=None, second_head=None, get_attns=False, return_latent=False): if second_inputs is not None: emb = torch.cat([speech_conditioning_inputs, first_inputs, second_inputs], dim=1) else: @@ -374,6 +374,10 @@ class UnifiedVoice(nn.Module): enc = gpt_out.last_hidden_state[:, 1:] # The first logit is tied to the speech_conditioning_input enc = self.final_norm(enc) + + if return_latent: + return enc[:, speech_conditioning_inputs.shape[1]:speech_conditioning_inputs.shape[1]+first_inputs.shape[1]], enc[:, -second_inputs.shape[1]:] + first_logits = enc[:, :first_inputs.shape[1]] first_logits = first_head(first_logits) first_logits = first_logits.permute(0,2,1) @@ -385,7 +389,8 @@ class UnifiedVoice(nn.Module): else: return first_logits - def forward(self, speech_conditioning_input, text_inputs, text_lengths, mel_codes, wav_lengths, text_first=True, raw_mels=None, return_attentions=False): + def forward(self, speech_conditioning_input, text_inputs, text_lengths, mel_codes, wav_lengths, text_first=True, raw_mels=None, return_attentions=False, + return_latent=False, clip_inputs=True): """ Forward pass that uses both text and voice in either text conditioning mode or voice conditioning mode (actuated by `text_first`). @@ -396,19 +401,23 @@ class UnifiedVoice(nn.Module): mel_inputs: long tensor, (b,m) wav_lengths: long tensor, (b,) raw_mels: MEL float tensor (b,80,s) - """ - assert self.max_mel_tokens >= mel_codes.shape[1], f'{mel_codes.shape[1]}' - assert self.max_text_tokens >= text_inputs.shape[1], f'{text_inputs.shape[1]}' - # This model will receive micro-batches with a ton of padding for both the text and MELs. Ameliorate this by - # chopping the inputs by the maximum actual length. - max_text_len = text_lengths.max() - text_inputs = F.pad(text_inputs[:, :max_text_len], (0,1), value=self.stop_text_token) - max_mel_len = wav_lengths.max() // self.mel_length_compression - mel_codes = F.pad(mel_codes[:, :max_mel_len], (0,1), value=self.stop_mel_token) - if raw_mels is not None: - raw_mels = raw_mels[:, :, :max_mel_len*4] + If return_attentions is specified, only logits are returned. + If return_latent is specified, loss & logits are not computed or returned. Only the predicted latents are returned. + If clip_inputs is True, the inputs will be clipped to the smallest input size across each input modality. + """ + if clip_inputs: + # This model will receive micro-batches with a ton of padding for both the text and MELs. Ameliorate this by + # chopping the inputs by the maximum actual length. + max_text_len = text_lengths.max() + text_inputs = text_inputs[:, :max_text_len] + max_mel_len = wav_lengths.max() // self.mel_length_compression + mel_codes = mel_codes[:, :max_mel_len] + if raw_mels is not None: + raw_mels = raw_mels[:, :, :max_mel_len*4] mel_codes = self.set_mel_padding(mel_codes, wav_lengths) + text_inputs = F.pad(text_inputs, (0,1), value=self.stop_text_token) + mel_codes = F.pad(mel_codes, (0,1), value=self.stop_mel_token) speech_conditioning_input = speech_conditioning_input.unsqueeze(1) if len(speech_conditioning_input.shape) == 3 else speech_conditioning_input conds = [] @@ -427,10 +436,15 @@ class UnifiedVoice(nn.Module): mel_inp = mel_codes mel_emb = self.mel_embedding(mel_inp) mel_emb = mel_emb + self.mel_pos_embedding(mel_codes) + if text_first: - text_logits, mel_logits = self.get_logits(conds, text_emb, self.text_head, mel_emb, self.mel_head, get_attns=return_attentions) + text_logits, mel_logits = self.get_logits(conds, text_emb, self.text_head, mel_emb, self.mel_head, get_attns=return_attentions, return_latent=return_latent) + if return_latent: + return mel_logits[:, :-2] # Despite the name, these are not logits. Strip off the two tokens added by this forward pass. else: - mel_logits, text_logits = self.get_logits(conds, mel_emb, self.mel_head, text_emb, self.text_head, get_attns=return_attentions) + mel_logits, text_logits = self.get_logits(conds, mel_emb, self.mel_head, text_emb, self.text_head, get_attns=return_attentions, return_latent=return_latent) + if return_latent: + return text_logits[:, :-2] # Despite the name, these are not logits. Strip off the two tokens added by this forward pass. if return_attentions: return mel_logits diff --git a/models/diffusion_decoder.py b/models/diffusion_decoder.py index 1baf809..5fdf7ad 100644 --- a/models/diffusion_decoder.py +++ b/models/diffusion_decoder.py @@ -176,7 +176,13 @@ class DiffusionTts(nn.Module): AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True), ) self.code_norm = normalization(model_channels) - self.latent_converter = nn.Conv1d(in_latent_channels, model_channels, 1) + self.latent_conditioner = nn.Sequential( + nn.Conv1d(in_latent_channels, model_channels, 3, padding=1), + AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True), + AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True), + AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True), + AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True), + ) self.contextual_embedder = nn.Sequential(nn.Conv1d(in_channels,model_channels,3,padding=1,stride=2), nn.Conv1d(model_channels, model_channels*2,3,padding=1,stride=2), AttentionBlock(model_channels*2, num_heads, relative_pos_embeddings=True, do_checkpoint=False), @@ -190,6 +196,7 @@ class DiffusionTts(nn.Module): DiffusionLayer(model_channels, dropout, num_heads), DiffusionLayer(model_channels, dropout, num_heads), ) + self.integrating_conv = nn.Conv1d(model_channels*2, model_channels, kernel_size=1) self.mel_head = nn.Conv1d(model_channels, in_channels, kernel_size=3, padding=1) @@ -206,7 +213,7 @@ class DiffusionTts(nn.Module): groups = { 'minicoder': list(self.contextual_embedder.parameters()), 'layers': list(self.layers.parameters()), - 'code_converters': list(self.code_embedding.parameters()) + list(self.code_converter.parameters()) + list(self.latent_converter.parameters()) + list(self.latent_converter.parameters()), + 'code_converters': list(self.code_embedding.parameters()) + list(self.code_converter.parameters()) + list(self.latent_conditioner.parameters()) + list(self.latent_conditioner.parameters()), 'timestep_integrator': list(self.conditioning_timestep_integrator.parameters()) + list(self.integrating_conv.parameters()), 'time_embed': list(self.time_embed.parameters()), } @@ -227,7 +234,7 @@ class DiffusionTts(nn.Module): cond_emb = conds.mean(dim=-1) cond_scale, cond_shift = torch.chunk(cond_emb, 2, dim=1) if is_latent(aligned_conditioning): - code_emb = self.autoregressive_latent_converter(aligned_conditioning) + code_emb = self.latent_conditioner(aligned_conditioning) else: code_emb = self.code_embedding(aligned_conditioning).permute(0, 2, 1) code_emb = self.code_converter(code_emb) @@ -269,7 +276,7 @@ class DiffusionTts(nn.Module): if conditioning_free: code_emb = self.unconditioned_embedding.repeat(x.shape[0], 1, x.shape[-1]) unused_params.extend(list(self.code_converter.parameters()) + list(self.code_embedding.parameters())) - unused_params.extend(list(self.latent_converter.parameters())) + unused_params.extend(list(self.latent_conditioner.parameters())) else: if precomputed_aligned_embeddings is not None: code_emb = precomputed_aligned_embeddings @@ -278,7 +285,7 @@ class DiffusionTts(nn.Module): if is_latent(aligned_conditioning): unused_params.extend(list(self.code_converter.parameters()) + list(self.code_embedding.parameters())) else: - unused_params.extend(list(self.latent_converter.parameters())) + unused_params.extend(list(self.latent_conditioner.parameters())) unused_params.append(self.unconditioned_embedding) diff --git a/models/new_autoregressive.py b/models/new_autoregressive.py deleted file mode 100644 index aba8c11..0000000 --- a/models/new_autoregressive.py +++ /dev/null @@ -1,286 +0,0 @@ -import torch -import torch.nn as nn -import torch.nn.functional as F -from transformers import GPT2PreTrainedModel, GPT2Config -from models.xtransformers import TransformerWrapper, Encoder, Decoder -from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions - -from models.arch_util import AttentionBlock - - -class InferenceModel(GPT2PreTrainedModel): - """ - Implementation of GPT2PreTrainedModel from transformers, which allows us to use their generation library with - this transformer. - """ - def __init__(self, model): - super().__init__(GPT2Config()) - self.transformer = model - self.context = None - - def parallelize(self, device_map=None): - # Not implemented. - pass - - def deparallelize(self): - # Not implemented. - pass - - def get_output_embeddings(self): - assert False, "Unsupported operation." - - def set_output_embeddings(self, new_embeddings): - assert False, "Unsupported operation." - - def store_context(self, context): - self.context = context - - def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs): - token_type_ids = kwargs.get("token_type_ids", None) - # only last token for inputs_ids if past is defined in kwargs - if past: - input_ids = input_ids[:, -1].unsqueeze(-1) - if token_type_ids is not None: - token_type_ids = token_type_ids[:, -1].unsqueeze(-1) - - attention_mask = kwargs.get("attention_mask", None) - position_ids = kwargs.get("position_ids", None) - - if attention_mask is not None and position_ids is None: - # create position_ids on the fly for batch generation - position_ids = attention_mask.long().cumsum(-1) - 1 - position_ids.masked_fill_(attention_mask == 0, 1) - if past: - position_ids = position_ids[:, -1].unsqueeze(-1) - else: - position_ids = None - return { - "input_ids": input_ids, - "past_key_values": past, - "use_cache": kwargs.get("use_cache"), - "position_ids": position_ids, - "attention_mask": attention_mask, - "token_type_ids": token_type_ids, - } - - def forward( - self, - input_ids=None, - past_key_values=None, - attention_mask=None, - token_type_ids=None, - position_ids=None, - head_mask=None, - inputs_embeds=None, - encoder_hidden_states=None, - encoder_attention_mask=None, - labels=None, - use_cache=None, - output_attentions=None, - output_hidden_states=None, - return_dict=None, - ): - assert self.context is not None - assert inputs_embeds is None # Not supported by this inference model. - assert labels is None # Training not supported by this inference model. - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - out = self.transformer.decoder(input_ids, full_context=self.context, return_embeddings=True, past_key_values=past_key_values, - use_cache=use_cache, expected_seq_len=100) - if use_cache: - hidden_states, present_key_values = out - else: - hidden_states = out - present_key_values = None - logits = self.transformer.decoder.to_logits(hidden_states) - - if not return_dict: - return (logits, ) - - return CausalLMOutputWithCrossAttentions( - loss=None, - logits=logits, - past_key_values=present_key_values, - hidden_states=hidden_states, - attentions=None, - cross_attentions=None, - ) - - @staticmethod - def _reorder_cache(past, beam_idx): - """ - This function is used to re-order the :obj:`past_key_values` cache if - :meth:`~transformers.PreTrainedModel.beam_search` or :meth:`~transformers.PreTrainedModel.beam_sample` is - called. This is required to match :obj:`past_key_values` with the correct beam_idx at every generation step. - """ - return tuple( - tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past) - for layer_past in past - ) - - -class ResBlock(nn.Module): - """ - Basic residual convolutional block that uses GroupNorm. - """ - def __init__(self, chan): - super().__init__() - self.net = nn.Sequential( - nn.Conv1d(chan, chan, kernel_size=3, padding=1), - nn.GroupNorm(chan//8, chan), - nn.ReLU(), - nn.Conv1d(chan, chan, kernel_size=3, padding=1), - nn.GroupNorm(chan//8, chan) - ) - - def forward(self, x): - return F.relu(self.net(x) + x) - - -class ConditioningEncoder(nn.Module): - def __init__(self, - spec_dim, - embedding_dim, - attn_blocks=6, - num_attn_heads=4, - do_checkpointing=False): - super().__init__() - attn = [] - self.init = nn.Sequential(nn.Conv1d(spec_dim, embedding_dim//4, kernel_size=5, padding=2), - nn.Conv1d(embedding_dim//4, embedding_dim//2, kernel_size=3, padding=1, stride=2), - ResBlock(embedding_dim//2), - nn.Conv1d(embedding_dim//2, embedding_dim, kernel_size=3, padding=1, stride=2)) - for a in range(attn_blocks): - attn.append(AttentionBlock(embedding_dim, num_attn_heads, do_checkpoint=do_checkpointing)) - self.attn = nn.Sequential(*attn) - self.dim = embedding_dim - - def forward(self, x): - h = self.init(x) - h = self.attn(h) - return h.mean(dim=2) - - -class AutoregressiveCodegen(nn.Module): - def __init__(self, model_dim, depth, num_text_tokens=256, num_mel_tokens=8194, dropout=.1): - super().__init__() - assert depth >= 8 # This is the minimum bound to support the context interleaving that happens later. - - self.START_TOKEN=8192 - self.STOP_TOKEN=8193 - self.START_TEXT_TOKEN = 255 - self.STOP_TEXT_TOKEN = 0 - self.max_text_token_id = num_text_tokens - self.max_mel_token_id = num_mel_tokens - self.mel_embedding = ConditioningEncoder(80, model_dim, do_checkpointing=False) - self.encoder = TransformerWrapper( - num_tokens=num_text_tokens, - use_pos_emb=False, - max_seq_len=-1, - attn_layers = Encoder( - depth=depth, - heads=model_dim//64, - dim=model_dim, - attn_dropout=dropout, - ff_dropout=dropout, - use_rmsnorm=True, - ff_glu=True, - ff_mult=1, - rotary_pos_emb=True, - attn_rel_pos_bias=True, - )) - self.encoder.norm = nn.Identity() # This layer and the next are unused. - self.encoder.to_logits = nn.Identity() - self.decoder = TransformerWrapper( - num_tokens=num_mel_tokens, - use_pos_emb=False, - max_seq_len=-1, - attn_layers=Decoder( - depth=depth, - heads=model_dim//64, - dim=model_dim, - attn_dropout=dropout, - ff_dropout=dropout, - use_rmsnorm=True, - ff_glu=True, - ff_mult=1, - rotary_pos_emb=True, - cross_attend=True, - attn_rel_pos_bias=True, - )) - - def get_grad_norm_parameter_groups(self): - return { - 'encoder': list(self.encoder.parameters()), - 'decoder': list(self.decoder.parameters()), - 'minicoder': list(self.mel_embedding.parameters()), - } - - def forward(self, text_codes, conditioning_signal, mel_codes, wav_lengths, return_loss=True): - assert text_codes.max() < self.max_text_token_id and text_codes.min() >= 0, f'Invalid text code encountered: {text_codes.max()}, {text_codes.min()}' - assert mel_codes.max() < self.max_mel_token_id and mel_codes.min() >= 0, f'Invalid mel code encountered: {mel_codes.max()}, {mel_codes.min()}' - - # Format mel_codes with a stop token on the end. - mel_lengths = wav_lengths // 1024 + 1 - for b in range(mel_codes.shape[0]): - mel_codes[b, mel_lengths[b]:] = self.STOP_TOKEN - mel_codes = F.pad(mel_codes, (0, 1), value=self.STOP_TOKEN) - - # Build the context - if len(conditioning_signal.shape) != 4: - conditioning_signal = conditioning_signal.unsqueeze(1) - cond_embs = [] - for i in range(conditioning_signal.shape[1]): - cond_embs.append(self.mel_embedding(conditioning_signal[:, i])) - cond_emb = torch.stack(cond_embs, dim=1).mean(dim=1, keepdim=True) - # Since all positional embeddings are relative, it is (probably) important to "fix" the text with some permanent embeddings. - text_codes = F.pad(text_codes, (1,0), value=self.START_TEXT_TOKEN) - text_codes = F.pad(text_codes, (0,1), value=self.STOP_TEXT_TOKEN) - _, enc_text = self.encoder(text_codes, return_hiddens=True) - # Interleave cond_emb into the first few contexts. - full_context = enc_text - full_context[1] = cond_emb - full_context[3] = cond_emb - full_context[6] = cond_emb - - # Execute the decoder - dec_inputs = F.pad(mel_codes, (1,0), value=self.START_TOKEN)[:, :-1] - dec = self.decoder(dec_inputs, full_context=full_context) - if not return_loss: - return dec - loss_mel = F.cross_entropy(dec.permute(0,2,1), mel_codes) - return loss_mel - - def generate(self, conditioning_signal, text_codes, max_tokens=256, **hf_generate_kwargs): - inference_model = InferenceModel(self) - # Build the context - if len(conditioning_signal.shape) != 4: - conditioning_signal = conditioning_signal.unsqueeze(1) - cond_embs = [] - for i in range(conditioning_signal.shape[1]): - cond_embs.append(self.mel_embedding(conditioning_signal[:, i])) - cond_emb = torch.stack(cond_embs, dim=1).mean(dim=1, keepdim=True) - text_codes = F.pad(text_codes, (1,0), value=self.START_TEXT_TOKEN) - text_codes = F.pad(text_codes, (0,1), value=self.STOP_TEXT_TOKEN) - _, enc_text = self.encoder(text_codes, return_hiddens=True) - # Interleave cond_emb into the first few contexts. - full_context = enc_text - full_context[1] = cond_emb - full_context[3] = cond_emb - full_context[6] = cond_emb - inference_model.store_context(full_context) - - gen = inference_model.generate(bos_token_id=self.START_TOKEN, pad_token_id=self.STOP_TOKEN, eos_token_id=self.STOP_TOKEN, - max_length=max_tokens, output_attentions=False, return_dict_in_generate=True, use_cache=False, - **hf_generate_kwargs) - return gen.sequences - - -if __name__ == '__main__': - codegen = AutoregressiveCodegen(256, 10) - torch.save(codegen.state_dict(), 'sample.pth') - #codegen.generate(torch.randn((1,80,120)), torch.randint(0,256,(1,200))) - codegen(torch.randint(0,256, (2,200)), - torch.randn(2,80,120), - torch.randint(0,8192, (2,350)), - torch.tensor([192,350]))