This commit is contained in:
James Betker 2022-05-09 15:36:22 -06:00
parent 96a5cc66ee
commit 545453077e

View File

@ -238,7 +238,7 @@ class MelEncoder(nn.Module):
class UnifiedVoice(nn.Module): class UnifiedVoice(nn.Module):
def __init__(self, layers=8, model_dim=512, heads=8, max_text_tokens=120, max_mel_tokens=250, max_conditioning_inputs=1, def __init__(self, layers=8, model_dim=512, heads=8, max_text_tokens=120, max_mel_tokens=250, max_conditioning_inputs=1,
mel_length_compression=1024, number_text_tokens=256, number_mel_codes=8194, start_mel_token=8192, mel_length_compression=1024, number_text_tokens=256, number_mel_codes=8194, start_mel_token=8192,
stop_mel_token=8193, start_text_token=None, checkpointing=True, types=1): stop_mel_token=8193, start_text_token=None, number_aligned_text_codes=256, checkpointing=True, types=1):
""" """
Args: Args:
layers: Number of layers in transformer stack. layers: Number of layers in transformer stack.
@ -278,6 +278,7 @@ class UnifiedVoice(nn.Module):
self.final_norm = nn.LayerNorm(model_dim) self.final_norm = nn.LayerNorm(model_dim)
self.text_head = nn.Linear(model_dim, self.number_text_tokens*types+1) self.text_head = nn.Linear(model_dim, self.number_text_tokens*types+1)
self.mel_head = nn.Linear(model_dim, self.number_mel_codes) self.mel_head = nn.Linear(model_dim, self.number_mel_codes)
self.aligned_head = nn.Linear(model_dim, self.number_aligned_text_codes)
# Initialize the embeddings per the GPT-2 scheme # Initialize the embeddings per the GPT-2 scheme
embeddings = [self.text_embedding, self.mel_embedding] embeddings = [self.text_embedding, self.mel_embedding]
@ -310,11 +311,11 @@ class UnifiedVoice(nn.Module):
mel_input_tokens[b, actual_end:] = self.stop_mel_token mel_input_tokens[b, actual_end:] = self.stop_mel_token
return mel_input_tokens return mel_input_tokens
def get_logits(self, speech_conditioning_inputs, first_inputs, first_head, second_inputs=None, second_head=None, return_latent=False): def get_logits(self, speech_conditioning_inputs, text_inputs, text_head, mel_inputs, mel_head, aligned_head, return_latent=False):
if second_inputs is not None: if mel_inputs is not None:
emb = torch.cat([speech_conditioning_inputs, first_inputs, second_inputs], dim=1) emb = torch.cat([speech_conditioning_inputs, text_inputs, mel_inputs], dim=1)
else: else:
emb = torch.cat([speech_conditioning_inputs, first_inputs], dim=1) emb = torch.cat([speech_conditioning_inputs, text_inputs], dim=1)
gpt_out = self.gpt(inputs_embeds=emb, return_dict=True) gpt_out = self.gpt(inputs_embeds=emb, return_dict=True)
@ -322,18 +323,16 @@ class UnifiedVoice(nn.Module):
enc = self.final_norm(enc) enc = self.final_norm(enc)
if return_latent: if return_latent:
return enc[:, speech_conditioning_inputs.shape[1]:speech_conditioning_inputs.shape[1]+first_inputs.shape[1]], enc[:, -second_inputs.shape[1]:] return enc[:, speech_conditioning_inputs.shape[1]:speech_conditioning_inputs.shape[1] + text_inputs.shape[1]], enc[:, -mel_inputs.shape[1]:]
first_logits = enc[:, :first_inputs.shape[1]] text_logits = enc[:, :text_inputs.shape[1]]
first_logits = first_head(first_logits) text_logits = text_head(text_logits).permute(0,2,1)
first_logits = first_logits.permute(0,2,1)
if second_inputs is not None: mel_logits = enc[:, -mel_inputs.shape[1]:]
second_logits = enc[:, -second_inputs.shape[1]:] aligned_logits = aligned_head(mel_logits).permute(0,2,1)
second_logits = second_head(second_logits) mel_logits = mel_head(mel_logits).permute(0,2,1)
second_logits = second_logits.permute(0,2,1)
return first_logits, second_logits return text_logits, mel_logits, aligned_logits
else:
return first_logits
def get_conditioning_latent(self, speech_conditioning_input): def get_conditioning_latent(self, speech_conditioning_input):
@ -346,7 +345,7 @@ class UnifiedVoice(nn.Module):
return conds return conds
def forward(self, speech_conditioning_input, text_inputs, text_lengths, mel_codes, wav_lengths, types=None, text_first=True, return_latent=False): def forward(self, speech_conditioning_input, text_inputs, text_lengths, mel_codes, wav_lengths, aligned_codes, types=None, return_latent=False):
""" """
Forward pass that uses both text and voice in either text conditioning mode or voice conditioning mode Forward pass that uses both text and voice in either text conditioning mode or voice conditioning mode
(actuated by `text_first`). (actuated by `text_first`).
@ -356,6 +355,7 @@ class UnifiedVoice(nn.Module):
text_lengths: long tensor, (b,) text_lengths: long tensor, (b,)
mel_inputs: long tensor, (b,m) mel_inputs: long tensor, (b,m)
wav_lengths: long tensor, (b,) wav_lengths: long tensor, (b,)
aligned_codes: long tensor, (b,m/C) where C is some constant.
If return_latent is specified, loss & logits are not computed or returned. Only the predicted latents are returned. If return_latent is specified, loss & logits are not computed or returned. Only the predicted latents are returned.
""" """
@ -369,6 +369,10 @@ class UnifiedVoice(nn.Module):
conds = self.get_conditioning_latent(speech_conditioning_input) conds = self.get_conditioning_latent(speech_conditioning_input)
ac_expansion_factor = mel_codes.shape[-1] // aligned_codes.shape[-1]
aligned_codes = aligned_codes.repeat(1, ac_expansion_factor)
_, aligned_targets = self.build_aligned_inputs_and_targets(aligned_codes, 0, 0)
text_inputs, text_targets = self.build_aligned_inputs_and_targets(text_inputs, self.start_text_token, self.stop_text_token) text_inputs, text_targets = self.build_aligned_inputs_and_targets(text_inputs, self.start_text_token, self.stop_text_token)
text_emb = self.text_embedding(text_inputs) + self.text_pos_embedding(text_inputs) text_emb = self.text_embedding(text_inputs) + self.text_pos_embedding(text_inputs)
mel_codes, mel_targets = self.build_aligned_inputs_and_targets(mel_codes, self.start_mel_token, self.stop_mel_token) mel_codes, mel_targets = self.build_aligned_inputs_and_targets(mel_codes, self.start_mel_token, self.stop_mel_token)
@ -376,18 +380,15 @@ class UnifiedVoice(nn.Module):
mel_emb = self.mel_embedding(mel_inp) mel_emb = self.mel_embedding(mel_inp)
mel_emb = mel_emb + self.mel_pos_embedding(mel_codes) mel_emb = mel_emb + self.mel_pos_embedding(mel_codes)
if text_first: text_logits, mel_logits, aligned_logits = self.get_logits(conds, text_emb, self.text_head, mel_emb, self.mel_head,
text_logits, mel_logits = self.get_logits(conds, text_emb, self.text_head, mel_emb, self.mel_head, return_latent=return_latent) self.aligned_head, return_latent=return_latent)
if return_latent: if return_latent:
return mel_logits[:, :-2] # Despite the name, these are not logits. Strip off the two tokens added by this forward pass. return mel_logits[:, :-2] # Despite the name, these are not logits. Strip off the two tokens added by this forward pass.
else:
mel_logits, text_logits = self.get_logits(conds, mel_emb, self.mel_head, text_emb, self.text_head, return_latent=return_latent)
if return_latent:
return text_logits[:, :-2] # Despite the name, these are not logits. Strip off the two tokens added by this forward pass.
loss_text = F.cross_entropy(text_logits, text_targets.long()) loss_text = F.cross_entropy(text_logits, text_targets.long())
loss_mel = F.cross_entropy(mel_logits, mel_targets.long()) loss_mel = F.cross_entropy(mel_logits, mel_targets.long())
return loss_text.mean(), loss_mel.mean(), mel_logits loss_aligned = F.cross_entropy(aligned_logits, aligned_targets.long())
return loss_text.mean(), loss_mel.mean(), loss_aligned.mean(), mel_logits
def inference_speech(self, speech_conditioning_input, text_inputs, **hf_generate_kwargs): def inference_speech(self, speech_conditioning_input, text_inputs, **hf_generate_kwargs):
if self.max_mel_tokens == -1: # Assume if this is the case, max_mel_tokens=-1 also if self.max_mel_tokens == -1: # Assume if this is the case, max_mel_tokens=-1 also