forked from mrq/DL-Art-School
Revert unified_voice back to beginning
I'll be doing my work within unified_voice2
This commit is contained in:
parent
432073c5ca
commit
ec456b6733
|
@ -8,7 +8,6 @@ from transformers import GPT2Model, GPT2Config
|
||||||
from models.arch_util import AttentionBlock
|
from models.arch_util import AttentionBlock
|
||||||
from models.gpt_voice.gpt_asr_hf import GPT2InferenceModel
|
from models.gpt_voice.gpt_asr_hf import GPT2InferenceModel
|
||||||
from models.gpt_voice.gpt_asr_hf2 import ResBlock
|
from models.gpt_voice.gpt_asr_hf2 import ResBlock
|
||||||
from models.gpt_voice.transformer_builders import build_hf_gpt_transformer
|
|
||||||
from models.tacotron2.text import symbols
|
from models.tacotron2.text import symbols
|
||||||
from trainer.networks import register_model
|
from trainer.networks import register_model
|
||||||
from utils.util import opt_get
|
from utils.util import opt_get
|
||||||
|
@ -60,6 +59,10 @@ class MelEncoder(nn.Module):
|
||||||
return x.permute(0,2,1)
|
return x.permute(0,2,1)
|
||||||
|
|
||||||
|
|
||||||
|
def null_position_embeddings(range, dim):
|
||||||
|
return torch.zeros((range.shape[0], range.shape[1], dim), device=range.device)
|
||||||
|
|
||||||
|
|
||||||
class UnifiedGptVoice(nn.Module):
|
class UnifiedGptVoice(nn.Module):
|
||||||
"""
|
"""
|
||||||
Derived from GptTtsHf, but offers multiple modes of autoregressive operation:
|
Derived from GptTtsHf, but offers multiple modes of autoregressive operation:
|
||||||
|
@ -104,8 +107,7 @@ class UnifiedGptVoice(nn.Module):
|
||||||
self.start_mel_token = start_mel_token
|
self.start_mel_token = start_mel_token
|
||||||
self.stop_mel_token = stop_mel_token
|
self.stop_mel_token = stop_mel_token
|
||||||
self.shuffle_conditioning = shuffle_conditioning
|
self.shuffle_conditioning = shuffle_conditioning
|
||||||
self.layers = layers
|
|
||||||
self.heads = heads
|
|
||||||
self.max_mel_tokens = max_mel_tokens
|
self.max_mel_tokens = max_mel_tokens
|
||||||
self.max_text_tokens = max_text_tokens
|
self.max_text_tokens = max_text_tokens
|
||||||
self.model_dim = model_dim
|
self.model_dim = model_dim
|
||||||
|
@ -115,14 +117,25 @@ class UnifiedGptVoice(nn.Module):
|
||||||
self.text_embedding = nn.Embedding(self.number_text_tokens, model_dim)
|
self.text_embedding = nn.Embedding(self.number_text_tokens, model_dim)
|
||||||
self.text_pos_embedding = nn.Embedding(self.max_text_tokens + 2, model_dim)
|
self.text_pos_embedding = nn.Embedding(self.max_text_tokens + 2, model_dim)
|
||||||
self.mel_pos_embedding = nn.Embedding(self.max_mel_tokens + 2, model_dim)
|
self.mel_pos_embedding = nn.Embedding(self.max_mel_tokens + 2, model_dim)
|
||||||
self.seq_length = 4+max_text_tokens+self.max_mel_tokens+self.max_conditioning_inputs
|
seq_length = 4+max_text_tokens+self.max_mel_tokens+self.max_conditioning_inputs
|
||||||
self.gpt = build_hf_gpt_transformer(layers, model_dim, heads, number_mel_codes, self.seq_length, checkpointing)
|
self.gpt_config = GPT2Config(vocab_size=self.number_mel_codes,
|
||||||
|
n_positions=seq_length,
|
||||||
|
n_ctx=seq_length,
|
||||||
|
n_embd=model_dim,
|
||||||
|
n_layer=layers,
|
||||||
|
n_head=heads,
|
||||||
|
gradient_checkpointing=checkpointing,
|
||||||
|
use_cache=not checkpointing)
|
||||||
|
self.gpt = GPT2Model(self.gpt_config)
|
||||||
if train_solo_embeddings:
|
if train_solo_embeddings:
|
||||||
self.mel_solo_embedding = nn.Parameter(torch.randn(1, 1, model_dim) * .02, requires_grad=True)
|
self.mel_solo_embedding = nn.Parameter(torch.randn(1, 1, model_dim) * self.gpt.config.initializer_range, requires_grad=True)
|
||||||
self.text_solo_embedding = nn.Parameter(torch.randn(1, 1, model_dim) * .02, requires_grad=True)
|
self.text_solo_embedding = nn.Parameter(torch.randn(1, 1, model_dim) * self.gpt.config.initializer_range, requires_grad=True)
|
||||||
else:
|
else:
|
||||||
self.mel_solo_embedding = 0
|
self.mel_solo_embedding = 0
|
||||||
self.text_solo_embedding = 0
|
self.text_solo_embedding = 0
|
||||||
|
# Override the built in positional embeddings
|
||||||
|
del self.gpt.wpe
|
||||||
|
self.gpt.wpe = functools.partial(null_position_embeddings, dim=model_dim)
|
||||||
|
|
||||||
if not use_mel_codes_as_input:
|
if not use_mel_codes_as_input:
|
||||||
self.gpt.wte = MelEncoder(model_dim, resblocks_per_reduction=1)
|
self.gpt.wte = MelEncoder(model_dim, resblocks_per_reduction=1)
|
||||||
|
@ -301,16 +314,7 @@ class UnifiedGptVoice(nn.Module):
|
||||||
|
|
||||||
def inference_speech(self, speech_conditioning_input, text_inputs, **hf_generate_kwargs):
|
def inference_speech(self, speech_conditioning_input, text_inputs, **hf_generate_kwargs):
|
||||||
if not hasattr(self, 'inference_model'):
|
if not hasattr(self, 'inference_model'):
|
||||||
# TODO: Decouple gpt_config from this inference model.
|
self.inference_model = GPT2InferenceModel(self.gpt_config, self.gpt, self.mel_pos_embedding, self.final_norm, self.mel_head)
|
||||||
gpt_config = GPT2Config(vocab_size=self.max_mel_tokens,
|
|
||||||
n_positions=self.seq_length,
|
|
||||||
n_ctx=self.seq_length,
|
|
||||||
n_embd=self.model_dim,
|
|
||||||
n_layer=self.layers,
|
|
||||||
n_head=self.heads,
|
|
||||||
gradient_checkpointing=False,
|
|
||||||
use_cache=True)
|
|
||||||
self.inference_model = GPT2InferenceModel(gpt_config, self.gpt, self.mel_pos_embedding, self.final_norm, self.mel_head)
|
|
||||||
|
|
||||||
text_inputs = F.pad(text_inputs, (0, 1), value=self.stop_text_token)
|
text_inputs = F.pad(text_inputs, (0, 1), value=self.stop_text_token)
|
||||||
text_inputs, text_targets = self.build_aligned_inputs_and_targets(text_inputs, self.start_text_token, self.stop_text_token)
|
text_inputs, text_targets = self.build_aligned_inputs_and_targets(text_inputs, self.start_text_token, self.stop_text_token)
|
||||||
|
@ -328,7 +332,7 @@ class UnifiedGptVoice(nn.Module):
|
||||||
fake_inputs[:,-1] = self.start_mel_token
|
fake_inputs[:,-1] = self.start_mel_token
|
||||||
|
|
||||||
gen = self.inference_model.generate(fake_inputs, bos_token_id=self.start_mel_token, pad_token_id=self.stop_mel_token, eos_token_id=self.stop_mel_token,
|
gen = self.inference_model.generate(fake_inputs, bos_token_id=self.start_mel_token, pad_token_id=self.stop_mel_token, eos_token_id=self.stop_mel_token,
|
||||||
max_length=self.seq_length, **hf_generate_kwargs)
|
max_length=self.gpt_config.n_positions, **hf_generate_kwargs)
|
||||||
return gen[:, fake_inputs.shape[1]:]
|
return gen[:, fake_inputs.shape[1]:]
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user