forked from mrq/DL-Art-School
uv3
This commit is contained in:
parent
96a5cc66ee
commit
545453077e
|
@ -238,7 +238,7 @@ class MelEncoder(nn.Module):
|
|||
class UnifiedVoice(nn.Module):
|
||||
def __init__(self, layers=8, model_dim=512, heads=8, max_text_tokens=120, max_mel_tokens=250, max_conditioning_inputs=1,
|
||||
mel_length_compression=1024, number_text_tokens=256, number_mel_codes=8194, start_mel_token=8192,
|
||||
stop_mel_token=8193, start_text_token=None, checkpointing=True, types=1):
|
||||
stop_mel_token=8193, start_text_token=None, number_aligned_text_codes=256, checkpointing=True, types=1):
|
||||
"""
|
||||
Args:
|
||||
layers: Number of layers in transformer stack.
|
||||
|
@ -278,6 +278,7 @@ class UnifiedVoice(nn.Module):
|
|||
self.final_norm = nn.LayerNorm(model_dim)
|
||||
self.text_head = nn.Linear(model_dim, self.number_text_tokens*types+1)
|
||||
self.mel_head = nn.Linear(model_dim, self.number_mel_codes)
|
||||
self.aligned_head = nn.Linear(model_dim, self.number_aligned_text_codes)
|
||||
|
||||
# Initialize the embeddings per the GPT-2 scheme
|
||||
embeddings = [self.text_embedding, self.mel_embedding]
|
||||
|
@ -310,11 +311,11 @@ class UnifiedVoice(nn.Module):
|
|||
mel_input_tokens[b, actual_end:] = self.stop_mel_token
|
||||
return mel_input_tokens
|
||||
|
||||
def get_logits(self, speech_conditioning_inputs, first_inputs, first_head, second_inputs=None, second_head=None, return_latent=False):
|
||||
if second_inputs is not None:
|
||||
emb = torch.cat([speech_conditioning_inputs, first_inputs, second_inputs], dim=1)
|
||||
def get_logits(self, speech_conditioning_inputs, text_inputs, text_head, mel_inputs, mel_head, aligned_head, return_latent=False):
|
||||
if mel_inputs is not None:
|
||||
emb = torch.cat([speech_conditioning_inputs, text_inputs, mel_inputs], dim=1)
|
||||
else:
|
||||
emb = torch.cat([speech_conditioning_inputs, first_inputs], dim=1)
|
||||
emb = torch.cat([speech_conditioning_inputs, text_inputs], dim=1)
|
||||
|
||||
gpt_out = self.gpt(inputs_embeds=emb, return_dict=True)
|
||||
|
||||
|
@ -322,18 +323,16 @@ class UnifiedVoice(nn.Module):
|
|||
enc = self.final_norm(enc)
|
||||
|
||||
if return_latent:
|
||||
return enc[:, speech_conditioning_inputs.shape[1]:speech_conditioning_inputs.shape[1]+first_inputs.shape[1]], enc[:, -second_inputs.shape[1]:]
|
||||
return enc[:, speech_conditioning_inputs.shape[1]:speech_conditioning_inputs.shape[1] + text_inputs.shape[1]], enc[:, -mel_inputs.shape[1]:]
|
||||
|
||||
first_logits = enc[:, :first_inputs.shape[1]]
|
||||
first_logits = first_head(first_logits)
|
||||
first_logits = first_logits.permute(0,2,1)
|
||||
if second_inputs is not None:
|
||||
second_logits = enc[:, -second_inputs.shape[1]:]
|
||||
second_logits = second_head(second_logits)
|
||||
second_logits = second_logits.permute(0,2,1)
|
||||
return first_logits, second_logits
|
||||
else:
|
||||
return first_logits
|
||||
text_logits = enc[:, :text_inputs.shape[1]]
|
||||
text_logits = text_head(text_logits).permute(0,2,1)
|
||||
|
||||
mel_logits = enc[:, -mel_inputs.shape[1]:]
|
||||
aligned_logits = aligned_head(mel_logits).permute(0,2,1)
|
||||
mel_logits = mel_head(mel_logits).permute(0,2,1)
|
||||
|
||||
return text_logits, mel_logits, aligned_logits
|
||||
|
||||
|
||||
def get_conditioning_latent(self, speech_conditioning_input):
|
||||
|
@ -346,7 +345,7 @@ class UnifiedVoice(nn.Module):
|
|||
return conds
|
||||
|
||||
|
||||
def forward(self, speech_conditioning_input, text_inputs, text_lengths, mel_codes, wav_lengths, types=None, text_first=True, return_latent=False):
|
||||
def forward(self, speech_conditioning_input, text_inputs, text_lengths, mel_codes, wav_lengths, aligned_codes, types=None, return_latent=False):
|
||||
"""
|
||||
Forward pass that uses both text and voice in either text conditioning mode or voice conditioning mode
|
||||
(actuated by `text_first`).
|
||||
|
@ -356,6 +355,7 @@ class UnifiedVoice(nn.Module):
|
|||
text_lengths: long tensor, (b,)
|
||||
mel_inputs: long tensor, (b,m)
|
||||
wav_lengths: long tensor, (b,)
|
||||
aligned_codes: long tensor, (b,m/C) where C is some constant.
|
||||
|
||||
If return_latent is specified, loss & logits are not computed or returned. Only the predicted latents are returned.
|
||||
"""
|
||||
|
@ -369,6 +369,10 @@ class UnifiedVoice(nn.Module):
|
|||
|
||||
conds = self.get_conditioning_latent(speech_conditioning_input)
|
||||
|
||||
ac_expansion_factor = mel_codes.shape[-1] // aligned_codes.shape[-1]
|
||||
aligned_codes = aligned_codes.repeat(1, ac_expansion_factor)
|
||||
_, aligned_targets = self.build_aligned_inputs_and_targets(aligned_codes, 0, 0)
|
||||
|
||||
text_inputs, text_targets = self.build_aligned_inputs_and_targets(text_inputs, self.start_text_token, self.stop_text_token)
|
||||
text_emb = self.text_embedding(text_inputs) + self.text_pos_embedding(text_inputs)
|
||||
mel_codes, mel_targets = self.build_aligned_inputs_and_targets(mel_codes, self.start_mel_token, self.stop_mel_token)
|
||||
|
@ -376,18 +380,15 @@ class UnifiedVoice(nn.Module):
|
|||
mel_emb = self.mel_embedding(mel_inp)
|
||||
mel_emb = mel_emb + self.mel_pos_embedding(mel_codes)
|
||||
|
||||
if text_first:
|
||||
text_logits, mel_logits = self.get_logits(conds, text_emb, self.text_head, mel_emb, self.mel_head, return_latent=return_latent)
|
||||
if return_latent:
|
||||
return mel_logits[:, :-2] # Despite the name, these are not logits. Strip off the two tokens added by this forward pass.
|
||||
else:
|
||||
mel_logits, text_logits = self.get_logits(conds, mel_emb, self.mel_head, text_emb, self.text_head, return_latent=return_latent)
|
||||
if return_latent:
|
||||
return text_logits[:, :-2] # Despite the name, these are not logits. Strip off the two tokens added by this forward pass.
|
||||
text_logits, mel_logits, aligned_logits = self.get_logits(conds, text_emb, self.text_head, mel_emb, self.mel_head,
|
||||
self.aligned_head, return_latent=return_latent)
|
||||
if return_latent:
|
||||
return mel_logits[:, :-2] # Despite the name, these are not logits. Strip off the two tokens added by this forward pass.
|
||||
|
||||
loss_text = F.cross_entropy(text_logits, text_targets.long())
|
||||
loss_mel = F.cross_entropy(mel_logits, mel_targets.long())
|
||||
return loss_text.mean(), loss_mel.mean(), mel_logits
|
||||
loss_aligned = F.cross_entropy(aligned_logits, aligned_targets.long())
|
||||
return loss_text.mean(), loss_mel.mean(), loss_aligned.mean(), mel_logits
|
||||
|
||||
def inference_speech(self, speech_conditioning_input, text_inputs, **hf_generate_kwargs):
|
||||
if self.max_mel_tokens == -1: # Assume if this is the case, max_mel_tokens=-1 also
|
||||
|
|
Loading…
Reference in New Issue
Block a user