Various fixes to gpt_tts_hf

This commit is contained in:
James Betker 2021-12-16 23:28:44 -07:00
parent 62c8ed9a29
commit 9e8a9bf6ca
2 changed files with 22 additions and 18 deletions

View File

@ -219,11 +219,13 @@ if __name__ == '__main__':
'phase': 'train',
'n_workers': 0,
'batch_size': batch_sz,
'needs_collate': True,
'max_wav_length': 256000,
'needs_collate': False,
'max_wav_length': 255995,
'max_text_length': 200,
'sample_rate': 22050,
'load_conditioning': True,
'num_conditioning_candidates': 3,
'conditioning_length': 44100,
}
from data import create_dataset, create_dataloader

View File

@ -33,10 +33,11 @@ class GptTtsHf(nn.Module):
self.max_conditioning_inputs = max_conditioning_inputs
self.mel_length_compression = mel_length_compression
self.conditioning_encoder = AudioMiniEncoder(80, model_dim)
self.text_pos_embedding = nn.Embedding(self.max_symbols_per_phrase + 1, model_dim)
self.conditioning_embedding = nn.Embedding(self.max_conditioning_inputs, model_dim)
self.text_embedding = nn.Embedding(self.NUMBER_TEXT_TOKENS, model_dim)
self.text_pos_embedding = nn.Embedding(self.max_symbols_per_phrase + 2, model_dim)
self.conditioning_embedding = nn.Parameter(torch.randn(1,model_dim), requires_grad=True)
self.mel_pos_embedding = nn.Embedding(self.max_mel_tokens + 2, model_dim)
seq_length = 2+self.max_symbols_per_phrase+self.max_conditioning_inputs+self.max_mel_tokens
seq_length = 4+self.max_symbols_per_phrase+self.max_conditioning_inputs+self.max_mel_tokens
self.gpt_config = GPT2Config(vocab_size=self.NUMBER_MEL_CODES,
n_positions=seq_length,
n_ctx=seq_length,
@ -56,14 +57,10 @@ class GptTtsHf(nn.Module):
assert cond_inputs.shape[1] <= self.max_conditioning_inputs
assert mel_targets.shape[1] <= self.max_mel_tokens
mel_targets = F.pad(mel_targets, (1,0), value=self.START_MEL_TOKEN)
mel_targets = F.pad(mel_targets, (0, self.max_mel_tokens - mel_targets.shape[1]), value=self.STOP_MEL_TOKEN)
mel_emb = self.gpt.get_input_embeddings()(mel_targets)
mel_emb = mel_emb + self.mel_pos_embedding(torch.arange(mel_emb.shape[1], device=mel_targets.device))
text_targets = F.pad(text_inputs, (1,0), value=self.START_TEXT_TOKEN)
text_targets = F.pad(text_inputs, (0, self.max_symbols_per_phrase - text_targets.shape[1]), value=self.STOP_TEXT_TOKEN)
text_emb = self.gpt.get_input_embeddings()(text_targets)
text_targets = F.pad(text_targets, (0,1), value=self.STOP_TEXT_TOKEN)
text_emb = self.text_embedding(text_targets)
text_emb = text_emb + self.text_pos_embedding(torch.arange(text_emb.shape[1], device=text_targets.device))
conds = []
@ -72,7 +69,12 @@ class GptTtsHf(nn.Module):
while len(conds) < self.max_conditioning_inputs:
conds.append(conds[-1])
conds = torch.stack(conds, dim=1)
conds = conds + self.conditioning_embedding(torch.arange(conds.shape[1], device=conds.device))
conds = conds + self.conditioning_embedding
mel_targets = F.pad(mel_targets, (1,0), value=self.START_MEL_TOKEN)
mel_targets = F.pad(mel_targets, (0,1), value=self.STOP_MEL_TOKEN)
mel_emb = self.gpt.get_input_embeddings()(mel_targets)
mel_emb = mel_emb + self.mel_pos_embedding(torch.arange(mel_emb.shape[1], device=mel_targets.device))
emb = torch.cat([text_emb, conds, mel_emb], dim=1)
gpt_out = self.gpt(inputs_embeds=emb, return_dict=True, output_attentions=get_attns)
@ -118,8 +120,8 @@ class GptTtsHf(nn.Module):
self.inference_model = GPT2InferenceModel(self.gpt_config, self.gpt, self.text_pos_embedding, self.final_norm, self.text_head)
text_targets = F.pad(text_inputs, (1,0), value=self.START_TEXT_TOKEN)
text_targets = F.pad(text_inputs, (0, self.max_symbols_per_phrase - text_targets.shape[1]), value=self.STOP_TEXT_TOKEN)
text_emb = self.gpt.get_input_embeddings()(text_targets)
text_targets = F.pad(text_targets, (0,1), value=self.STOP_TEXT_TOKEN)
text_emb = self.text_embedding(text_targets)
text_emb = text_emb + self.text_pos_embedding(torch.arange(text_emb.shape[1], device=text_targets.device))
conds = []
@ -133,11 +135,11 @@ class GptTtsHf(nn.Module):
emb = torch.cat([text_emb, conds], dim=1)
self.inference_model.store_mel_emb(emb)
fake_inputs = torch.full((text_inputs.shape[0],self.max_symbols_per_phrase+self.max_conditioning_inputs+1,), fill_value=1, dtype=torch.long, device=text_inputs.device)
fake_inputs = torch.full((text_inputs.shape[0],emb.shape[1]+1,), fill_value=1, dtype=torch.long, device=text_inputs.device)
fake_inputs[:,-1] = self.START_MEL_TOKEN
gen = self.inference_model.generate(fake_inputs, do_sample=do_sample, bos_token_id=self.START_MEL_TOKEN, pad_token_id=self.STOP_MEL_TOKEN, eos_token_id=self.STOP_MEL_TOKEN,
max_length=self.max_symbols_per_phrase+self.max_mel_frames, temperature=temperature, num_beams=num_beams, use_cache=True)
max_length=emb.shape[1]+self.max_mel_tokens, temperature=temperature, num_beams=num_beams, use_cache=True)
return gen[:, self.max_mel_frames:]
@ -148,7 +150,7 @@ def register_gpt_tts_hf(opt_net, opt):
if __name__ == '__main__':
gpt = GptTtsHf(model_dim=1024, heads=16)
l = gpt(torch.randint(high=len(symbols), size=(2,100)),
l = gpt(torch.randint(high=len(symbols), size=(2,200)),
torch.randn(2,2,80,800),
torch.randint(high=8192, size=(2,200)),
torch.randint(high=8192, size=(2,250)),
torch.tensor([150*256,195*256]))