forked from mrq/DL-Art-School
unified_voice2: decouple positional embeddings and token embeddings from underlying gpt model
This commit is contained in:
parent
f503d8d96b
commit
ee3dfac2ae
|
@ -7,13 +7,23 @@ Every function contains the following arguments:
|
||||||
layers: Net number of layers in the transformer.
|
layers: Net number of layers in the transformer.
|
||||||
model_dim: Hidden dimensionality of the model.
|
model_dim: Hidden dimensionality of the model.
|
||||||
heads: Number of attention heads.
|
heads: Number of attention heads.
|
||||||
num_tokens: Number of possible tokens in the transformer's dictionary. Do not use this in future releases.
|
max_mel_seq_len: Maximum mel sequence length to attend to.
|
||||||
max_seq_len: Maximum sequence length to attend to.
|
max_text_seq_len: Maximum text sequence length to attend to.
|
||||||
checkpointing: Whether or not the underlying implementation should support gradient checkpointing.
|
checkpointing: Whether or not the underlying implementation should support gradient checkpointing.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(model, global_mel_pos_embedding, global_text_pos_embedding, local_mel_pos_embedding, local_text_pos_embedding)
|
||||||
|
model: The transformer model
|
||||||
|
global_mel_pos_embedding: A global embedding function (that takes the MEL sequence as input) which should be added on to the MEL embeddings.
|
||||||
|
global_text_pos_embedding: The global embedding function for text tokens.
|
||||||
|
local_mel_pos_embedding: A local embedding function which, if not None, should be concatenated with the local text position embeddings and fed to the transformer.
|
||||||
|
local_text_pos_embedding: The local embedding function for text positions which will be None if local_mel_pos_embedding=None.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
import functools
|
import functools
|
||||||
from time import time
|
from time import time
|
||||||
import torch
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
|
||||||
|
@ -21,46 +31,58 @@ def null_position_embeddings(range, dim):
|
||||||
return torch.zeros((range.shape[0], range.shape[1], dim), device=range.device)
|
return torch.zeros((range.shape[0], range.shape[1], dim), device=range.device)
|
||||||
|
|
||||||
|
|
||||||
def build_hf_gpt_transformer(layers, model_dim, heads, num_tokens, max_seq_len, checkpointing):
|
class LearnedPositionEmbeddings(nn.Module):
|
||||||
|
def __init__(self, seq_len, model_dim, init=.02):
|
||||||
|
super().__init__()
|
||||||
|
self.emb = nn.Embedding(seq_len, model_dim)
|
||||||
|
# Initializing this way is standard for GPT-2
|
||||||
|
self.emb.weight.data.normal_(mean=0.0, std=init)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
sl = x.shape[1]
|
||||||
|
return self.emb(torch.arange(0, sl, device=x.device))
|
||||||
|
|
||||||
|
|
||||||
|
def build_hf_gpt_transformer(layers, model_dim, heads, max_mel_seq_len, max_text_seq_len, checkpointing):
|
||||||
"""
|
"""
|
||||||
GPT-2 implemented by the HuggingFace library.
|
GPT-2 implemented by the HuggingFace library.
|
||||||
"""
|
"""
|
||||||
from transformers import GPT2Config, GPT2Model
|
from transformers import GPT2Config, GPT2Model
|
||||||
gpt_config = GPT2Config(vocab_size=num_tokens,
|
gpt_config = GPT2Config(vocab_size=256, # Unused.
|
||||||
n_positions=max_seq_len,
|
n_positions=max_mel_seq_len+max_text_seq_len,
|
||||||
n_ctx=max_seq_len,
|
n_ctx=max_mel_seq_len+max_text_seq_len,
|
||||||
n_embd=model_dim,
|
n_embd=model_dim,
|
||||||
n_layer=layers,
|
n_layer=layers,
|
||||||
n_head=heads,
|
n_head=heads,
|
||||||
gradient_checkpointing=checkpointing,
|
gradient_checkpointing=checkpointing,
|
||||||
use_cache=not checkpointing)
|
use_cache=not checkpointing)
|
||||||
gpt = GPT2Model(gpt_config)
|
gpt = GPT2Model(gpt_config)
|
||||||
# Override the built in positional embeddings
|
# Override the built in positional embeddings
|
||||||
del gpt.wpe
|
del gpt.wpe
|
||||||
gpt.wpe = functools.partial(null_position_embeddings, dim=model_dim)
|
gpt.wpe = functools.partial(null_position_embeddings, dim=model_dim)
|
||||||
# Built-in token embeddings are unused.
|
# Built-in token embeddings are unused.
|
||||||
del gpt.wte
|
del gpt.wte
|
||||||
return gpt
|
return gpt, LearnedPositionEmbeddings(max_mel_seq_len, model_dim), LearnedPositionEmbeddings(max_text_seq_len, model_dim),\
|
||||||
|
None, None
|
||||||
|
|
||||||
|
|
||||||
def build_lr_performer(layers, model_dim, heads, num_tokens, max_seq_len, checkpointing):
|
def build_lr_performer(layers, model_dim, heads, max_mel_seq_len, max_text_seq_len, checkpointing):
|
||||||
"""
|
"""
|
||||||
lucidrains Performer implementation, https://github.com/lucidrains/performer-pytorch
|
lucidrains Performer implementation, https://github.com/lucidrains/performer-pytorch
|
||||||
"""
|
"""
|
||||||
from models.lucidrains.performer.performer_pytorch import PerformerLM
|
from models.lucidrains.performer.performer_pytorch import Performer
|
||||||
model = PerformerLM(dim=model_dim, depth=layers, heads=heads, dim_head=model_dim, causal=True,
|
model = Performer(dim=model_dim, depth=layers, heads=heads, dim_head=model_dim, causal=True)
|
||||||
num_tokens=num_tokens, max_seq_len=max_seq_len)
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
def build_lr_reformer(layers, model_dim, heads, num_tokens, max_seq_len, checkpointing):
|
def build_lr_reformer(layers, model_dim, heads, max_mel_seq_len, max_text_seq_len, checkpointing):
|
||||||
"""
|
"""
|
||||||
lucidrains Reformer implementation, https://github.com/lucidrains/reformer-pytorch
|
lucidrains Reformer implementation, https://github.com/lucidrains/reformer-pytorch
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
def build_lr_xformer(layers, model_dim, heads, num_tokens, max_seq_len, checkpointing):
|
def build_lr_xformer(layers, model_dim, heads, max_mel_seq_len, max_text_seq_len, checkpointing):
|
||||||
"""
|
"""
|
||||||
lucidrains x-transformer implementation, https://github.com/lucidrains/x-transformers
|
lucidrains x-transformer implementation, https://github.com/lucidrains/x-transformers
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -105,14 +105,12 @@ class UnifiedVoice(nn.Module):
|
||||||
self.mel_length_compression = mel_length_compression
|
self.mel_length_compression = mel_length_compression
|
||||||
self.conditioning_encoder = ConditioningEncoder(80, model_dim, num_attn_heads=heads)
|
self.conditioning_encoder = ConditioningEncoder(80, model_dim, num_attn_heads=heads)
|
||||||
self.text_embedding = nn.Embedding(self.number_text_tokens, model_dim)
|
self.text_embedding = nn.Embedding(self.number_text_tokens, model_dim)
|
||||||
self.text_pos_embedding = nn.Embedding(self.max_text_tokens + 2, model_dim)
|
|
||||||
if use_mel_codes_as_input:
|
if use_mel_codes_as_input:
|
||||||
self.mel_embedding = nn.Embedding(self.number_mel_codes, model_dim)
|
self.mel_embedding = nn.Embedding(self.number_mel_codes, model_dim)
|
||||||
else:
|
else:
|
||||||
self.mel_embedding = MelEncoder(model_dim, resblocks_per_reduction=1)
|
self.mel_embedding = MelEncoder(model_dim, resblocks_per_reduction=1)
|
||||||
self.mel_pos_embedding = nn.Embedding(self.max_mel_tokens + 2, model_dim)
|
self.gpt, self.mel_pos_embedding, self.text_pos_embedding, self.mel_layer_pos_embedding, self.text_layer_pos_embedding = \
|
||||||
self.seq_length = 4+max_text_tokens+self.max_mel_tokens+self.max_conditioning_inputs
|
build_hf_gpt_transformer(layers, model_dim, heads, self.max_text_tokens+2, self.max_mel_tokens+3, checkpointing)
|
||||||
self.gpt = build_hf_gpt_transformer(layers, model_dim, heads, number_mel_codes, self.seq_length, checkpointing)
|
|
||||||
if train_solo_embeddings:
|
if train_solo_embeddings:
|
||||||
self.mel_solo_embedding = nn.Parameter(torch.randn(1, 1, model_dim) * .02, requires_grad=True)
|
self.mel_solo_embedding = nn.Parameter(torch.randn(1, 1, model_dim) * .02, requires_grad=True)
|
||||||
self.text_solo_embedding = nn.Parameter(torch.randn(1, 1, model_dim) * .02, requires_grad=True)
|
self.text_solo_embedding = nn.Parameter(torch.randn(1, 1, model_dim) * .02, requires_grad=True)
|
||||||
|
@ -126,13 +124,11 @@ class UnifiedVoice(nn.Module):
|
||||||
self.max_conditioning_length = max_conditioning_length
|
self.max_conditioning_length = max_conditioning_length
|
||||||
|
|
||||||
# Initialize the embeddings per the GPT-2 scheme
|
# Initialize the embeddings per the GPT-2 scheme
|
||||||
embeddings = [self.text_embedding, self.text_pos_embedding, self.mel_pos_embedding]
|
embeddings = [self.text_embedding]
|
||||||
if use_mel_codes_as_input:
|
if use_mel_codes_as_input:
|
||||||
embeddings.append(self.mel_embedding)
|
embeddings.append(self.mel_embedding)
|
||||||
for module in:
|
for module in embeddings:
|
||||||
module.weight.data.normal_(mean=0.0, std=.02)
|
module.weight.data.normal_(mean=0.0, std=.02)
|
||||||
if module.padding_idx is not None:
|
|
||||||
module.weight.data[module.padding_idx].zero_()
|
|
||||||
|
|
||||||
def build_aligned_inputs_and_targets(self, input, start_token, stop_token):
|
def build_aligned_inputs_and_targets(self, input, start_token, stop_token):
|
||||||
inp = F.pad(input, (1,0), value=start_token)
|
inp = F.pad(input, (1,0), value=start_token)
|
||||||
|
@ -218,14 +214,14 @@ class UnifiedVoice(nn.Module):
|
||||||
speech_conditioning_input = self.conditioning_encoder(speech_conditioning_input).unsqueeze(1)
|
speech_conditioning_input = self.conditioning_encoder(speech_conditioning_input).unsqueeze(1)
|
||||||
|
|
||||||
text_inputs, text_targets = self.build_aligned_inputs_and_targets(text_inputs, self.start_text_token, self.stop_text_token)
|
text_inputs, text_targets = self.build_aligned_inputs_and_targets(text_inputs, self.start_text_token, self.stop_text_token)
|
||||||
text_emb = self.text_embedding(text_inputs) + self.text_pos_embedding(torch.arange(text_inputs.shape[1], device=text_inputs.device))
|
text_emb = self.text_embedding(text_inputs) + self.text_pos_embedding(text_inputs)
|
||||||
mel_codes, mel_targets = self.build_aligned_inputs_and_targets(mel_codes, self.start_mel_token, self.stop_mel_token)
|
mel_codes, mel_targets = self.build_aligned_inputs_and_targets(mel_codes, self.start_mel_token, self.stop_mel_token)
|
||||||
if raw_mels is not None:
|
if raw_mels is not None:
|
||||||
mel_inp = F.pad(raw_mels, (0, 8))
|
mel_inp = F.pad(raw_mels, (0, 8))
|
||||||
else:
|
else:
|
||||||
mel_inp = mel_codes
|
mel_inp = mel_codes
|
||||||
mel_emb = self.mel_embedding(mel_inp)
|
mel_emb = self.mel_embedding(mel_inp)
|
||||||
mel_emb = mel_emb + self.mel_pos_embedding(torch.arange(mel_emb.shape[1], device=mel_emb.device))
|
mel_emb = mel_emb + self.mel_pos_embedding(mel_codes)
|
||||||
if text_first:
|
if text_first:
|
||||||
text_logits, mel_logits = self.get_logits(speech_conditioning_input, text_emb, self.text_head, mel_emb, self.mel_head, get_attns=return_attentions)
|
text_logits, mel_logits = self.get_logits(speech_conditioning_input, text_emb, self.text_head, mel_emb, self.mel_head, get_attns=return_attentions)
|
||||||
else:
|
else:
|
||||||
|
@ -254,7 +250,7 @@ class UnifiedVoice(nn.Module):
|
||||||
speech_conditioning_input = self.conditioning_encoder(speech_conditioning_input).unsqueeze(1)
|
speech_conditioning_input = self.conditioning_encoder(speech_conditioning_input).unsqueeze(1)
|
||||||
|
|
||||||
text_inputs, text_targets = self.build_aligned_inputs_and_targets(text_inputs, self.start_text_token, self.stop_text_token)
|
text_inputs, text_targets = self.build_aligned_inputs_and_targets(text_inputs, self.start_text_token, self.stop_text_token)
|
||||||
text_emb = self.text_embedding(text_inputs) + self.text_pos_embedding(torch.arange(text_inputs.shape[1], device=text_inputs.device)) + self.text_solo_embedding
|
text_emb = self.text_embedding(text_inputs) + self.text_pos_embedding(text_inputs) + self.text_solo_embedding
|
||||||
text_logits = self.get_logits(speech_conditioning_input, text_emb, self.text_head)
|
text_logits = self.get_logits(speech_conditioning_input, text_emb, self.text_head)
|
||||||
loss_text = F.cross_entropy(text_logits, text_targets.long())
|
loss_text = F.cross_entropy(text_logits, text_targets.long())
|
||||||
return loss_text.mean()
|
return loss_text.mean()
|
||||||
|
@ -283,7 +279,7 @@ class UnifiedVoice(nn.Module):
|
||||||
else:
|
else:
|
||||||
mel_inp = mel_codes
|
mel_inp = mel_codes
|
||||||
mel_emb = self.mel_embedding(mel_inp)
|
mel_emb = self.mel_embedding(mel_inp)
|
||||||
mel_emb = mel_emb + self.mel_pos_embedding(torch.arange(mel_emb.shape[1], device=mel_emb.device)) + self.mel_solo_embedding
|
mel_emb = mel_emb + self.mel_pos_embedding(mel_codes) + self.mel_solo_embedding
|
||||||
mel_logits = self.get_logits(speech_conditioning_input, mel_emb, self.mel_head)
|
mel_logits = self.get_logits(speech_conditioning_input, mel_emb, self.mel_head)
|
||||||
loss_mel = F.cross_entropy(mel_logits, mel_targets.long())
|
loss_mel = F.cross_entropy(mel_logits, mel_targets.long())
|
||||||
return loss_mel.mean()
|
return loss_mel.mean()
|
||||||
|
@ -291,9 +287,10 @@ class UnifiedVoice(nn.Module):
|
||||||
def inference_speech(self, speech_conditioning_input, text_inputs, **hf_generate_kwargs):
|
def inference_speech(self, speech_conditioning_input, text_inputs, **hf_generate_kwargs):
|
||||||
if not hasattr(self, 'inference_model'):
|
if not hasattr(self, 'inference_model'):
|
||||||
# TODO: Decouple gpt_config from this inference model.
|
# TODO: Decouple gpt_config from this inference model.
|
||||||
|
seq_length = self.max_mel_tokens + self.max_text_tokens + 5
|
||||||
gpt_config = GPT2Config(vocab_size=self.max_mel_tokens,
|
gpt_config = GPT2Config(vocab_size=self.max_mel_tokens,
|
||||||
n_positions=self.seq_length,
|
n_positions=seq_length,
|
||||||
n_ctx=self.seq_length,
|
n_ctx=seq_length,
|
||||||
n_embd=self.model_dim,
|
n_embd=self.model_dim,
|
||||||
n_layer=self.layers,
|
n_layer=self.layers,
|
||||||
n_head=self.heads,
|
n_head=self.heads,
|
||||||
|
@ -303,7 +300,7 @@ class UnifiedVoice(nn.Module):
|
||||||
|
|
||||||
text_inputs = F.pad(text_inputs, (0, 1), value=self.stop_text_token)
|
text_inputs = F.pad(text_inputs, (0, 1), value=self.stop_text_token)
|
||||||
text_inputs, text_targets = self.build_aligned_inputs_and_targets(text_inputs, self.start_text_token, self.stop_text_token)
|
text_inputs, text_targets = self.build_aligned_inputs_and_targets(text_inputs, self.start_text_token, self.stop_text_token)
|
||||||
text_emb = self.text_embedding(text_inputs) + self.text_pos_embedding(torch.arange(text_inputs.shape[1], device=text_inputs.device))
|
text_emb = self.text_embedding(text_inputs) + self.text_pos_embedding(text_inputs)
|
||||||
|
|
||||||
if self.shuffle_conditioning:
|
if self.shuffle_conditioning:
|
||||||
# Randomly permute the conditioning spectrogram, to destroy any structure present.
|
# Randomly permute the conditioning spectrogram, to destroy any structure present.
|
||||||
|
|
Loading…
Reference in New Issue
Block a user