forked from mrq/DL-Art-School
Revert unified_voice back to beginning
I'll be doing my work within unified_voice2
This commit is contained in:
parent
432073c5ca
commit
ec456b6733
|
@ -8,7 +8,6 @@ from transformers import GPT2Model, GPT2Config
|
|||
from models.arch_util import AttentionBlock
|
||||
from models.gpt_voice.gpt_asr_hf import GPT2InferenceModel
|
||||
from models.gpt_voice.gpt_asr_hf2 import ResBlock
|
||||
from models.gpt_voice.transformer_builders import build_hf_gpt_transformer
|
||||
from models.tacotron2.text import symbols
|
||||
from trainer.networks import register_model
|
||||
from utils.util import opt_get
|
||||
|
@ -60,6 +59,10 @@ class MelEncoder(nn.Module):
|
|||
return x.permute(0,2,1)
|
||||
|
||||
|
||||
def null_position_embeddings(range, dim):
|
||||
return torch.zeros((range.shape[0], range.shape[1], dim), device=range.device)
|
||||
|
||||
|
||||
class UnifiedGptVoice(nn.Module):
|
||||
"""
|
||||
Derived from GptTtsHf, but offers multiple modes of autoregressive operation:
|
||||
|
@ -104,8 +107,7 @@ class UnifiedGptVoice(nn.Module):
|
|||
self.start_mel_token = start_mel_token
|
||||
self.stop_mel_token = stop_mel_token
|
||||
self.shuffle_conditioning = shuffle_conditioning
|
||||
self.layers = layers
|
||||
self.heads = heads
|
||||
|
||||
self.max_mel_tokens = max_mel_tokens
|
||||
self.max_text_tokens = max_text_tokens
|
||||
self.model_dim = model_dim
|
||||
|
@ -115,14 +117,25 @@ class UnifiedGptVoice(nn.Module):
|
|||
self.text_embedding = nn.Embedding(self.number_text_tokens, model_dim)
|
||||
self.text_pos_embedding = nn.Embedding(self.max_text_tokens + 2, model_dim)
|
||||
self.mel_pos_embedding = nn.Embedding(self.max_mel_tokens + 2, model_dim)
|
||||
self.seq_length = 4+max_text_tokens+self.max_mel_tokens+self.max_conditioning_inputs
|
||||
self.gpt = build_hf_gpt_transformer(layers, model_dim, heads, number_mel_codes, self.seq_length, checkpointing)
|
||||
seq_length = 4+max_text_tokens+self.max_mel_tokens+self.max_conditioning_inputs
|
||||
self.gpt_config = GPT2Config(vocab_size=self.number_mel_codes,
|
||||
n_positions=seq_length,
|
||||
n_ctx=seq_length,
|
||||
n_embd=model_dim,
|
||||
n_layer=layers,
|
||||
n_head=heads,
|
||||
gradient_checkpointing=checkpointing,
|
||||
use_cache=not checkpointing)
|
||||
self.gpt = GPT2Model(self.gpt_config)
|
||||
if train_solo_embeddings:
|
||||
self.mel_solo_embedding = nn.Parameter(torch.randn(1, 1, model_dim) * .02, requires_grad=True)
|
||||
self.text_solo_embedding = nn.Parameter(torch.randn(1, 1, model_dim) * .02, requires_grad=True)
|
||||
self.mel_solo_embedding = nn.Parameter(torch.randn(1, 1, model_dim) * self.gpt.config.initializer_range, requires_grad=True)
|
||||
self.text_solo_embedding = nn.Parameter(torch.randn(1, 1, model_dim) * self.gpt.config.initializer_range, requires_grad=True)
|
||||
else:
|
||||
self.mel_solo_embedding = 0
|
||||
self.text_solo_embedding = 0
|
||||
# Override the built in positional embeddings
|
||||
del self.gpt.wpe
|
||||
self.gpt.wpe = functools.partial(null_position_embeddings, dim=model_dim)
|
||||
|
||||
if not use_mel_codes_as_input:
|
||||
self.gpt.wte = MelEncoder(model_dim, resblocks_per_reduction=1)
|
||||
|
@ -301,16 +314,7 @@ class UnifiedGptVoice(nn.Module):
|
|||
|
||||
def inference_speech(self, speech_conditioning_input, text_inputs, **hf_generate_kwargs):
|
||||
if not hasattr(self, 'inference_model'):
|
||||
# TODO: Decouple gpt_config from this inference model.
|
||||
gpt_config = GPT2Config(vocab_size=self.max_mel_tokens,
|
||||
n_positions=self.seq_length,
|
||||
n_ctx=self.seq_length,
|
||||
n_embd=self.model_dim,
|
||||
n_layer=self.layers,
|
||||
n_head=self.heads,
|
||||
gradient_checkpointing=False,
|
||||
use_cache=True)
|
||||
self.inference_model = GPT2InferenceModel(gpt_config, self.gpt, self.mel_pos_embedding, self.final_norm, self.mel_head)
|
||||
self.inference_model = GPT2InferenceModel(self.gpt_config, self.gpt, self.mel_pos_embedding, self.final_norm, self.mel_head)
|
||||
|
||||
text_inputs = F.pad(text_inputs, (0, 1), value=self.stop_text_token)
|
||||
text_inputs, text_targets = self.build_aligned_inputs_and_targets(text_inputs, self.start_text_token, self.stop_text_token)
|
||||
|
@ -328,7 +332,7 @@ class UnifiedGptVoice(nn.Module):
|
|||
fake_inputs[:,-1] = self.start_mel_token
|
||||
|
||||
gen = self.inference_model.generate(fake_inputs, bos_token_id=self.start_mel_token, pad_token_id=self.stop_mel_token, eos_token_id=self.stop_mel_token,
|
||||
max_length=self.seq_length, **hf_generate_kwargs)
|
||||
max_length=self.gpt_config.n_positions, **hf_generate_kwargs)
|
||||
return gen[:, fake_inputs.shape[1]:]
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user