forked from mrq/DL-Art-School
Unified: automatically clip inputs according to specified max length to improve inference time
This commit is contained in:
parent
61cd351b71
commit
525addffab
|
@ -115,9 +115,9 @@ class UnifiedGptVoice(nn.Module):
|
|||
self.mel_length_compression = mel_length_compression
|
||||
self.conditioning_encoder = ConditioningEncoder(80, model_dim, num_attn_heads=heads)
|
||||
self.text_embedding = nn.Embedding(self.number_text_tokens, model_dim)
|
||||
self.text_pos_embedding = nn.Embedding(self.max_text_tokens + 1, model_dim)
|
||||
self.mel_pos_embedding = nn.Embedding(self.max_mel_tokens + 1, model_dim)
|
||||
seq_length = 2+max_text_tokens+self.max_mel_tokens+self.max_conditioning_inputs
|
||||
self.text_pos_embedding = nn.Embedding(self.max_text_tokens + 2, model_dim)
|
||||
self.mel_pos_embedding = nn.Embedding(self.max_mel_tokens + 2, model_dim)
|
||||
seq_length = 4+max_text_tokens+self.max_mel_tokens+self.max_conditioning_inputs
|
||||
self.gpt_config = GPT2Config(vocab_size=self.number_mel_codes,
|
||||
n_positions=seq_length,
|
||||
n_ctx=seq_length,
|
||||
|
@ -212,13 +212,14 @@ class UnifiedGptVoice(nn.Module):
|
|||
else:
|
||||
return first_logits
|
||||
|
||||
def forward(self, speech_conditioning_input, text_inputs, mel_codes, wav_lengths, text_first=True, raw_mels=None, return_attentions=False):
|
||||
def forward(self, speech_conditioning_input, text_inputs, text_lengths, mel_codes, wav_lengths, text_first=True, raw_mels=None, return_attentions=False):
|
||||
"""
|
||||
Forward pass that uses both text and voice in either text conditioning mode or voice conditioning mode
|
||||
(actuated by `text_first`).
|
||||
|
||||
speech_conditioning_input: MEL float tensor, (b,80,s)
|
||||
text_inputs: long tensor, (b,t)
|
||||
text_lengths: long tensor, (b,)
|
||||
mel_inputs: long tensor, (b,m)
|
||||
wav_lengths: long tensor, (b,)
|
||||
raw_mels: MEL float tensor (b,80,s)
|
||||
|
@ -226,7 +227,16 @@ class UnifiedGptVoice(nn.Module):
|
|||
assert self.max_mel_tokens >= mel_codes.shape[1], f'{mel_codes.shape[1]}'
|
||||
assert self.max_text_tokens >= text_inputs.shape[1], f'{text_inputs.shape[1]}'
|
||||
|
||||
# This model will receive micro-batches with a ton of padding for both the text and MELs. Ameliorate this by
|
||||
# chopping the inputs by the maximum actual length.
|
||||
max_text_len = text_lengths.max()
|
||||
text_inputs = F.pad(text_inputs[:, :max_text_len], (0,1), value=self.stop_text_token)
|
||||
max_mel_len = wav_lengths.max() // self.mel_length_compression
|
||||
mel_codes = F.pad(mel_codes[:, :max_mel_len], (0,1), value=self.stop_mel_token)
|
||||
if raw_mels is not None:
|
||||
raw_mels = raw_mels[:, :, :max_mel_len*4]
|
||||
mel_codes = self.set_mel_padding(mel_codes, wav_lengths)
|
||||
|
||||
if self.shuffle_conditioning:
|
||||
speech_conditioning_input = self.randomly_permute_conditioning_input(speech_conditioning_input)
|
||||
speech_conditioning_input = self.conditioning_encoder(speech_conditioning_input).unsqueeze(1)
|
||||
|
@ -235,7 +245,7 @@ class UnifiedGptVoice(nn.Module):
|
|||
text_emb = self.text_embedding(text_inputs) + self.text_pos_embedding(torch.arange(text_inputs.shape[1], device=text_inputs.device))
|
||||
mel_codes, mel_targets = self.build_aligned_inputs_and_targets(mel_codes, self.start_mel_token, self.stop_mel_token)
|
||||
if raw_mels is not None:
|
||||
mel_inp = F.pad(raw_mels, (0, 4))
|
||||
mel_inp = F.pad(raw_mels, (0, 8))
|
||||
else:
|
||||
mel_inp = mel_codes
|
||||
mel_emb = self.gpt.get_input_embeddings()(mel_inp)
|
||||
|
@ -251,13 +261,18 @@ class UnifiedGptVoice(nn.Module):
|
|||
loss_mel = F.cross_entropy(mel_logits, mel_targets.long())
|
||||
return loss_text.mean(), loss_mel.mean(), mel_logits
|
||||
|
||||
def text_forward(self, speech_conditioning_input, text_inputs):
|
||||
def text_forward(self, speech_conditioning_input, text_inputs, text_lengths):
|
||||
"""
|
||||
Performs autoregressive modeling on only text. Still requires a speech_conditioning_input due to the way the
|
||||
model inputs are formatted. Just provide any audio clip (arguably, zeros could be provided).
|
||||
"""
|
||||
assert self.max_text_tokens >= text_inputs.shape[1], f'{text_inputs.shape[1]}'
|
||||
|
||||
# This model will receive micro-batches with a ton of padding for both the text and MELs. Ameliorate this by
|
||||
# chopping the inputs by the maximum actual length.
|
||||
max_text_len = text_lengths.max()
|
||||
text_inputs = F.pad(text_inputs[:, :max_text_len], (0,1), value=self.stop_text_token)
|
||||
|
||||
if self.shuffle_conditioning:
|
||||
speech_conditioning_input = self.randomly_permute_conditioning_input(speech_conditioning_input)
|
||||
speech_conditioning_input = self.conditioning_encoder(speech_conditioning_input).unsqueeze(1)
|
||||
|
@ -274,7 +289,14 @@ class UnifiedGptVoice(nn.Module):
|
|||
"""
|
||||
assert self.max_mel_tokens >= mel_codes.shape[1], f'{mel_codes.shape[1]}'
|
||||
|
||||
# This model will receive micro-batches with a ton of padding for both the text and MELs. Ameliorate this by
|
||||
# chopping the inputs by the maximum actual length.
|
||||
max_mel_len = wav_lengths.max() // self.mel_length_compression
|
||||
mel_codes = F.pad(mel_codes[:, :max_mel_len], (0,1), value=self.stop_mel_token)
|
||||
mel_codes = self.set_mel_padding(mel_codes, wav_lengths)
|
||||
if raw_mels is not None:
|
||||
raw_mels = raw_mels[:, :, :max_mel_len*4]
|
||||
|
||||
if self.shuffle_conditioning:
|
||||
speech_conditioning_input = self.randomly_permute_conditioning_input(speech_conditioning_input)
|
||||
speech_conditioning_input = self.conditioning_encoder(speech_conditioning_input).unsqueeze(1)
|
||||
|
@ -294,6 +316,7 @@ class UnifiedGptVoice(nn.Module):
|
|||
if not hasattr(self, 'inference_model'):
|
||||
self.inference_model = GPT2InferenceModel(self.gpt_config, self.gpt, self.mel_pos_embedding, self.final_norm, self.mel_head)
|
||||
|
||||
text_inputs = F.pad(text_inputs, (0, 1), value=self.stop_text_token)
|
||||
text_inputs, text_targets = self.build_aligned_inputs_and_targets(text_inputs, self.start_text_token, self.stop_text_token)
|
||||
text_emb = self.text_embedding(text_inputs) + self.text_pos_embedding(torch.arange(text_inputs.shape[1], device=text_inputs.device))
|
||||
|
||||
|
@ -319,10 +342,10 @@ def register_unified_gpt_voice(opt_net, opt):
|
|||
|
||||
|
||||
if __name__ == '__main__':
|
||||
gpt = UnifiedGptVoice(model_dim=256, heads=4, train_solo_embeddings=True, use_mel_codes_as_input=False)
|
||||
gpt = UnifiedGptVoice(model_dim=256, heads=4, train_solo_embeddings=True, use_mel_codes_as_input=True)
|
||||
l = gpt(torch.randn(2, 80, 800),
|
||||
torch.randint(high=len(symbols), size=(2,80)),
|
||||
torch.tensor([32, 80]),
|
||||
torch.randint(high=8192, size=(2,250)),
|
||||
torch.tensor([150*256,195*256]),
|
||||
raw_mels=torch.randn(2,80,1000))
|
||||
gpt.text_forward(torch.randn(2,80,800), torch.randint(high=50, size=(2,80)))
|
||||
torch.tensor([150*256,195*256]))
|
||||
gpt.text_forward(torch.randn(2,80,800), torch.randint(high=50, size=(2,80)), torch.tensor([32, 80]))
|
||||
|
|
Loading…
Reference in New Issue
Block a user