Revert unified_voice back to beginning

I'll be doing my work within unified_voice2
This commit is contained in:
James Betker 2022-01-09 22:34:30 -07:00
parent 432073c5ca
commit ec456b6733

View File

@ -8,7 +8,6 @@ from transformers import GPT2Model, GPT2Config
from models.arch_util import AttentionBlock
from models.gpt_voice.gpt_asr_hf import GPT2InferenceModel
from models.gpt_voice.gpt_asr_hf2 import ResBlock
from models.gpt_voice.transformer_builders import build_hf_gpt_transformer
from models.tacotron2.text import symbols
from trainer.networks import register_model
from utils.util import opt_get
@ -60,6 +59,10 @@ class MelEncoder(nn.Module):
return x.permute(0,2,1)
def null_position_embeddings(range, dim):
return torch.zeros((range.shape[0], range.shape[1], dim), device=range.device)
class UnifiedGptVoice(nn.Module):
"""
Derived from GptTtsHf, but offers multiple modes of autoregressive operation:
@ -104,8 +107,7 @@ class UnifiedGptVoice(nn.Module):
self.start_mel_token = start_mel_token
self.stop_mel_token = stop_mel_token
self.shuffle_conditioning = shuffle_conditioning
self.layers = layers
self.heads = heads
self.max_mel_tokens = max_mel_tokens
self.max_text_tokens = max_text_tokens
self.model_dim = model_dim
@ -115,14 +117,25 @@ class UnifiedGptVoice(nn.Module):
self.text_embedding = nn.Embedding(self.number_text_tokens, model_dim)
self.text_pos_embedding = nn.Embedding(self.max_text_tokens + 2, model_dim)
self.mel_pos_embedding = nn.Embedding(self.max_mel_tokens + 2, model_dim)
self.seq_length = 4+max_text_tokens+self.max_mel_tokens+self.max_conditioning_inputs
self.gpt = build_hf_gpt_transformer(layers, model_dim, heads, number_mel_codes, self.seq_length, checkpointing)
seq_length = 4+max_text_tokens+self.max_mel_tokens+self.max_conditioning_inputs
self.gpt_config = GPT2Config(vocab_size=self.number_mel_codes,
n_positions=seq_length,
n_ctx=seq_length,
n_embd=model_dim,
n_layer=layers,
n_head=heads,
gradient_checkpointing=checkpointing,
use_cache=not checkpointing)
self.gpt = GPT2Model(self.gpt_config)
if train_solo_embeddings:
self.mel_solo_embedding = nn.Parameter(torch.randn(1, 1, model_dim) * .02, requires_grad=True)
self.text_solo_embedding = nn.Parameter(torch.randn(1, 1, model_dim) * .02, requires_grad=True)
self.mel_solo_embedding = nn.Parameter(torch.randn(1, 1, model_dim) * self.gpt.config.initializer_range, requires_grad=True)
self.text_solo_embedding = nn.Parameter(torch.randn(1, 1, model_dim) * self.gpt.config.initializer_range, requires_grad=True)
else:
self.mel_solo_embedding = 0
self.text_solo_embedding = 0
# Override the built in positional embeddings
del self.gpt.wpe
self.gpt.wpe = functools.partial(null_position_embeddings, dim=model_dim)
if not use_mel_codes_as_input:
self.gpt.wte = MelEncoder(model_dim, resblocks_per_reduction=1)
@ -301,16 +314,7 @@ class UnifiedGptVoice(nn.Module):
def inference_speech(self, speech_conditioning_input, text_inputs, **hf_generate_kwargs):
if not hasattr(self, 'inference_model'):
# TODO: Decouple gpt_config from this inference model.
gpt_config = GPT2Config(vocab_size=self.max_mel_tokens,
n_positions=self.seq_length,
n_ctx=self.seq_length,
n_embd=self.model_dim,
n_layer=self.layers,
n_head=self.heads,
gradient_checkpointing=False,
use_cache=True)
self.inference_model = GPT2InferenceModel(gpt_config, self.gpt, self.mel_pos_embedding, self.final_norm, self.mel_head)
self.inference_model = GPT2InferenceModel(self.gpt_config, self.gpt, self.mel_pos_embedding, self.final_norm, self.mel_head)
text_inputs = F.pad(text_inputs, (0, 1), value=self.stop_text_token)
text_inputs, text_targets = self.build_aligned_inputs_and_targets(text_inputs, self.start_text_token, self.stop_text_token)
@ -328,7 +332,7 @@ class UnifiedGptVoice(nn.Module):
fake_inputs[:,-1] = self.start_mel_token
gen = self.inference_model.generate(fake_inputs, bos_token_id=self.start_mel_token, pad_token_id=self.stop_mel_token, eos_token_id=self.stop_mel_token,
max_length=self.seq_length, **hf_generate_kwargs)
max_length=self.gpt_config.n_positions, **hf_generate_kwargs)
return gen[:, fake_inputs.shape[1]:]