forked from mrq/DL-Art-School
Shuffle conditioning inputs along the positional axis to reduce fitting on prosody and other positional information
The mels should still retain some short-range positional information the model can use for tone and frequencies, for example.
This commit is contained in:
parent
53858b2055
commit
48e3ee9a5b
|
@ -8,6 +8,7 @@ from transformers import GPT2Model, GPT2Config, GPT2LMHeadModel, GPT2PreTrainedM
|
||||||
from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
|
from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
|
||||||
from transformers.utils.model_parallel_utils import get_device_map, assert_device_map
|
from transformers.utils.model_parallel_utils import get_device_map, assert_device_map
|
||||||
|
|
||||||
|
from models.arch_util import AttentionBlock
|
||||||
from models.gpt_voice.gpt_asr_hf import GPT2InferenceModel
|
from models.gpt_voice.gpt_asr_hf import GPT2InferenceModel
|
||||||
from models.gpt_voice.mini_encoder import AudioMiniEncoder
|
from models.gpt_voice.mini_encoder import AudioMiniEncoder
|
||||||
from models.tacotron2.text import symbols
|
from models.tacotron2.text import symbols
|
||||||
|
@ -15,6 +16,28 @@ from trainer.networks import register_model
|
||||||
from utils.util import opt_get
|
from utils.util import opt_get
|
||||||
|
|
||||||
|
|
||||||
|
class ConditioningEncoder(nn.Module):
|
||||||
|
def __init__(self,
|
||||||
|
spec_dim,
|
||||||
|
embedding_dim,
|
||||||
|
attn_blocks=4,
|
||||||
|
num_attn_heads=4,
|
||||||
|
do_checkpointing=False):
|
||||||
|
super().__init__()
|
||||||
|
attn = []
|
||||||
|
self.init = nn.Conv1d(spec_dim, embedding_dim, kernel_size=1)
|
||||||
|
for a in range(attn_blocks):
|
||||||
|
attn.append(AttentionBlock(embedding_dim, num_attn_heads, do_checkpoint=do_checkpointing))
|
||||||
|
self.attn = nn.Sequential(*attn)
|
||||||
|
self.dim = embedding_dim
|
||||||
|
self.do_checkpointing = do_checkpointing
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
h = self.init(x)
|
||||||
|
h = self.attn(h)
|
||||||
|
return h[:, :, 0]
|
||||||
|
|
||||||
|
|
||||||
class GptTtsHf(nn.Module):
|
class GptTtsHf(nn.Module):
|
||||||
NUMBER_TEXT_TOKENS = len(symbols)+1
|
NUMBER_TEXT_TOKENS = len(symbols)+1
|
||||||
START_TEXT_TOKEN = len(symbols)
|
START_TEXT_TOKEN = len(symbols)
|
||||||
|
@ -24,14 +47,14 @@ class GptTtsHf(nn.Module):
|
||||||
STOP_MEL_TOKEN = 8193
|
STOP_MEL_TOKEN = 8193
|
||||||
|
|
||||||
def __init__(self, layers=8, model_dim=512, heads=8, max_symbols_per_phrase=200, max_mel_tokens=250, max_conditioning_inputs=3,
|
def __init__(self, layers=8, model_dim=512, heads=8, max_symbols_per_phrase=200, max_mel_tokens=250, max_conditioning_inputs=3,
|
||||||
checkpointing=True, mel_length_compression=1024, max_conditioning_length=44100//256):
|
checkpointing=True, mel_length_compression=1024, max_conditioning_length=60):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.max_mel_tokens = max_mel_tokens
|
self.max_mel_tokens = max_mel_tokens
|
||||||
self.max_symbols_per_phrase = max_symbols_per_phrase
|
self.max_symbols_per_phrase = max_symbols_per_phrase
|
||||||
self.model_dim = model_dim
|
self.model_dim = model_dim
|
||||||
self.max_conditioning_inputs = max_conditioning_inputs
|
self.max_conditioning_inputs = max_conditioning_inputs
|
||||||
self.mel_length_compression = mel_length_compression
|
self.mel_length_compression = mel_length_compression
|
||||||
self.conditioning_encoder = AudioMiniEncoder(80, model_dim)
|
self.conditioning_encoder = ConditioningEncoder(80, model_dim)
|
||||||
self.text_embedding = nn.Embedding(self.NUMBER_TEXT_TOKENS, model_dim)
|
self.text_embedding = nn.Embedding(self.NUMBER_TEXT_TOKENS, model_dim)
|
||||||
seq_length = 2+self.max_symbols_per_phrase+self.max_conditioning_inputs+self.max_mel_tokens
|
seq_length = 2+self.max_symbols_per_phrase+self.max_conditioning_inputs+self.max_mel_tokens
|
||||||
self.gpt_config = GPT2Config(vocab_size=self.NUMBER_MEL_CODES,
|
self.gpt_config = GPT2Config(vocab_size=self.NUMBER_MEL_CODES,
|
||||||
|
@ -54,19 +77,12 @@ class GptTtsHf(nn.Module):
|
||||||
tar = F.pad(input, (0,1), value=stop_token)
|
tar = F.pad(input, (0,1), value=stop_token)
|
||||||
return inp, tar
|
return inp, tar
|
||||||
|
|
||||||
def get_logits(self, text_inputs, cond_inputs, mel_inputs, get_attns=False):
|
def get_logits(self, text_inputs, cond_input, mel_inputs, get_attns=False):
|
||||||
text_emb = self.text_embedding(text_inputs)
|
text_emb = self.text_embedding(text_inputs)
|
||||||
|
cond = self.conditioning_encoder(cond_input).unsqueeze(1)
|
||||||
conds = []
|
|
||||||
for k in range(cond_inputs.shape[1]):
|
|
||||||
conds.append(self.conditioning_encoder(cond_inputs[:, k]))
|
|
||||||
while len(conds) < self.max_conditioning_inputs:
|
|
||||||
conds.append(conds[-1])
|
|
||||||
conds = torch.stack(conds, dim=1)
|
|
||||||
|
|
||||||
mel_emb = self.gpt.get_input_embeddings()(mel_inputs)
|
mel_emb = self.gpt.get_input_embeddings()(mel_inputs)
|
||||||
|
|
||||||
emb = torch.cat([text_emb, conds, mel_emb], dim=1)
|
emb = torch.cat([text_emb, cond, mel_emb], dim=1)
|
||||||
gpt_out = self.gpt(inputs_embeds=emb, return_dict=True, output_attentions=get_attns)
|
gpt_out = self.gpt(inputs_embeds=emb, return_dict=True, output_attentions=get_attns)
|
||||||
if get_attns:
|
if get_attns:
|
||||||
return gpt_out.attentions
|
return gpt_out.attentions
|
||||||
|
@ -81,7 +97,7 @@ class GptTtsHf(nn.Module):
|
||||||
|
|
||||||
return text_logits, mel_logits
|
return text_logits, mel_logits
|
||||||
|
|
||||||
def forward(self, text_inputs, cond_inputs, mel_targets, wav_lengths, return_attentions=False):
|
def forward(self, text_inputs, cond_input, mel_targets, wav_lengths, return_attentions=False):
|
||||||
"""
|
"""
|
||||||
Forward pass
|
Forward pass
|
||||||
text_inputs: long tensor, (b,t)
|
text_inputs: long tensor, (b,t)
|
||||||
|
@ -95,18 +111,14 @@ class GptTtsHf(nn.Module):
|
||||||
if mel_lengths[b] < mel_targets.shape[-1]:
|
if mel_lengths[b] < mel_targets.shape[-1]:
|
||||||
mel_targets[b, mel_lengths[b]:] = self.STOP_MEL_TOKEN
|
mel_targets[b, mel_lengths[b]:] = self.STOP_MEL_TOKEN
|
||||||
|
|
||||||
# Format conditioning inputs properly.
|
# Randomly permute the conditioning spectrogram, to destroy any structure present.
|
||||||
if len(cond_inputs.shape) == 3:
|
cond_input = cond_input[:,:,torch.randperm(cond_input.shape[-1])]
|
||||||
cond_inputs = cond_inputs.unsqueeze(1) # Format a single conditioning input as a set of {1}
|
if cond_input.shape[-1] > self.max_conditioning_length:
|
||||||
if cond_inputs.shape[-1] > self.max_conditioning_length:
|
cond_input = cond_input[:,:,:self.max_conditioning_length]
|
||||||
# Remember, that this doesn't necessarily mean that the conditioning inputs aren't mostly zero-padded, so
|
|
||||||
# skew trimming towards the front end of the clip.
|
|
||||||
rand_clip = random.randint(0, min(50, cond_inputs.shape[-1]-self.max_conditioning_length))
|
|
||||||
cond_inputs = cond_inputs[:,:,:,rand_clip:rand_clip+self.max_conditioning_length]
|
|
||||||
|
|
||||||
text_inputs, text_targets = self.build_aligned_inputs_and_targets(text_inputs, self.START_TEXT_TOKEN, self.STOP_TEXT_TOKEN)
|
text_inputs, text_targets = self.build_aligned_inputs_and_targets(text_inputs, self.START_TEXT_TOKEN, self.STOP_TEXT_TOKEN)
|
||||||
mel_inputs, mel_targets = self.build_aligned_inputs_and_targets(mel_targets, self.START_MEL_TOKEN, self.STOP_MEL_TOKEN)
|
mel_inputs, mel_targets = self.build_aligned_inputs_and_targets(mel_targets, self.START_MEL_TOKEN, self.STOP_MEL_TOKEN)
|
||||||
text_logits, mel_logits = self.get_logits(text_inputs, cond_inputs, mel_inputs, get_attns=return_attentions)
|
text_logits, mel_logits = self.get_logits(text_inputs, cond_input, mel_inputs, get_attns=return_attentions)
|
||||||
if return_attentions:
|
if return_attentions:
|
||||||
return mel_logits
|
return mel_logits
|
||||||
loss_text = F.cross_entropy(text_logits, text_targets.long())
|
loss_text = F.cross_entropy(text_logits, text_targets.long())
|
||||||
|
@ -153,6 +165,6 @@ def register_gpt_tts_hf(opt_net, opt):
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
gpt = GptTtsHf(model_dim=1024, heads=16)
|
gpt = GptTtsHf(model_dim=1024, heads=16)
|
||||||
l = gpt(torch.randint(high=len(symbols), size=(2,200)),
|
l = gpt(torch.randint(high=len(symbols), size=(2,200)),
|
||||||
torch.randn(2,80,800),
|
torch.arange(0, 80, 1, dtype=torch.float).view(1,80,1).repeat(2,1,800),
|
||||||
torch.randint(high=8192, size=(2,250)),
|
torch.randint(high=8192, size=(2,250)),
|
||||||
torch.tensor([150*256,195*256]))
|
torch.tensor([150*256,195*256]))
|
||||||
|
|
Loading…
Reference in New Issue
Block a user