Various fixes to gpt_tts_hf
This commit is contained in:
parent
62c8ed9a29
commit
9e8a9bf6ca
|
@ -219,11 +219,13 @@ if __name__ == '__main__':
|
|||
'phase': 'train',
|
||||
'n_workers': 0,
|
||||
'batch_size': batch_sz,
|
||||
'needs_collate': True,
|
||||
'max_wav_length': 256000,
|
||||
'needs_collate': False,
|
||||
'max_wav_length': 255995,
|
||||
'max_text_length': 200,
|
||||
'sample_rate': 22050,
|
||||
'load_conditioning': True,
|
||||
'num_conditioning_candidates': 3,
|
||||
'conditioning_length': 44100,
|
||||
}
|
||||
from data import create_dataset, create_dataloader
|
||||
|
||||
|
|
|
@ -33,10 +33,11 @@ class GptTtsHf(nn.Module):
|
|||
self.max_conditioning_inputs = max_conditioning_inputs
|
||||
self.mel_length_compression = mel_length_compression
|
||||
self.conditioning_encoder = AudioMiniEncoder(80, model_dim)
|
||||
self.text_pos_embedding = nn.Embedding(self.max_symbols_per_phrase + 1, model_dim)
|
||||
self.conditioning_embedding = nn.Embedding(self.max_conditioning_inputs, model_dim)
|
||||
self.text_embedding = nn.Embedding(self.NUMBER_TEXT_TOKENS, model_dim)
|
||||
self.text_pos_embedding = nn.Embedding(self.max_symbols_per_phrase + 2, model_dim)
|
||||
self.conditioning_embedding = nn.Parameter(torch.randn(1,model_dim), requires_grad=True)
|
||||
self.mel_pos_embedding = nn.Embedding(self.max_mel_tokens + 2, model_dim)
|
||||
seq_length = 2+self.max_symbols_per_phrase+self.max_conditioning_inputs+self.max_mel_tokens
|
||||
seq_length = 4+self.max_symbols_per_phrase+self.max_conditioning_inputs+self.max_mel_tokens
|
||||
self.gpt_config = GPT2Config(vocab_size=self.NUMBER_MEL_CODES,
|
||||
n_positions=seq_length,
|
||||
n_ctx=seq_length,
|
||||
|
@ -56,14 +57,10 @@ class GptTtsHf(nn.Module):
|
|||
assert cond_inputs.shape[1] <= self.max_conditioning_inputs
|
||||
assert mel_targets.shape[1] <= self.max_mel_tokens
|
||||
|
||||
mel_targets = F.pad(mel_targets, (1,0), value=self.START_MEL_TOKEN)
|
||||
mel_targets = F.pad(mel_targets, (0, self.max_mel_tokens - mel_targets.shape[1]), value=self.STOP_MEL_TOKEN)
|
||||
mel_emb = self.gpt.get_input_embeddings()(mel_targets)
|
||||
mel_emb = mel_emb + self.mel_pos_embedding(torch.arange(mel_emb.shape[1], device=mel_targets.device))
|
||||
|
||||
text_targets = F.pad(text_inputs, (1,0), value=self.START_TEXT_TOKEN)
|
||||
text_targets = F.pad(text_inputs, (0, self.max_symbols_per_phrase - text_targets.shape[1]), value=self.STOP_TEXT_TOKEN)
|
||||
text_emb = self.gpt.get_input_embeddings()(text_targets)
|
||||
text_targets = F.pad(text_targets, (0,1), value=self.STOP_TEXT_TOKEN)
|
||||
text_emb = self.text_embedding(text_targets)
|
||||
text_emb = text_emb + self.text_pos_embedding(torch.arange(text_emb.shape[1], device=text_targets.device))
|
||||
|
||||
conds = []
|
||||
|
@ -72,7 +69,12 @@ class GptTtsHf(nn.Module):
|
|||
while len(conds) < self.max_conditioning_inputs:
|
||||
conds.append(conds[-1])
|
||||
conds = torch.stack(conds, dim=1)
|
||||
conds = conds + self.conditioning_embedding(torch.arange(conds.shape[1], device=conds.device))
|
||||
conds = conds + self.conditioning_embedding
|
||||
|
||||
mel_targets = F.pad(mel_targets, (1,0), value=self.START_MEL_TOKEN)
|
||||
mel_targets = F.pad(mel_targets, (0,1), value=self.STOP_MEL_TOKEN)
|
||||
mel_emb = self.gpt.get_input_embeddings()(mel_targets)
|
||||
mel_emb = mel_emb + self.mel_pos_embedding(torch.arange(mel_emb.shape[1], device=mel_targets.device))
|
||||
|
||||
emb = torch.cat([text_emb, conds, mel_emb], dim=1)
|
||||
gpt_out = self.gpt(inputs_embeds=emb, return_dict=True, output_attentions=get_attns)
|
||||
|
@ -118,8 +120,8 @@ class GptTtsHf(nn.Module):
|
|||
self.inference_model = GPT2InferenceModel(self.gpt_config, self.gpt, self.text_pos_embedding, self.final_norm, self.text_head)
|
||||
|
||||
text_targets = F.pad(text_inputs, (1,0), value=self.START_TEXT_TOKEN)
|
||||
text_targets = F.pad(text_inputs, (0, self.max_symbols_per_phrase - text_targets.shape[1]), value=self.STOP_TEXT_TOKEN)
|
||||
text_emb = self.gpt.get_input_embeddings()(text_targets)
|
||||
text_targets = F.pad(text_targets, (0,1), value=self.STOP_TEXT_TOKEN)
|
||||
text_emb = self.text_embedding(text_targets)
|
||||
text_emb = text_emb + self.text_pos_embedding(torch.arange(text_emb.shape[1], device=text_targets.device))
|
||||
|
||||
conds = []
|
||||
|
@ -133,11 +135,11 @@ class GptTtsHf(nn.Module):
|
|||
emb = torch.cat([text_emb, conds], dim=1)
|
||||
self.inference_model.store_mel_emb(emb)
|
||||
|
||||
fake_inputs = torch.full((text_inputs.shape[0],self.max_symbols_per_phrase+self.max_conditioning_inputs+1,), fill_value=1, dtype=torch.long, device=text_inputs.device)
|
||||
fake_inputs = torch.full((text_inputs.shape[0],emb.shape[1]+1,), fill_value=1, dtype=torch.long, device=text_inputs.device)
|
||||
fake_inputs[:,-1] = self.START_MEL_TOKEN
|
||||
|
||||
gen = self.inference_model.generate(fake_inputs, do_sample=do_sample, bos_token_id=self.START_MEL_TOKEN, pad_token_id=self.STOP_MEL_TOKEN, eos_token_id=self.STOP_MEL_TOKEN,
|
||||
max_length=self.max_symbols_per_phrase+self.max_mel_frames, temperature=temperature, num_beams=num_beams, use_cache=True)
|
||||
max_length=emb.shape[1]+self.max_mel_tokens, temperature=temperature, num_beams=num_beams, use_cache=True)
|
||||
return gen[:, self.max_mel_frames:]
|
||||
|
||||
|
||||
|
@ -148,7 +150,7 @@ def register_gpt_tts_hf(opt_net, opt):
|
|||
|
||||
if __name__ == '__main__':
|
||||
gpt = GptTtsHf(model_dim=1024, heads=16)
|
||||
l = gpt(torch.randint(high=len(symbols), size=(2,100)),
|
||||
l = gpt(torch.randint(high=len(symbols), size=(2,200)),
|
||||
torch.randn(2,2,80,800),
|
||||
torch.randint(high=8192, size=(2,200)),
|
||||
torch.randint(high=8192, size=(2,250)),
|
||||
torch.tensor([150*256,195*256]))
|
||||
|
|
Loading…
Reference in New Issue
Block a user