From 48e3ee9a5bfa59f6eac880c3cab0869210bec593 Mon Sep 17 00:00:00 2001 From: James Betker Date: Mon, 20 Dec 2021 19:05:56 -0700 Subject: [PATCH] Shuffle conditioning inputs along the positional axis to reduce fitting on prosody and other positional information The mels should still retain some short-range positional information the model can use for tone and frequencies, for example. --- codes/models/gpt_voice/gpt_tts_hf.py | 58 +++++++++++++++++----------- 1 file changed, 35 insertions(+), 23 deletions(-) diff --git a/codes/models/gpt_voice/gpt_tts_hf.py b/codes/models/gpt_voice/gpt_tts_hf.py index 31da0343..b62d5a57 100644 --- a/codes/models/gpt_voice/gpt_tts_hf.py +++ b/codes/models/gpt_voice/gpt_tts_hf.py @@ -8,6 +8,7 @@ from transformers import GPT2Model, GPT2Config, GPT2LMHeadModel, GPT2PreTrainedM from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions from transformers.utils.model_parallel_utils import get_device_map, assert_device_map +from models.arch_util import AttentionBlock from models.gpt_voice.gpt_asr_hf import GPT2InferenceModel from models.gpt_voice.mini_encoder import AudioMiniEncoder from models.tacotron2.text import symbols @@ -15,6 +16,28 @@ from trainer.networks import register_model from utils.util import opt_get +class ConditioningEncoder(nn.Module): + def __init__(self, + spec_dim, + embedding_dim, + attn_blocks=4, + num_attn_heads=4, + do_checkpointing=False): + super().__init__() + attn = [] + self.init = nn.Conv1d(spec_dim, embedding_dim, kernel_size=1) + for a in range(attn_blocks): + attn.append(AttentionBlock(embedding_dim, num_attn_heads, do_checkpoint=do_checkpointing)) + self.attn = nn.Sequential(*attn) + self.dim = embedding_dim + self.do_checkpointing = do_checkpointing + + def forward(self, x): + h = self.init(x) + h = self.attn(h) + return h[:, :, 0] + + class GptTtsHf(nn.Module): NUMBER_TEXT_TOKENS = len(symbols)+1 START_TEXT_TOKEN = len(symbols) @@ -24,14 +47,14 @@ class GptTtsHf(nn.Module): STOP_MEL_TOKEN = 8193 def __init__(self, layers=8, model_dim=512, heads=8, max_symbols_per_phrase=200, max_mel_tokens=250, max_conditioning_inputs=3, - checkpointing=True, mel_length_compression=1024, max_conditioning_length=44100//256): + checkpointing=True, mel_length_compression=1024, max_conditioning_length=60): super().__init__() self.max_mel_tokens = max_mel_tokens self.max_symbols_per_phrase = max_symbols_per_phrase self.model_dim = model_dim self.max_conditioning_inputs = max_conditioning_inputs self.mel_length_compression = mel_length_compression - self.conditioning_encoder = AudioMiniEncoder(80, model_dim) + self.conditioning_encoder = ConditioningEncoder(80, model_dim) self.text_embedding = nn.Embedding(self.NUMBER_TEXT_TOKENS, model_dim) seq_length = 2+self.max_symbols_per_phrase+self.max_conditioning_inputs+self.max_mel_tokens self.gpt_config = GPT2Config(vocab_size=self.NUMBER_MEL_CODES, @@ -54,19 +77,12 @@ class GptTtsHf(nn.Module): tar = F.pad(input, (0,1), value=stop_token) return inp, tar - def get_logits(self, text_inputs, cond_inputs, mel_inputs, get_attns=False): + def get_logits(self, text_inputs, cond_input, mel_inputs, get_attns=False): text_emb = self.text_embedding(text_inputs) - - conds = [] - for k in range(cond_inputs.shape[1]): - conds.append(self.conditioning_encoder(cond_inputs[:, k])) - while len(conds) < self.max_conditioning_inputs: - conds.append(conds[-1]) - conds = torch.stack(conds, dim=1) - + cond = self.conditioning_encoder(cond_input).unsqueeze(1) mel_emb = self.gpt.get_input_embeddings()(mel_inputs) - emb = torch.cat([text_emb, conds, mel_emb], dim=1) + emb = torch.cat([text_emb, cond, mel_emb], dim=1) gpt_out = self.gpt(inputs_embeds=emb, return_dict=True, output_attentions=get_attns) if get_attns: return gpt_out.attentions @@ -81,7 +97,7 @@ class GptTtsHf(nn.Module): return text_logits, mel_logits - def forward(self, text_inputs, cond_inputs, mel_targets, wav_lengths, return_attentions=False): + def forward(self, text_inputs, cond_input, mel_targets, wav_lengths, return_attentions=False): """ Forward pass text_inputs: long tensor, (b,t) @@ -95,18 +111,14 @@ class GptTtsHf(nn.Module): if mel_lengths[b] < mel_targets.shape[-1]: mel_targets[b, mel_lengths[b]:] = self.STOP_MEL_TOKEN - # Format conditioning inputs properly. - if len(cond_inputs.shape) == 3: - cond_inputs = cond_inputs.unsqueeze(1) # Format a single conditioning input as a set of {1} - if cond_inputs.shape[-1] > self.max_conditioning_length: - # Remember, that this doesn't necessarily mean that the conditioning inputs aren't mostly zero-padded, so - # skew trimming towards the front end of the clip. - rand_clip = random.randint(0, min(50, cond_inputs.shape[-1]-self.max_conditioning_length)) - cond_inputs = cond_inputs[:,:,:,rand_clip:rand_clip+self.max_conditioning_length] + # Randomly permute the conditioning spectrogram, to destroy any structure present. + cond_input = cond_input[:,:,torch.randperm(cond_input.shape[-1])] + if cond_input.shape[-1] > self.max_conditioning_length: + cond_input = cond_input[:,:,:self.max_conditioning_length] text_inputs, text_targets = self.build_aligned_inputs_and_targets(text_inputs, self.START_TEXT_TOKEN, self.STOP_TEXT_TOKEN) mel_inputs, mel_targets = self.build_aligned_inputs_and_targets(mel_targets, self.START_MEL_TOKEN, self.STOP_MEL_TOKEN) - text_logits, mel_logits = self.get_logits(text_inputs, cond_inputs, mel_inputs, get_attns=return_attentions) + text_logits, mel_logits = self.get_logits(text_inputs, cond_input, mel_inputs, get_attns=return_attentions) if return_attentions: return mel_logits loss_text = F.cross_entropy(text_logits, text_targets.long()) @@ -153,6 +165,6 @@ def register_gpt_tts_hf(opt_net, opt): if __name__ == '__main__': gpt = GptTtsHf(model_dim=1024, heads=16) l = gpt(torch.randint(high=len(symbols), size=(2,200)), - torch.randn(2,80,800), + torch.arange(0, 80, 1, dtype=torch.float).view(1,80,1).repeat(2,1,800), torch.randint(high=8192, size=(2,250)), torch.tensor([150*256,195*256]))