Various fixes to gpt_tts_hf
This commit is contained in:
parent
62c8ed9a29
commit
9e8a9bf6ca
|
@ -219,11 +219,13 @@ if __name__ == '__main__':
|
||||||
'phase': 'train',
|
'phase': 'train',
|
||||||
'n_workers': 0,
|
'n_workers': 0,
|
||||||
'batch_size': batch_sz,
|
'batch_size': batch_sz,
|
||||||
'needs_collate': True,
|
'needs_collate': False,
|
||||||
'max_wav_length': 256000,
|
'max_wav_length': 255995,
|
||||||
'max_text_length': 200,
|
'max_text_length': 200,
|
||||||
'sample_rate': 22050,
|
'sample_rate': 22050,
|
||||||
'load_conditioning': True,
|
'load_conditioning': True,
|
||||||
|
'num_conditioning_candidates': 3,
|
||||||
|
'conditioning_length': 44100,
|
||||||
}
|
}
|
||||||
from data import create_dataset, create_dataloader
|
from data import create_dataset, create_dataloader
|
||||||
|
|
||||||
|
|
|
@ -33,10 +33,11 @@ class GptTtsHf(nn.Module):
|
||||||
self.max_conditioning_inputs = max_conditioning_inputs
|
self.max_conditioning_inputs = max_conditioning_inputs
|
||||||
self.mel_length_compression = mel_length_compression
|
self.mel_length_compression = mel_length_compression
|
||||||
self.conditioning_encoder = AudioMiniEncoder(80, model_dim)
|
self.conditioning_encoder = AudioMiniEncoder(80, model_dim)
|
||||||
self.text_pos_embedding = nn.Embedding(self.max_symbols_per_phrase + 1, model_dim)
|
self.text_embedding = nn.Embedding(self.NUMBER_TEXT_TOKENS, model_dim)
|
||||||
self.conditioning_embedding = nn.Embedding(self.max_conditioning_inputs, model_dim)
|
self.text_pos_embedding = nn.Embedding(self.max_symbols_per_phrase + 2, model_dim)
|
||||||
|
self.conditioning_embedding = nn.Parameter(torch.randn(1,model_dim), requires_grad=True)
|
||||||
self.mel_pos_embedding = nn.Embedding(self.max_mel_tokens + 2, model_dim)
|
self.mel_pos_embedding = nn.Embedding(self.max_mel_tokens + 2, model_dim)
|
||||||
seq_length = 2+self.max_symbols_per_phrase+self.max_conditioning_inputs+self.max_mel_tokens
|
seq_length = 4+self.max_symbols_per_phrase+self.max_conditioning_inputs+self.max_mel_tokens
|
||||||
self.gpt_config = GPT2Config(vocab_size=self.NUMBER_MEL_CODES,
|
self.gpt_config = GPT2Config(vocab_size=self.NUMBER_MEL_CODES,
|
||||||
n_positions=seq_length,
|
n_positions=seq_length,
|
||||||
n_ctx=seq_length,
|
n_ctx=seq_length,
|
||||||
|
@ -56,14 +57,10 @@ class GptTtsHf(nn.Module):
|
||||||
assert cond_inputs.shape[1] <= self.max_conditioning_inputs
|
assert cond_inputs.shape[1] <= self.max_conditioning_inputs
|
||||||
assert mel_targets.shape[1] <= self.max_mel_tokens
|
assert mel_targets.shape[1] <= self.max_mel_tokens
|
||||||
|
|
||||||
mel_targets = F.pad(mel_targets, (1,0), value=self.START_MEL_TOKEN)
|
|
||||||
mel_targets = F.pad(mel_targets, (0, self.max_mel_tokens - mel_targets.shape[1]), value=self.STOP_MEL_TOKEN)
|
|
||||||
mel_emb = self.gpt.get_input_embeddings()(mel_targets)
|
|
||||||
mel_emb = mel_emb + self.mel_pos_embedding(torch.arange(mel_emb.shape[1], device=mel_targets.device))
|
|
||||||
|
|
||||||
text_targets = F.pad(text_inputs, (1,0), value=self.START_TEXT_TOKEN)
|
text_targets = F.pad(text_inputs, (1,0), value=self.START_TEXT_TOKEN)
|
||||||
text_targets = F.pad(text_inputs, (0, self.max_symbols_per_phrase - text_targets.shape[1]), value=self.STOP_TEXT_TOKEN)
|
text_targets = F.pad(text_targets, (0,1), value=self.STOP_TEXT_TOKEN)
|
||||||
text_emb = self.gpt.get_input_embeddings()(text_targets)
|
text_emb = self.text_embedding(text_targets)
|
||||||
text_emb = text_emb + self.text_pos_embedding(torch.arange(text_emb.shape[1], device=text_targets.device))
|
text_emb = text_emb + self.text_pos_embedding(torch.arange(text_emb.shape[1], device=text_targets.device))
|
||||||
|
|
||||||
conds = []
|
conds = []
|
||||||
|
@ -72,7 +69,12 @@ class GptTtsHf(nn.Module):
|
||||||
while len(conds) < self.max_conditioning_inputs:
|
while len(conds) < self.max_conditioning_inputs:
|
||||||
conds.append(conds[-1])
|
conds.append(conds[-1])
|
||||||
conds = torch.stack(conds, dim=1)
|
conds = torch.stack(conds, dim=1)
|
||||||
conds = conds + self.conditioning_embedding(torch.arange(conds.shape[1], device=conds.device))
|
conds = conds + self.conditioning_embedding
|
||||||
|
|
||||||
|
mel_targets = F.pad(mel_targets, (1,0), value=self.START_MEL_TOKEN)
|
||||||
|
mel_targets = F.pad(mel_targets, (0,1), value=self.STOP_MEL_TOKEN)
|
||||||
|
mel_emb = self.gpt.get_input_embeddings()(mel_targets)
|
||||||
|
mel_emb = mel_emb + self.mel_pos_embedding(torch.arange(mel_emb.shape[1], device=mel_targets.device))
|
||||||
|
|
||||||
emb = torch.cat([text_emb, conds, mel_emb], dim=1)
|
emb = torch.cat([text_emb, conds, mel_emb], dim=1)
|
||||||
gpt_out = self.gpt(inputs_embeds=emb, return_dict=True, output_attentions=get_attns)
|
gpt_out = self.gpt(inputs_embeds=emb, return_dict=True, output_attentions=get_attns)
|
||||||
|
@ -118,8 +120,8 @@ class GptTtsHf(nn.Module):
|
||||||
self.inference_model = GPT2InferenceModel(self.gpt_config, self.gpt, self.text_pos_embedding, self.final_norm, self.text_head)
|
self.inference_model = GPT2InferenceModel(self.gpt_config, self.gpt, self.text_pos_embedding, self.final_norm, self.text_head)
|
||||||
|
|
||||||
text_targets = F.pad(text_inputs, (1,0), value=self.START_TEXT_TOKEN)
|
text_targets = F.pad(text_inputs, (1,0), value=self.START_TEXT_TOKEN)
|
||||||
text_targets = F.pad(text_inputs, (0, self.max_symbols_per_phrase - text_targets.shape[1]), value=self.STOP_TEXT_TOKEN)
|
text_targets = F.pad(text_targets, (0,1), value=self.STOP_TEXT_TOKEN)
|
||||||
text_emb = self.gpt.get_input_embeddings()(text_targets)
|
text_emb = self.text_embedding(text_targets)
|
||||||
text_emb = text_emb + self.text_pos_embedding(torch.arange(text_emb.shape[1], device=text_targets.device))
|
text_emb = text_emb + self.text_pos_embedding(torch.arange(text_emb.shape[1], device=text_targets.device))
|
||||||
|
|
||||||
conds = []
|
conds = []
|
||||||
|
@ -133,11 +135,11 @@ class GptTtsHf(nn.Module):
|
||||||
emb = torch.cat([text_emb, conds], dim=1)
|
emb = torch.cat([text_emb, conds], dim=1)
|
||||||
self.inference_model.store_mel_emb(emb)
|
self.inference_model.store_mel_emb(emb)
|
||||||
|
|
||||||
fake_inputs = torch.full((text_inputs.shape[0],self.max_symbols_per_phrase+self.max_conditioning_inputs+1,), fill_value=1, dtype=torch.long, device=text_inputs.device)
|
fake_inputs = torch.full((text_inputs.shape[0],emb.shape[1]+1,), fill_value=1, dtype=torch.long, device=text_inputs.device)
|
||||||
fake_inputs[:,-1] = self.START_MEL_TOKEN
|
fake_inputs[:,-1] = self.START_MEL_TOKEN
|
||||||
|
|
||||||
gen = self.inference_model.generate(fake_inputs, do_sample=do_sample, bos_token_id=self.START_MEL_TOKEN, pad_token_id=self.STOP_MEL_TOKEN, eos_token_id=self.STOP_MEL_TOKEN,
|
gen = self.inference_model.generate(fake_inputs, do_sample=do_sample, bos_token_id=self.START_MEL_TOKEN, pad_token_id=self.STOP_MEL_TOKEN, eos_token_id=self.STOP_MEL_TOKEN,
|
||||||
max_length=self.max_symbols_per_phrase+self.max_mel_frames, temperature=temperature, num_beams=num_beams, use_cache=True)
|
max_length=emb.shape[1]+self.max_mel_tokens, temperature=temperature, num_beams=num_beams, use_cache=True)
|
||||||
return gen[:, self.max_mel_frames:]
|
return gen[:, self.max_mel_frames:]
|
||||||
|
|
||||||
|
|
||||||
|
@ -148,7 +150,7 @@ def register_gpt_tts_hf(opt_net, opt):
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
gpt = GptTtsHf(model_dim=1024, heads=16)
|
gpt = GptTtsHf(model_dim=1024, heads=16)
|
||||||
l = gpt(torch.randint(high=len(symbols), size=(2,100)),
|
l = gpt(torch.randint(high=len(symbols), size=(2,200)),
|
||||||
torch.randn(2,2,80,800),
|
torch.randn(2,2,80,800),
|
||||||
torch.randint(high=8192, size=(2,200)),
|
torch.randint(high=8192, size=(2,250)),
|
||||||
torch.tensor([150*256,195*256]))
|
torch.tensor([150*256,195*256]))
|
||||||
|
|
Loading…
Reference in New Issue
Block a user