Unified: automatically clip inputs according to specified max length to improve inference time

This commit is contained in:
James Betker 2022-01-06 10:13:45 -07:00
parent 61cd351b71
commit 525addffab

View File

@ -115,9 +115,9 @@ class UnifiedGptVoice(nn.Module):
self.mel_length_compression = mel_length_compression
self.conditioning_encoder = ConditioningEncoder(80, model_dim, num_attn_heads=heads)
self.text_embedding = nn.Embedding(self.number_text_tokens, model_dim)
self.text_pos_embedding = nn.Embedding(self.max_text_tokens + 1, model_dim)
self.mel_pos_embedding = nn.Embedding(self.max_mel_tokens + 1, model_dim)
seq_length = 2+max_text_tokens+self.max_mel_tokens+self.max_conditioning_inputs
self.text_pos_embedding = nn.Embedding(self.max_text_tokens + 2, model_dim)
self.mel_pos_embedding = nn.Embedding(self.max_mel_tokens + 2, model_dim)
seq_length = 4+max_text_tokens+self.max_mel_tokens+self.max_conditioning_inputs
self.gpt_config = GPT2Config(vocab_size=self.number_mel_codes,
n_positions=seq_length,
n_ctx=seq_length,
@ -212,13 +212,14 @@ class UnifiedGptVoice(nn.Module):
else:
return first_logits
def forward(self, speech_conditioning_input, text_inputs, mel_codes, wav_lengths, text_first=True, raw_mels=None, return_attentions=False):
def forward(self, speech_conditioning_input, text_inputs, text_lengths, mel_codes, wav_lengths, text_first=True, raw_mels=None, return_attentions=False):
"""
Forward pass that uses both text and voice in either text conditioning mode or voice conditioning mode
(actuated by `text_first`).
speech_conditioning_input: MEL float tensor, (b,80,s)
text_inputs: long tensor, (b,t)
text_lengths: long tensor, (b,)
mel_inputs: long tensor, (b,m)
wav_lengths: long tensor, (b,)
raw_mels: MEL float tensor (b,80,s)
@ -226,7 +227,16 @@ class UnifiedGptVoice(nn.Module):
assert self.max_mel_tokens >= mel_codes.shape[1], f'{mel_codes.shape[1]}'
assert self.max_text_tokens >= text_inputs.shape[1], f'{text_inputs.shape[1]}'
# This model will receive micro-batches with a ton of padding for both the text and MELs. Ameliorate this by
# chopping the inputs by the maximum actual length.
max_text_len = text_lengths.max()
text_inputs = F.pad(text_inputs[:, :max_text_len], (0,1), value=self.stop_text_token)
max_mel_len = wav_lengths.max() // self.mel_length_compression
mel_codes = F.pad(mel_codes[:, :max_mel_len], (0,1), value=self.stop_mel_token)
if raw_mels is not None:
raw_mels = raw_mels[:, :, :max_mel_len*4]
mel_codes = self.set_mel_padding(mel_codes, wav_lengths)
if self.shuffle_conditioning:
speech_conditioning_input = self.randomly_permute_conditioning_input(speech_conditioning_input)
speech_conditioning_input = self.conditioning_encoder(speech_conditioning_input).unsqueeze(1)
@ -235,7 +245,7 @@ class UnifiedGptVoice(nn.Module):
text_emb = self.text_embedding(text_inputs) + self.text_pos_embedding(torch.arange(text_inputs.shape[1], device=text_inputs.device))
mel_codes, mel_targets = self.build_aligned_inputs_and_targets(mel_codes, self.start_mel_token, self.stop_mel_token)
if raw_mels is not None:
mel_inp = F.pad(raw_mels, (0, 4))
mel_inp = F.pad(raw_mels, (0, 8))
else:
mel_inp = mel_codes
mel_emb = self.gpt.get_input_embeddings()(mel_inp)
@ -251,13 +261,18 @@ class UnifiedGptVoice(nn.Module):
loss_mel = F.cross_entropy(mel_logits, mel_targets.long())
return loss_text.mean(), loss_mel.mean(), mel_logits
def text_forward(self, speech_conditioning_input, text_inputs):
def text_forward(self, speech_conditioning_input, text_inputs, text_lengths):
"""
Performs autoregressive modeling on only text. Still requires a speech_conditioning_input due to the way the
model inputs are formatted. Just provide any audio clip (arguably, zeros could be provided).
"""
assert self.max_text_tokens >= text_inputs.shape[1], f'{text_inputs.shape[1]}'
# This model will receive micro-batches with a ton of padding for both the text and MELs. Ameliorate this by
# chopping the inputs by the maximum actual length.
max_text_len = text_lengths.max()
text_inputs = F.pad(text_inputs[:, :max_text_len], (0,1), value=self.stop_text_token)
if self.shuffle_conditioning:
speech_conditioning_input = self.randomly_permute_conditioning_input(speech_conditioning_input)
speech_conditioning_input = self.conditioning_encoder(speech_conditioning_input).unsqueeze(1)
@ -274,7 +289,14 @@ class UnifiedGptVoice(nn.Module):
"""
assert self.max_mel_tokens >= mel_codes.shape[1], f'{mel_codes.shape[1]}'
# This model will receive micro-batches with a ton of padding for both the text and MELs. Ameliorate this by
# chopping the inputs by the maximum actual length.
max_mel_len = wav_lengths.max() // self.mel_length_compression
mel_codes = F.pad(mel_codes[:, :max_mel_len], (0,1), value=self.stop_mel_token)
mel_codes = self.set_mel_padding(mel_codes, wav_lengths)
if raw_mels is not None:
raw_mels = raw_mels[:, :, :max_mel_len*4]
if self.shuffle_conditioning:
speech_conditioning_input = self.randomly_permute_conditioning_input(speech_conditioning_input)
speech_conditioning_input = self.conditioning_encoder(speech_conditioning_input).unsqueeze(1)
@ -294,6 +316,7 @@ class UnifiedGptVoice(nn.Module):
if not hasattr(self, 'inference_model'):
self.inference_model = GPT2InferenceModel(self.gpt_config, self.gpt, self.mel_pos_embedding, self.final_norm, self.mel_head)
text_inputs = F.pad(text_inputs, (0, 1), value=self.stop_text_token)
text_inputs, text_targets = self.build_aligned_inputs_and_targets(text_inputs, self.start_text_token, self.stop_text_token)
text_emb = self.text_embedding(text_inputs) + self.text_pos_embedding(torch.arange(text_inputs.shape[1], device=text_inputs.device))
@ -319,10 +342,10 @@ def register_unified_gpt_voice(opt_net, opt):
if __name__ == '__main__':
gpt = UnifiedGptVoice(model_dim=256, heads=4, train_solo_embeddings=True, use_mel_codes_as_input=False)
gpt = UnifiedGptVoice(model_dim=256, heads=4, train_solo_embeddings=True, use_mel_codes_as_input=True)
l = gpt(torch.randn(2, 80, 800),
torch.randint(high=len(symbols), size=(2,80)),
torch.tensor([32, 80]),
torch.randint(high=8192, size=(2,250)),
torch.tensor([150*256,195*256]),
raw_mels=torch.randn(2,80,1000))
gpt.text_forward(torch.randn(2,80,800), torch.randint(high=50, size=(2,80)))
torch.tensor([150*256,195*256]))
gpt.text_forward(torch.randn(2,80,800), torch.randint(high=50, size=(2,80)), torch.tensor([32, 80]))