Unified: automatically clip inputs according to specified max length to improve inference time

This commit is contained in:
James Betker 2022-01-06 10:13:45 -07:00
parent 61cd351b71
commit 525addffab

View File

@ -115,9 +115,9 @@ class UnifiedGptVoice(nn.Module):
self.mel_length_compression = mel_length_compression self.mel_length_compression = mel_length_compression
self.conditioning_encoder = ConditioningEncoder(80, model_dim, num_attn_heads=heads) self.conditioning_encoder = ConditioningEncoder(80, model_dim, num_attn_heads=heads)
self.text_embedding = nn.Embedding(self.number_text_tokens, model_dim) self.text_embedding = nn.Embedding(self.number_text_tokens, model_dim)
self.text_pos_embedding = nn.Embedding(self.max_text_tokens + 1, model_dim) self.text_pos_embedding = nn.Embedding(self.max_text_tokens + 2, model_dim)
self.mel_pos_embedding = nn.Embedding(self.max_mel_tokens + 1, model_dim) self.mel_pos_embedding = nn.Embedding(self.max_mel_tokens + 2, model_dim)
seq_length = 2+max_text_tokens+self.max_mel_tokens+self.max_conditioning_inputs seq_length = 4+max_text_tokens+self.max_mel_tokens+self.max_conditioning_inputs
self.gpt_config = GPT2Config(vocab_size=self.number_mel_codes, self.gpt_config = GPT2Config(vocab_size=self.number_mel_codes,
n_positions=seq_length, n_positions=seq_length,
n_ctx=seq_length, n_ctx=seq_length,
@ -212,13 +212,14 @@ class UnifiedGptVoice(nn.Module):
else: else:
return first_logits return first_logits
def forward(self, speech_conditioning_input, text_inputs, mel_codes, wav_lengths, text_first=True, raw_mels=None, return_attentions=False): def forward(self, speech_conditioning_input, text_inputs, text_lengths, mel_codes, wav_lengths, text_first=True, raw_mels=None, return_attentions=False):
""" """
Forward pass that uses both text and voice in either text conditioning mode or voice conditioning mode Forward pass that uses both text and voice in either text conditioning mode or voice conditioning mode
(actuated by `text_first`). (actuated by `text_first`).
speech_conditioning_input: MEL float tensor, (b,80,s) speech_conditioning_input: MEL float tensor, (b,80,s)
text_inputs: long tensor, (b,t) text_inputs: long tensor, (b,t)
text_lengths: long tensor, (b,)
mel_inputs: long tensor, (b,m) mel_inputs: long tensor, (b,m)
wav_lengths: long tensor, (b,) wav_lengths: long tensor, (b,)
raw_mels: MEL float tensor (b,80,s) raw_mels: MEL float tensor (b,80,s)
@ -226,7 +227,16 @@ class UnifiedGptVoice(nn.Module):
assert self.max_mel_tokens >= mel_codes.shape[1], f'{mel_codes.shape[1]}' assert self.max_mel_tokens >= mel_codes.shape[1], f'{mel_codes.shape[1]}'
assert self.max_text_tokens >= text_inputs.shape[1], f'{text_inputs.shape[1]}' assert self.max_text_tokens >= text_inputs.shape[1], f'{text_inputs.shape[1]}'
# This model will receive micro-batches with a ton of padding for both the text and MELs. Ameliorate this by
# chopping the inputs by the maximum actual length.
max_text_len = text_lengths.max()
text_inputs = F.pad(text_inputs[:, :max_text_len], (0,1), value=self.stop_text_token)
max_mel_len = wav_lengths.max() // self.mel_length_compression
mel_codes = F.pad(mel_codes[:, :max_mel_len], (0,1), value=self.stop_mel_token)
if raw_mels is not None:
raw_mels = raw_mels[:, :, :max_mel_len*4]
mel_codes = self.set_mel_padding(mel_codes, wav_lengths) mel_codes = self.set_mel_padding(mel_codes, wav_lengths)
if self.shuffle_conditioning: if self.shuffle_conditioning:
speech_conditioning_input = self.randomly_permute_conditioning_input(speech_conditioning_input) speech_conditioning_input = self.randomly_permute_conditioning_input(speech_conditioning_input)
speech_conditioning_input = self.conditioning_encoder(speech_conditioning_input).unsqueeze(1) speech_conditioning_input = self.conditioning_encoder(speech_conditioning_input).unsqueeze(1)
@ -235,7 +245,7 @@ class UnifiedGptVoice(nn.Module):
text_emb = self.text_embedding(text_inputs) + self.text_pos_embedding(torch.arange(text_inputs.shape[1], device=text_inputs.device)) text_emb = self.text_embedding(text_inputs) + self.text_pos_embedding(torch.arange(text_inputs.shape[1], device=text_inputs.device))
mel_codes, mel_targets = self.build_aligned_inputs_and_targets(mel_codes, self.start_mel_token, self.stop_mel_token) mel_codes, mel_targets = self.build_aligned_inputs_and_targets(mel_codes, self.start_mel_token, self.stop_mel_token)
if raw_mels is not None: if raw_mels is not None:
mel_inp = F.pad(raw_mels, (0, 4)) mel_inp = F.pad(raw_mels, (0, 8))
else: else:
mel_inp = mel_codes mel_inp = mel_codes
mel_emb = self.gpt.get_input_embeddings()(mel_inp) mel_emb = self.gpt.get_input_embeddings()(mel_inp)
@ -251,13 +261,18 @@ class UnifiedGptVoice(nn.Module):
loss_mel = F.cross_entropy(mel_logits, mel_targets.long()) loss_mel = F.cross_entropy(mel_logits, mel_targets.long())
return loss_text.mean(), loss_mel.mean(), mel_logits return loss_text.mean(), loss_mel.mean(), mel_logits
def text_forward(self, speech_conditioning_input, text_inputs): def text_forward(self, speech_conditioning_input, text_inputs, text_lengths):
""" """
Performs autoregressive modeling on only text. Still requires a speech_conditioning_input due to the way the Performs autoregressive modeling on only text. Still requires a speech_conditioning_input due to the way the
model inputs are formatted. Just provide any audio clip (arguably, zeros could be provided). model inputs are formatted. Just provide any audio clip (arguably, zeros could be provided).
""" """
assert self.max_text_tokens >= text_inputs.shape[1], f'{text_inputs.shape[1]}' assert self.max_text_tokens >= text_inputs.shape[1], f'{text_inputs.shape[1]}'
# This model will receive micro-batches with a ton of padding for both the text and MELs. Ameliorate this by
# chopping the inputs by the maximum actual length.
max_text_len = text_lengths.max()
text_inputs = F.pad(text_inputs[:, :max_text_len], (0,1), value=self.stop_text_token)
if self.shuffle_conditioning: if self.shuffle_conditioning:
speech_conditioning_input = self.randomly_permute_conditioning_input(speech_conditioning_input) speech_conditioning_input = self.randomly_permute_conditioning_input(speech_conditioning_input)
speech_conditioning_input = self.conditioning_encoder(speech_conditioning_input).unsqueeze(1) speech_conditioning_input = self.conditioning_encoder(speech_conditioning_input).unsqueeze(1)
@ -274,7 +289,14 @@ class UnifiedGptVoice(nn.Module):
""" """
assert self.max_mel_tokens >= mel_codes.shape[1], f'{mel_codes.shape[1]}' assert self.max_mel_tokens >= mel_codes.shape[1], f'{mel_codes.shape[1]}'
# This model will receive micro-batches with a ton of padding for both the text and MELs. Ameliorate this by
# chopping the inputs by the maximum actual length.
max_mel_len = wav_lengths.max() // self.mel_length_compression
mel_codes = F.pad(mel_codes[:, :max_mel_len], (0,1), value=self.stop_mel_token)
mel_codes = self.set_mel_padding(mel_codes, wav_lengths) mel_codes = self.set_mel_padding(mel_codes, wav_lengths)
if raw_mels is not None:
raw_mels = raw_mels[:, :, :max_mel_len*4]
if self.shuffle_conditioning: if self.shuffle_conditioning:
speech_conditioning_input = self.randomly_permute_conditioning_input(speech_conditioning_input) speech_conditioning_input = self.randomly_permute_conditioning_input(speech_conditioning_input)
speech_conditioning_input = self.conditioning_encoder(speech_conditioning_input).unsqueeze(1) speech_conditioning_input = self.conditioning_encoder(speech_conditioning_input).unsqueeze(1)
@ -294,6 +316,7 @@ class UnifiedGptVoice(nn.Module):
if not hasattr(self, 'inference_model'): if not hasattr(self, 'inference_model'):
self.inference_model = GPT2InferenceModel(self.gpt_config, self.gpt, self.mel_pos_embedding, self.final_norm, self.mel_head) self.inference_model = GPT2InferenceModel(self.gpt_config, self.gpt, self.mel_pos_embedding, self.final_norm, self.mel_head)
text_inputs = F.pad(text_inputs, (0, 1), value=self.stop_text_token)
text_inputs, text_targets = self.build_aligned_inputs_and_targets(text_inputs, self.start_text_token, self.stop_text_token) text_inputs, text_targets = self.build_aligned_inputs_and_targets(text_inputs, self.start_text_token, self.stop_text_token)
text_emb = self.text_embedding(text_inputs) + self.text_pos_embedding(torch.arange(text_inputs.shape[1], device=text_inputs.device)) text_emb = self.text_embedding(text_inputs) + self.text_pos_embedding(torch.arange(text_inputs.shape[1], device=text_inputs.device))
@ -319,10 +342,10 @@ def register_unified_gpt_voice(opt_net, opt):
if __name__ == '__main__': if __name__ == '__main__':
gpt = UnifiedGptVoice(model_dim=256, heads=4, train_solo_embeddings=True, use_mel_codes_as_input=False) gpt = UnifiedGptVoice(model_dim=256, heads=4, train_solo_embeddings=True, use_mel_codes_as_input=True)
l = gpt(torch.randn(2, 80, 800), l = gpt(torch.randn(2, 80, 800),
torch.randint(high=len(symbols), size=(2,80)), torch.randint(high=len(symbols), size=(2,80)),
torch.tensor([32, 80]),
torch.randint(high=8192, size=(2,250)), torch.randint(high=8192, size=(2,250)),
torch.tensor([150*256,195*256]), torch.tensor([150*256,195*256]))
raw_mels=torch.randn(2,80,1000)) gpt.text_forward(torch.randn(2,80,800), torch.randint(high=50, size=(2,80)), torch.tensor([32, 80]))
gpt.text_forward(torch.randn(2,80,800), torch.randint(high=50, size=(2,80)))