From 91f28580e2f90e43d2e614b6b1677aecc4f389e7 Mon Sep 17 00:00:00 2001 From: James Betker Date: Mon, 10 Jan 2022 16:17:31 -0700 Subject: [PATCH] fix unified_voice --- codes/models/gpt_voice/unified_voice2.py | 79 +++++++++++------------- codes/scripts/audio/gen/use_gpt_tts.py | 2 +- 2 files changed, 36 insertions(+), 45 deletions(-) diff --git a/codes/models/gpt_voice/unified_voice2.py b/codes/models/gpt_voice/unified_voice2.py index c307af40..726ef167 100644 --- a/codes/models/gpt_voice/unified_voice2.py +++ b/codes/models/gpt_voice/unified_voice2.py @@ -62,7 +62,7 @@ class MelEncoder(nn.Module): class UnifiedVoice(nn.Module): def __init__(self, layers=8, model_dim=512, heads=8, max_text_tokens=120, max_mel_tokens=250, max_conditioning_inputs=1, - max_conditioning_length=60, shuffle_conditioning=True, mel_length_compression=1024, number_text_tokens=256, + mel_length_compression=1024, number_text_tokens=256, start_text_token=255, stop_text_token=0, number_mel_codes=8194, start_mel_token=8192, stop_mel_token=8193, train_solo_embeddings=False, use_mel_codes_as_input=True, checkpointing=True): @@ -74,8 +74,6 @@ class UnifiedVoice(nn.Module): max_text_tokens: Maximum number of text tokens that will be encountered by model. max_mel_tokens: Maximum number of MEL tokens that will be encountered by model. max_conditioning_inputs: Maximum number of conditioning inputs provided to the model. If (1), conditioning input can be of format (b,80,s), otherwise (b,n,80,s). - max_conditioning_length: Maximum length of conditioning input. Only needed if shuffle_conditioning=True - shuffle_conditioning: Whether or not the conditioning inputs will be shuffled across the sequence dimension. Useful if you want to provide the same input as conditioning and mel_codes. mel_length_compression: The factor between and . Used to compute MEL code padding given wav input length. number_text_tokens: start_text_token: @@ -95,7 +93,6 @@ class UnifiedVoice(nn.Module): self.number_mel_codes = number_mel_codes self.start_mel_token = start_mel_token self.stop_mel_token = stop_mel_token - self.shuffle_conditioning = shuffle_conditioning self.layers = layers self.heads = heads self.max_mel_tokens = max_mel_tokens @@ -110,7 +107,7 @@ class UnifiedVoice(nn.Module): else: self.mel_embedding = MelEncoder(model_dim, resblocks_per_reduction=1) self.gpt, self.mel_pos_embedding, self.text_pos_embedding, self.mel_layer_pos_embedding, self.text_layer_pos_embedding = \ - build_hf_gpt_transformer(layers, model_dim, heads, self.max_text_tokens+2, self.max_mel_tokens+3, checkpointing) + build_hf_gpt_transformer(layers, model_dim, heads, self.max_mel_tokens+2+self.max_conditioning_inputs, self.max_text_tokens+2, checkpointing) if train_solo_embeddings: self.mel_solo_embedding = nn.Parameter(torch.randn(1, 1, model_dim) * .02, requires_grad=True) self.text_solo_embedding = nn.Parameter(torch.randn(1, 1, model_dim) * .02, requires_grad=True) @@ -121,7 +118,6 @@ class UnifiedVoice(nn.Module): self.final_norm = nn.LayerNorm(model_dim) self.text_head = nn.Linear(model_dim, self.number_text_tokens) self.mel_head = nn.Linear(model_dim, self.number_mel_codes) - self.max_conditioning_length = max_conditioning_length # Initialize the embeddings per the GPT-2 scheme embeddings = [self.text_embedding] @@ -149,23 +145,11 @@ class UnifiedVoice(nn.Module): mel_input_tokens[b, actual_end:] = self.stop_mel_token return mel_input_tokens - def randomly_permute_conditioning_input(self, speech_conditioning_input): - """ - Randomly permute the conditioning spectrogram, to destroy any structure present. Note that since the - conditioning input is derived from a discrete spectrogram, it does actually retain structure, but only a little - bit (actually: exactly how much we want; enough to discriminate different vocal qualities, but nothing about - what is being said). - """ - cond_input = speech_conditioning_input[:,:,torch.randperm(speech_conditioning_input.shape[-1])] - if cond_input.shape[-1] > self.max_conditioning_length: - cond_input = cond_input[:,:,:self.max_conditioning_length] - return cond_input - - def get_logits(self, speech_conditioning_input, first_inputs, first_head, second_inputs=None, second_head=None, get_attns=False): + def get_logits(self, speech_conditioning_inputs, first_inputs, first_head, second_inputs=None, second_head=None, get_attns=False): if second_inputs is not None: - emb = torch.cat([speech_conditioning_input, first_inputs, second_inputs], dim=1) + emb = torch.cat([speech_conditioning_inputs, first_inputs, second_inputs], dim=1) else: - emb = torch.cat([speech_conditioning_input, first_inputs], dim=1) + emb = torch.cat([speech_conditioning_inputs, first_inputs], dim=1) gpt_out = self.gpt(inputs_embeds=emb, return_dict=True, output_attentions=get_attns) if get_attns: @@ -209,9 +193,11 @@ class UnifiedVoice(nn.Module): raw_mels = raw_mels[:, :, :max_mel_len*4] mel_codes = self.set_mel_padding(mel_codes, wav_lengths) - if self.shuffle_conditioning: - speech_conditioning_input = self.randomly_permute_conditioning_input(speech_conditioning_input) - speech_conditioning_input = self.conditioning_encoder(speech_conditioning_input).unsqueeze(1) + speech_conditioning_input = speech_conditioning_input.unsqueeze(1) if len(speech_conditioning_input.shape) == 3 else speech_conditioning_input + conds = [] + for j in range(speech_conditioning_input.shape[1]): + conds.append(self.conditioning_encoder(speech_conditioning_input[:, j])) + conds = torch.stack(conds, dim=1) text_inputs, text_targets = self.build_aligned_inputs_and_targets(text_inputs, self.start_text_token, self.stop_text_token) text_emb = self.text_embedding(text_inputs) + self.text_pos_embedding(text_inputs) @@ -223,9 +209,9 @@ class UnifiedVoice(nn.Module): mel_emb = self.mel_embedding(mel_inp) mel_emb = mel_emb + self.mel_pos_embedding(mel_codes) if text_first: - text_logits, mel_logits = self.get_logits(speech_conditioning_input, text_emb, self.text_head, mel_emb, self.mel_head, get_attns=return_attentions) + text_logits, mel_logits = self.get_logits(conds, text_emb, self.text_head, mel_emb, self.mel_head, get_attns=return_attentions) else: - mel_logits, text_logits = self.get_logits(speech_conditioning_input, mel_emb, self.mel_head, text_emb, self.text_head, get_attns=return_attentions) + mel_logits, text_logits = self.get_logits(conds, mel_emb, self.mel_head, text_emb, self.text_head, get_attns=return_attentions) if return_attentions: return mel_logits @@ -245,13 +231,15 @@ class UnifiedVoice(nn.Module): max_text_len = text_lengths.max() text_inputs = F.pad(text_inputs[:, :max_text_len], (0,1), value=self.stop_text_token) - if self.shuffle_conditioning: - speech_conditioning_input = self.randomly_permute_conditioning_input(speech_conditioning_input) - speech_conditioning_input = self.conditioning_encoder(speech_conditioning_input).unsqueeze(1) + speech_conditioning_input = speech_conditioning_input.unsqueeze(1) if len(speech_conditioning_input.shape) == 3 else speech_conditioning_input + conds = [] + for j in range(speech_conditioning_input.shape[1]): + conds.append(self.conditioning_encoder(speech_conditioning_input[:, j])) + conds = torch.stack(conds, dim=1) text_inputs, text_targets = self.build_aligned_inputs_and_targets(text_inputs, self.start_text_token, self.stop_text_token) text_emb = self.text_embedding(text_inputs) + self.text_pos_embedding(text_inputs) + self.text_solo_embedding - text_logits = self.get_logits(speech_conditioning_input, text_emb, self.text_head) + text_logits = self.get_logits(conds, text_emb, self.text_head) loss_text = F.cross_entropy(text_logits, text_targets.long()) return loss_text.mean() @@ -269,9 +257,11 @@ class UnifiedVoice(nn.Module): if raw_mels is not None: raw_mels = raw_mels[:, :, :max_mel_len*4] - if self.shuffle_conditioning: - speech_conditioning_input = self.randomly_permute_conditioning_input(speech_conditioning_input) - speech_conditioning_input = self.conditioning_encoder(speech_conditioning_input).unsqueeze(1) + speech_conditioning_input = speech_conditioning_input.unsqueeze(1) if len(speech_conditioning_input.shape) == 3 else speech_conditioning_input + conds = [] + for j in range(speech_conditioning_input.shape[1]): + conds.append(self.conditioning_encoder(speech_conditioning_input[:, j])) + conds = torch.stack(conds, dim=1) mel_codes, mel_targets = self.build_aligned_inputs_and_targets(mel_codes, self.start_mel_token, self.stop_mel_token) if raw_mels is not None: @@ -280,7 +270,7 @@ class UnifiedVoice(nn.Module): mel_inp = mel_codes mel_emb = self.mel_embedding(mel_inp) mel_emb = mel_emb + self.mel_pos_embedding(mel_codes) + self.mel_solo_embedding - mel_logits = self.get_logits(speech_conditioning_input, mel_emb, self.mel_head) + mel_logits = self.get_logits(conds, mel_emb, self.mel_head) loss_mel = F.cross_entropy(mel_logits, mel_targets.long()) return loss_mel.mean() @@ -302,12 +292,13 @@ class UnifiedVoice(nn.Module): text_inputs, text_targets = self.build_aligned_inputs_and_targets(text_inputs, self.start_text_token, self.stop_text_token) text_emb = self.text_embedding(text_inputs) + self.text_pos_embedding(text_inputs) - if self.shuffle_conditioning: - # Randomly permute the conditioning spectrogram, to destroy any structure present. - speech_conditioning_input = self.randomly_permute_conditioning_input(speech_conditioning_input) - cond = self.conditioning_encoder(speech_conditioning_input).unsqueeze(1) + speech_conditioning_input = speech_conditioning_input.unsqueeze(1) if len(speech_conditioning_input.shape) == 3 else speech_conditioning_input + conds = [] + for j in range(speech_conditioning_input.shape[1]): + conds.append(self.conditioning_encoder(speech_conditioning_input[:, j])) + conds = torch.stack(conds, dim=1) - emb = torch.cat([cond, text_emb], dim=1) + emb = torch.cat([conds, text_emb], dim=1) self.inference_model.store_mel_emb(emb) fake_inputs = torch.full((emb.shape[0], emb.shape[1]+1,), fill_value=1, dtype=torch.long, device=text_inputs.device) @@ -324,10 +315,10 @@ def register_unified_voice2(opt_net, opt): if __name__ == '__main__': - gpt = UnifiedVoice(model_dim=256, heads=4, train_solo_embeddings=True, use_mel_codes_as_input=True) - l = gpt(torch.randn(2, 80, 800), - torch.randint(high=len(symbols), size=(2,80)), - torch.tensor([32, 80]), + gpt = UnifiedVoice(model_dim=256, heads=4, train_solo_embeddings=True, use_mel_codes_as_input=True, max_conditioning_inputs=4) + l = gpt(torch.randn(2, 3, 80, 800), + torch.randint(high=len(symbols), size=(2,120)), + torch.tensor([32, 120]), torch.randint(high=8192, size=(2,250)), - torch.tensor([150*256,195*256])) + torch.tensor([250*256,195*256])) gpt.text_forward(torch.randn(2,80,800), torch.randint(high=50, size=(2,80)), torch.tensor([32, 80])) diff --git a/codes/scripts/audio/gen/use_gpt_tts.py b/codes/scripts/audio/gen/use_gpt_tts.py index a2c15c94..da3c9129 100644 --- a/codes/scripts/audio/gen/use_gpt_tts.py +++ b/codes/scripts/audio/gen/use_gpt_tts.py @@ -124,7 +124,7 @@ if __name__ == '__main__': text = F.pad(text, (0,1)) # This may not be necessary. cond_path = args.cond_path if args.cond_preset is None else preselected_cond_voices[args.cond_preset] - conds, cond_wav = load_conditioning(cond_path) + conds, cond_wav = load_conditioning(cond_path, cond_length=88000) with torch.no_grad(): print("Performing GPT inference..")