diff --git a/codes/models/gpt_voice/transformer_builders.py b/codes/models/gpt_voice/transformer_builders.py index 5092fb42..ae10d714 100644 --- a/codes/models/gpt_voice/transformer_builders.py +++ b/codes/models/gpt_voice/transformer_builders.py @@ -7,13 +7,23 @@ Every function contains the following arguments: layers: Net number of layers in the transformer. model_dim: Hidden dimensionality of the model. heads: Number of attention heads. - num_tokens: Number of possible tokens in the transformer's dictionary. Do not use this in future releases. - max_seq_len: Maximum sequence length to attend to. + max_mel_seq_len: Maximum mel sequence length to attend to. + max_text_seq_len: Maximum text sequence length to attend to. checkpointing: Whether or not the underlying implementation should support gradient checkpointing. + +Returns: + (model, global_mel_pos_embedding, global_text_pos_embedding, local_mel_pos_embedding, local_text_pos_embedding) + model: The transformer model + global_mel_pos_embedding: A global embedding function (that takes the MEL sequence as input) which should be added on to the MEL embeddings. + global_text_pos_embedding: The global embedding function for text tokens. + local_mel_pos_embedding: A local embedding function which, if not None, should be concatenated with the local text position embeddings and fed to the transformer. + local_text_pos_embedding: The local embedding function for text positions which will be None if local_mel_pos_embedding=None. + """ import functools from time import time import torch +import torch.nn as nn from tqdm import tqdm @@ -21,46 +31,58 @@ def null_position_embeddings(range, dim): return torch.zeros((range.shape[0], range.shape[1], dim), device=range.device) -def build_hf_gpt_transformer(layers, model_dim, heads, num_tokens, max_seq_len, checkpointing): +class LearnedPositionEmbeddings(nn.Module): + def __init__(self, seq_len, model_dim, init=.02): + super().__init__() + self.emb = nn.Embedding(seq_len, model_dim) + # Initializing this way is standard for GPT-2 + self.emb.weight.data.normal_(mean=0.0, std=init) + + def forward(self, x): + sl = x.shape[1] + return self.emb(torch.arange(0, sl, device=x.device)) + + +def build_hf_gpt_transformer(layers, model_dim, heads, max_mel_seq_len, max_text_seq_len, checkpointing): """ GPT-2 implemented by the HuggingFace library. """ from transformers import GPT2Config, GPT2Model - gpt_config = GPT2Config(vocab_size=num_tokens, - n_positions=max_seq_len, - n_ctx=max_seq_len, - n_embd=model_dim, - n_layer=layers, - n_head=heads, - gradient_checkpointing=checkpointing, - use_cache=not checkpointing) + gpt_config = GPT2Config(vocab_size=256, # Unused. + n_positions=max_mel_seq_len+max_text_seq_len, + n_ctx=max_mel_seq_len+max_text_seq_len, + n_embd=model_dim, + n_layer=layers, + n_head=heads, + gradient_checkpointing=checkpointing, + use_cache=not checkpointing) gpt = GPT2Model(gpt_config) # Override the built in positional embeddings del gpt.wpe gpt.wpe = functools.partial(null_position_embeddings, dim=model_dim) # Built-in token embeddings are unused. del gpt.wte - return gpt + return gpt, LearnedPositionEmbeddings(max_mel_seq_len, model_dim), LearnedPositionEmbeddings(max_text_seq_len, model_dim),\ + None, None -def build_lr_performer(layers, model_dim, heads, num_tokens, max_seq_len, checkpointing): +def build_lr_performer(layers, model_dim, heads, max_mel_seq_len, max_text_seq_len, checkpointing): """ lucidrains Performer implementation, https://github.com/lucidrains/performer-pytorch """ - from models.lucidrains.performer.performer_pytorch import PerformerLM - model = PerformerLM(dim=model_dim, depth=layers, heads=heads, dim_head=model_dim, causal=True, - num_tokens=num_tokens, max_seq_len=max_seq_len) + from models.lucidrains.performer.performer_pytorch import Performer + model = Performer(dim=model_dim, depth=layers, heads=heads, dim_head=model_dim, causal=True) return model -def build_lr_reformer(layers, model_dim, heads, num_tokens, max_seq_len, checkpointing): +def build_lr_reformer(layers, model_dim, heads, max_mel_seq_len, max_text_seq_len, checkpointing): """ lucidrains Reformer implementation, https://github.com/lucidrains/reformer-pytorch """ pass -def build_lr_xformer(layers, model_dim, heads, num_tokens, max_seq_len, checkpointing): +def build_lr_xformer(layers, model_dim, heads, max_mel_seq_len, max_text_seq_len, checkpointing): """ lucidrains x-transformer implementation, https://github.com/lucidrains/x-transformers """ diff --git a/codes/models/gpt_voice/unified_voice2.py b/codes/models/gpt_voice/unified_voice2.py index 239168b9..c307af40 100644 --- a/codes/models/gpt_voice/unified_voice2.py +++ b/codes/models/gpt_voice/unified_voice2.py @@ -105,14 +105,12 @@ class UnifiedVoice(nn.Module): self.mel_length_compression = mel_length_compression self.conditioning_encoder = ConditioningEncoder(80, model_dim, num_attn_heads=heads) self.text_embedding = nn.Embedding(self.number_text_tokens, model_dim) - self.text_pos_embedding = nn.Embedding(self.max_text_tokens + 2, model_dim) if use_mel_codes_as_input: self.mel_embedding = nn.Embedding(self.number_mel_codes, model_dim) else: self.mel_embedding = MelEncoder(model_dim, resblocks_per_reduction=1) - self.mel_pos_embedding = nn.Embedding(self.max_mel_tokens + 2, model_dim) - self.seq_length = 4+max_text_tokens+self.max_mel_tokens+self.max_conditioning_inputs - self.gpt = build_hf_gpt_transformer(layers, model_dim, heads, number_mel_codes, self.seq_length, checkpointing) + self.gpt, self.mel_pos_embedding, self.text_pos_embedding, self.mel_layer_pos_embedding, self.text_layer_pos_embedding = \ + build_hf_gpt_transformer(layers, model_dim, heads, self.max_text_tokens+2, self.max_mel_tokens+3, checkpointing) if train_solo_embeddings: self.mel_solo_embedding = nn.Parameter(torch.randn(1, 1, model_dim) * .02, requires_grad=True) self.text_solo_embedding = nn.Parameter(torch.randn(1, 1, model_dim) * .02, requires_grad=True) @@ -126,13 +124,11 @@ class UnifiedVoice(nn.Module): self.max_conditioning_length = max_conditioning_length # Initialize the embeddings per the GPT-2 scheme - embeddings = [self.text_embedding, self.text_pos_embedding, self.mel_pos_embedding] + embeddings = [self.text_embedding] if use_mel_codes_as_input: embeddings.append(self.mel_embedding) - for module in: + for module in embeddings: module.weight.data.normal_(mean=0.0, std=.02) - if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() def build_aligned_inputs_and_targets(self, input, start_token, stop_token): inp = F.pad(input, (1,0), value=start_token) @@ -218,14 +214,14 @@ class UnifiedVoice(nn.Module): speech_conditioning_input = self.conditioning_encoder(speech_conditioning_input).unsqueeze(1) text_inputs, text_targets = self.build_aligned_inputs_and_targets(text_inputs, self.start_text_token, self.stop_text_token) - text_emb = self.text_embedding(text_inputs) + self.text_pos_embedding(torch.arange(text_inputs.shape[1], device=text_inputs.device)) + text_emb = self.text_embedding(text_inputs) + self.text_pos_embedding(text_inputs) mel_codes, mel_targets = self.build_aligned_inputs_and_targets(mel_codes, self.start_mel_token, self.stop_mel_token) if raw_mels is not None: mel_inp = F.pad(raw_mels, (0, 8)) else: mel_inp = mel_codes mel_emb = self.mel_embedding(mel_inp) - mel_emb = mel_emb + self.mel_pos_embedding(torch.arange(mel_emb.shape[1], device=mel_emb.device)) + mel_emb = mel_emb + self.mel_pos_embedding(mel_codes) if text_first: text_logits, mel_logits = self.get_logits(speech_conditioning_input, text_emb, self.text_head, mel_emb, self.mel_head, get_attns=return_attentions) else: @@ -254,7 +250,7 @@ class UnifiedVoice(nn.Module): speech_conditioning_input = self.conditioning_encoder(speech_conditioning_input).unsqueeze(1) text_inputs, text_targets = self.build_aligned_inputs_and_targets(text_inputs, self.start_text_token, self.stop_text_token) - text_emb = self.text_embedding(text_inputs) + self.text_pos_embedding(torch.arange(text_inputs.shape[1], device=text_inputs.device)) + self.text_solo_embedding + text_emb = self.text_embedding(text_inputs) + self.text_pos_embedding(text_inputs) + self.text_solo_embedding text_logits = self.get_logits(speech_conditioning_input, text_emb, self.text_head) loss_text = F.cross_entropy(text_logits, text_targets.long()) return loss_text.mean() @@ -283,7 +279,7 @@ class UnifiedVoice(nn.Module): else: mel_inp = mel_codes mel_emb = self.mel_embedding(mel_inp) - mel_emb = mel_emb + self.mel_pos_embedding(torch.arange(mel_emb.shape[1], device=mel_emb.device)) + self.mel_solo_embedding + mel_emb = mel_emb + self.mel_pos_embedding(mel_codes) + self.mel_solo_embedding mel_logits = self.get_logits(speech_conditioning_input, mel_emb, self.mel_head) loss_mel = F.cross_entropy(mel_logits, mel_targets.long()) return loss_mel.mean() @@ -291,9 +287,10 @@ class UnifiedVoice(nn.Module): def inference_speech(self, speech_conditioning_input, text_inputs, **hf_generate_kwargs): if not hasattr(self, 'inference_model'): # TODO: Decouple gpt_config from this inference model. + seq_length = self.max_mel_tokens + self.max_text_tokens + 5 gpt_config = GPT2Config(vocab_size=self.max_mel_tokens, - n_positions=self.seq_length, - n_ctx=self.seq_length, + n_positions=seq_length, + n_ctx=seq_length, n_embd=self.model_dim, n_layer=self.layers, n_head=self.heads, @@ -303,7 +300,7 @@ class UnifiedVoice(nn.Module): text_inputs = F.pad(text_inputs, (0, 1), value=self.stop_text_token) text_inputs, text_targets = self.build_aligned_inputs_and_targets(text_inputs, self.start_text_token, self.stop_text_token) - text_emb = self.text_embedding(text_inputs) + self.text_pos_embedding(torch.arange(text_inputs.shape[1], device=text_inputs.device)) + text_emb = self.text_embedding(text_inputs) + self.text_pos_embedding(text_inputs) if self.shuffle_conditioning: # Randomly permute the conditioning spectrogram, to destroy any structure present.