forked from mrq/DL-Art-School
unified_voice with rotary embeddings
This commit is contained in:
parent
573e5552b9
commit
3f8d7955ef
|
@ -59,8 +59,8 @@ def build_hf_gpt_transformer(layers, model_dim, heads, max_mel_seq_len, max_text
|
||||||
"""
|
"""
|
||||||
from transformers import GPT2Config, GPT2Model
|
from transformers import GPT2Config, GPT2Model
|
||||||
gpt_config = GPT2Config(vocab_size=256, # Unused.
|
gpt_config = GPT2Config(vocab_size=256, # Unused.
|
||||||
n_positions=max_mel_seq_len+max_text_seq_len,
|
n_positions=1,
|
||||||
n_ctx=max_mel_seq_len+max_text_seq_len,
|
n_ctx=1,
|
||||||
n_embd=model_dim,
|
n_embd=model_dim,
|
||||||
n_layer=layers,
|
n_layer=layers,
|
||||||
n_head=heads,
|
n_head=heads,
|
||||||
|
@ -72,8 +72,10 @@ def build_hf_gpt_transformer(layers, model_dim, heads, max_mel_seq_len, max_text
|
||||||
gpt.wpe = functools.partial(null_position_embeddings, dim=model_dim)
|
gpt.wpe = functools.partial(null_position_embeddings, dim=model_dim)
|
||||||
# Built-in token embeddings are unused.
|
# Built-in token embeddings are unused.
|
||||||
del gpt.wte
|
del gpt.wte
|
||||||
return gpt, LearnedPositionEmbeddings(max_mel_seq_len, model_dim), LearnedPositionEmbeddings(max_text_seq_len, model_dim),\
|
|
||||||
None, None
|
mel_pos_emb = LearnedPositionEmbeddings(max_mel_seq_len, model_dim) if max_mel_seq_len != -1 else functools.partial(null_position_embeddings, dim=model_dim)
|
||||||
|
text_pos_emb = LearnedPositionEmbeddings(max_text_seq_len, model_dim) if max_mel_seq_len != -1 else functools.partial(null_position_embeddings, dim=model_dim)
|
||||||
|
return gpt, mel_pos_emb, text_pos_emb, None, None
|
||||||
|
|
||||||
|
|
||||||
def build_lr_performer(layers, model_dim, heads, max_mel_seq_len, max_text_seq_len, checkpointing):
|
def build_lr_performer(layers, model_dim, heads, max_mel_seq_len, max_text_seq_len, checkpointing):
|
||||||
|
|
|
@ -3,10 +3,12 @@ import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from transformers import GPT2Config, GPT2PreTrainedModel
|
from transformers import GPT2Config, GPT2PreTrainedModel
|
||||||
from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
|
from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
|
||||||
|
from transformers.models.gpt2.modeling_gpt2 import GPT2Attention
|
||||||
from transformers.utils.model_parallel_utils import get_device_map, assert_device_map
|
from transformers.utils.model_parallel_utils import get_device_map, assert_device_map
|
||||||
|
|
||||||
from models.arch_util import AttentionBlock
|
from models.arch_util import AttentionBlock
|
||||||
from models.audio.tts.transformer_builders import build_hf_gpt_transformer
|
from models.audio.tts.transformer_builders import build_hf_gpt_transformer
|
||||||
|
from models.lucidrains.x_transformers import RotaryEmbedding, apply_rotary_pos_emb
|
||||||
from trainer.networks import register_model
|
from trainer.networks import register_model
|
||||||
from utils.util import opt_get
|
from utils.util import opt_get
|
||||||
|
|
||||||
|
@ -183,6 +185,73 @@ class GPT2InferenceModel(GPT2PreTrainedModel):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class GPT2AttentionWithRotaryEmbeddings(GPT2Attention):
|
||||||
|
def __init__(self, config, is_cross_attention=False, layer_idx=None):
|
||||||
|
super().__init__(config, is_cross_attention=is_cross_attention, layer_idx=layer_idx)
|
||||||
|
self.rotary_pos_emb = RotaryEmbedding(32)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states,
|
||||||
|
layer_past=None,
|
||||||
|
attention_mask=None,
|
||||||
|
head_mask=None,
|
||||||
|
encoder_hidden_states=None,
|
||||||
|
encoder_attention_mask=None,
|
||||||
|
use_cache=False,
|
||||||
|
output_attentions=False,
|
||||||
|
):
|
||||||
|
if encoder_hidden_states is not None:
|
||||||
|
if not hasattr(self, "q_attn"):
|
||||||
|
raise ValueError(
|
||||||
|
"If class is used as cross attention, the weights `q_attn` have to be defined. "
|
||||||
|
"Please make sure to instantiate class with `GPT2Attention(..., is_cross_attention=True)`."
|
||||||
|
)
|
||||||
|
|
||||||
|
query = self.q_attn(hidden_states)
|
||||||
|
key, value = self.c_attn(encoder_hidden_states).split(self.split_size, dim=2)
|
||||||
|
attention_mask = encoder_attention_mask
|
||||||
|
else:
|
||||||
|
query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2)
|
||||||
|
|
||||||
|
query = self._split_heads(query, self.num_heads, self.head_dim)
|
||||||
|
key = self._split_heads(key, self.num_heads, self.head_dim)
|
||||||
|
value = self._split_heads(value, self.num_heads, self.head_dim)
|
||||||
|
|
||||||
|
if layer_past is not None:
|
||||||
|
past_key, past_value = layer_past
|
||||||
|
key = torch.cat((past_key, key), dim=-2)
|
||||||
|
value = torch.cat((past_value, value), dim=-2)
|
||||||
|
|
||||||
|
if use_cache is True:
|
||||||
|
present = (key, value)
|
||||||
|
else:
|
||||||
|
present = None
|
||||||
|
|
||||||
|
# Apply rotary embeddings. This is the only difference between this implementation and the HF one.
|
||||||
|
rotary_pos_emb = self.rotary_pos_emb(hidden_states.shape[1], hidden_states.device)
|
||||||
|
l = rotary_pos_emb.shape[-1]
|
||||||
|
(ql, qr), (kl, kr), (vl, vr) = map(lambda t: (t[..., :l], t[..., l:]), (query, key, value))
|
||||||
|
ql, kl, vl = map(lambda t: apply_rotary_pos_emb(t, rotary_pos_emb), (ql, kl, vl))
|
||||||
|
query, key, value = map(lambda t: torch.cat(t, dim=-1), ((ql, qr), (kl, kr), (vl, vr)))
|
||||||
|
|
||||||
|
if self.reorder_and_upcast_attn:
|
||||||
|
attn_output, attn_weights = self._upcast_and_reordered_attn(query, key, value, attention_mask, head_mask)
|
||||||
|
else:
|
||||||
|
attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask)
|
||||||
|
|
||||||
|
attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim)
|
||||||
|
attn_output = self.c_proj(attn_output)
|
||||||
|
attn_output = self.resid_dropout(attn_output)
|
||||||
|
|
||||||
|
outputs = (attn_output, present)
|
||||||
|
if output_attentions:
|
||||||
|
outputs += (attn_weights,)
|
||||||
|
|
||||||
|
return outputs # a, present, (attentions)
|
||||||
|
|
||||||
|
|
||||||
class ConditioningEncoder(nn.Module):
|
class ConditioningEncoder(nn.Module):
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
spec_dim,
|
spec_dim,
|
||||||
|
@ -239,7 +308,7 @@ class UnifiedVoice(nn.Module):
|
||||||
mel_length_compression=1024, number_text_tokens=256,
|
mel_length_compression=1024, number_text_tokens=256,
|
||||||
start_text_token=255, stop_text_token=0, number_mel_codes=8194, start_mel_token=8192,
|
start_text_token=255, stop_text_token=0, number_mel_codes=8194, start_mel_token=8192,
|
||||||
stop_mel_token=8193, train_solo_embeddings=False, use_mel_codes_as_input=True,
|
stop_mel_token=8193, train_solo_embeddings=False, use_mel_codes_as_input=True,
|
||||||
checkpointing=True, average_conditioning_embeddings=False):
|
checkpointing=True, average_conditioning_embeddings=False, use_rotary_embeddings=False):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
layers: Number of layers in transformer stack.
|
layers: Number of layers in transformer stack.
|
||||||
|
@ -270,8 +339,8 @@ class UnifiedVoice(nn.Module):
|
||||||
self.stop_mel_token = stop_mel_token
|
self.stop_mel_token = stop_mel_token
|
||||||
self.layers = layers
|
self.layers = layers
|
||||||
self.heads = heads
|
self.heads = heads
|
||||||
self.max_mel_tokens = max_mel_tokens
|
self.max_mel_tokens = -1 if max_mel_tokens == -1 else max_mel_tokens+2+self.max_conditioning_inputs
|
||||||
self.max_text_tokens = max_text_tokens
|
self.max_text_tokens = -1 if max_text_tokens == -1 else max_text_tokens+2
|
||||||
self.model_dim = model_dim
|
self.model_dim = model_dim
|
||||||
self.max_conditioning_inputs = max_conditioning_inputs
|
self.max_conditioning_inputs = max_conditioning_inputs
|
||||||
self.mel_length_compression = mel_length_compression
|
self.mel_length_compression = mel_length_compression
|
||||||
|
@ -283,7 +352,7 @@ class UnifiedVoice(nn.Module):
|
||||||
else:
|
else:
|
||||||
self.mel_embedding = MelEncoder(model_dim, resblocks_per_reduction=1)
|
self.mel_embedding = MelEncoder(model_dim, resblocks_per_reduction=1)
|
||||||
self.gpt, self.mel_pos_embedding, self.text_pos_embedding, self.mel_layer_pos_embedding, self.text_layer_pos_embedding = \
|
self.gpt, self.mel_pos_embedding, self.text_pos_embedding, self.mel_layer_pos_embedding, self.text_layer_pos_embedding = \
|
||||||
build_hf_gpt_transformer(layers, model_dim, heads, self.max_mel_tokens+2+self.max_conditioning_inputs, self.max_text_tokens+2, checkpointing)
|
build_hf_gpt_transformer(layers, model_dim, heads, self.max_mel_tokens, self.max_text_tokens, checkpointing)
|
||||||
if train_solo_embeddings:
|
if train_solo_embeddings:
|
||||||
self.mel_solo_embedding = nn.Parameter(torch.randn(1, 1, model_dim) * .02, requires_grad=True)
|
self.mel_solo_embedding = nn.Parameter(torch.randn(1, 1, model_dim) * .02, requires_grad=True)
|
||||||
self.text_solo_embedding = nn.Parameter(torch.randn(1, 1, model_dim) * .02, requires_grad=True)
|
self.text_solo_embedding = nn.Parameter(torch.randn(1, 1, model_dim) * .02, requires_grad=True)
|
||||||
|
@ -291,6 +360,11 @@ class UnifiedVoice(nn.Module):
|
||||||
self.mel_solo_embedding = 0
|
self.mel_solo_embedding = 0
|
||||||
self.text_solo_embedding = 0
|
self.text_solo_embedding = 0
|
||||||
|
|
||||||
|
if use_rotary_embeddings:
|
||||||
|
# We must re-build all the attention layers as type GPT2AttentionWithRotaryEmbeddings.
|
||||||
|
for blk in self.gpt.h:
|
||||||
|
blk.attn = GPT2AttentionWithRotaryEmbeddings(self.gpt.config, layer_idx=blk.attn.layer_idx)
|
||||||
|
|
||||||
self.final_norm = nn.LayerNorm(model_dim)
|
self.final_norm = nn.LayerNorm(model_dim)
|
||||||
self.text_head = nn.Linear(model_dim, self.number_text_tokens)
|
self.text_head = nn.Linear(model_dim, self.number_text_tokens)
|
||||||
self.mel_head = nn.Linear(model_dim, self.number_mel_codes)
|
self.mel_head = nn.Linear(model_dim, self.number_mel_codes)
|
||||||
|
@ -371,9 +445,6 @@ class UnifiedVoice(nn.Module):
|
||||||
If return_attentions is specified, only logits are returned.
|
If return_attentions is specified, only logits are returned.
|
||||||
If return_latent is specified, loss & logits are not computed or returned. Only the predicted latents are returned.
|
If return_latent is specified, loss & logits are not computed or returned. Only the predicted latents are returned.
|
||||||
"""
|
"""
|
||||||
assert self.max_mel_tokens >= mel_codes.shape[1], f'{mel_codes.shape[1]}'
|
|
||||||
assert self.max_text_tokens >= text_inputs.shape[1], f'{text_inputs.shape[1]}'
|
|
||||||
|
|
||||||
# This model will receive micro-batches with a ton of padding for both the text and MELs. Ameliorate this by
|
# This model will receive micro-batches with a ton of padding for both the text and MELs. Ameliorate this by
|
||||||
# chopping the inputs by the maximum actual length.
|
# chopping the inputs by the maximum actual length.
|
||||||
max_text_len = text_lengths.max()
|
max_text_len = text_lengths.max()
|
||||||
|
@ -422,8 +493,6 @@ class UnifiedVoice(nn.Module):
|
||||||
Performs autoregressive modeling on only text. Still requires a speech_conditioning_input due to the way the
|
Performs autoregressive modeling on only text. Still requires a speech_conditioning_input due to the way the
|
||||||
model inputs are formatted. Just provide any audio clip (arguably, zeros could be provided).
|
model inputs are formatted. Just provide any audio clip (arguably, zeros could be provided).
|
||||||
"""
|
"""
|
||||||
assert self.max_text_tokens >= text_inputs.shape[1], f'{text_inputs.shape[1]}'
|
|
||||||
|
|
||||||
# This model will receive micro-batches with a ton of padding for both the text and MELs. Ameliorate this by
|
# This model will receive micro-batches with a ton of padding for both the text and MELs. Ameliorate this by
|
||||||
# chopping the inputs by the maximum actual length.
|
# chopping the inputs by the maximum actual length.
|
||||||
max_text_len = text_lengths.max()
|
max_text_len = text_lengths.max()
|
||||||
|
@ -477,7 +546,10 @@ class UnifiedVoice(nn.Module):
|
||||||
return loss_mel.mean()
|
return loss_mel.mean()
|
||||||
|
|
||||||
def inference_speech(self, speech_conditioning_input, text_inputs, return_attentions=False, **hf_generate_kwargs):
|
def inference_speech(self, speech_conditioning_input, text_inputs, return_attentions=False, **hf_generate_kwargs):
|
||||||
seq_length = self.max_mel_tokens + self.max_text_tokens + 2
|
if self.max_mel_tokens == -1: # Assume if this is the case, max_mel_tokens=-1 also
|
||||||
|
seq_length = 2002 # Arbitrary default.
|
||||||
|
else:
|
||||||
|
seq_length = self.max_mel_tokens + self.max_text_tokens + 2
|
||||||
if not hasattr(self, 'inference_model'):
|
if not hasattr(self, 'inference_model'):
|
||||||
# TODO: Decouple gpt_config from this inference model.
|
# TODO: Decouple gpt_config from this inference model.
|
||||||
gpt_config = GPT2Config(vocab_size=self.max_mel_tokens,
|
gpt_config = GPT2Config(vocab_size=self.max_mel_tokens,
|
||||||
|
@ -566,10 +638,11 @@ def register_unified_voice2(opt_net, opt):
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
gpt = UnifiedVoice(model_dim=256, heads=4, train_solo_embeddings=True, use_mel_codes_as_input=True, max_conditioning_inputs=4)
|
gpt = UnifiedVoice(model_dim=256, heads=4, train_solo_embeddings=True, use_mel_codes_as_input=True, max_conditioning_inputs=4,
|
||||||
|
use_rotary_embeddings=True, max_mel_tokens=-1, max_text_tokens=-1)
|
||||||
l = gpt(torch.randn(2, 3, 80, 800),
|
l = gpt(torch.randn(2, 3, 80, 800),
|
||||||
torch.randint(high=256, size=(2,120)),
|
torch.randint(high=256, size=(2,120)),
|
||||||
torch.tensor([32, 120]),
|
torch.tensor([32, 120]),
|
||||||
torch.randint(high=8192, size=(2,250)),
|
torch.randint(high=8192, size=(2,250)),
|
||||||
torch.tensor([250*256,195*256]))
|
torch.tensor([250*256,195*256]))
|
||||||
gpt.text_forward(torch.randn(2,80,800), torch.randint(high=50, size=(2,80)), torch.tensor([32, 80]))
|
#gpt.text_forward(torch.randn(2,80,800), torch.randint(high=50, size=(2,80)), torch.tensor([32, 80]))
|
||||||
|
|
|
@ -327,7 +327,7 @@ class Trainer:
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_clip_text_to_voice.yml')
|
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_clvp.yml')
|
||||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
opt = option.parse(args.opt, is_train=True)
|
opt = option.parse(args.opt, is_train=True)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user