uv2 add alignment head

This commit is contained in:
James Betker 2022-06-14 15:18:58 -06:00
parent 7ff1fbe2be
commit c68669e1e1

View File

@ -238,22 +238,7 @@ class MelEncoder(nn.Module):
class UnifiedVoice(nn.Module): class UnifiedVoice(nn.Module):
def __init__(self, layers=8, model_dim=512, heads=8, max_text_tokens=120, max_mel_tokens=250, max_conditioning_inputs=1, def __init__(self, layers=8, model_dim=512, heads=8, max_text_tokens=120, max_mel_tokens=250, max_conditioning_inputs=1,
mel_length_compression=1024, number_text_tokens=256, number_mel_codes=8194, start_mel_token=8192, mel_length_compression=1024, number_text_tokens=256, number_mel_codes=8194, start_mel_token=8192,
stop_mel_token=8193, start_text_token=None, checkpointing=True, types=1): stop_mel_token=8193, start_text_token=255, checkpointing=True, types=1, only_alignment_head=False):
"""
Args:
layers: Number of layers in transformer stack.
model_dim: Operating dimensions of the transformer
heads: Number of transformer heads. Must be divisible by model_dim. Recommend model_dim//64
max_text_tokens: Maximum number of text tokens that will be encountered by model.
max_mel_tokens: Maximum number of MEL tokens that will be encountered by model.
max_conditioning_inputs: Maximum number of conditioning inputs provided to the model. If (1), conditioning input can be of format (b,80,s), otherwise (b,n,80,s).
mel_length_compression: The factor between <number_input_samples> and <mel_tokens>. Used to compute MEL code padding given wav input length.
number_text_tokens:
number_mel_codes:
start_mel_token:
stop_mel_token:
checkpointing:
"""
super().__init__() super().__init__()
self.number_text_tokens = number_text_tokens self.number_text_tokens = number_text_tokens
@ -278,6 +263,15 @@ class UnifiedVoice(nn.Module):
self.final_norm = nn.LayerNorm(model_dim) self.final_norm = nn.LayerNorm(model_dim)
self.text_head = nn.Linear(model_dim, self.number_text_tokens*types+1) self.text_head = nn.Linear(model_dim, self.number_text_tokens*types+1)
self.mel_head = nn.Linear(model_dim, self.number_mel_codes) self.mel_head = nn.Linear(model_dim, self.number_mel_codes)
self.alignment_head = nn.Linear(model_dim, 256)
if only_alignment_head:
for p in self.parameters():
p.DO_NOT_TRAIN = True
p.requires_grad = False
for p in self.alignment_head.parameters():
del p.DO_NOT_TRAIN
p.requires_grad = True
# Initialize the embeddings per the GPT-2 scheme # Initialize the embeddings per the GPT-2 scheme
embeddings = [self.text_embedding, self.mel_embedding] embeddings = [self.text_embedding, self.mel_embedding]
@ -310,11 +304,8 @@ class UnifiedVoice(nn.Module):
mel_input_tokens[b, actual_end:] = self.stop_mel_token mel_input_tokens[b, actual_end:] = self.stop_mel_token
return mel_input_tokens return mel_input_tokens
def get_logits(self, speech_conditioning_inputs, first_inputs, first_head, second_inputs=None, second_head=None, return_latent=False): def get_logits(self, speech_conditioning_inputs, first_inputs, second_inputs, return_latent=False):
if second_inputs is not None:
emb = torch.cat([speech_conditioning_inputs, first_inputs, second_inputs], dim=1) emb = torch.cat([speech_conditioning_inputs, first_inputs, second_inputs], dim=1)
else:
emb = torch.cat([speech_conditioning_inputs, first_inputs], dim=1)
gpt_out = self.gpt(inputs_embeds=emb, return_dict=True) gpt_out = self.gpt(inputs_embeds=emb, return_dict=True)
@ -324,16 +315,19 @@ class UnifiedVoice(nn.Module):
if return_latent: if return_latent:
return enc[:, speech_conditioning_inputs.shape[1]:speech_conditioning_inputs.shape[1]+first_inputs.shape[1]], enc[:, -second_inputs.shape[1]:] return enc[:, speech_conditioning_inputs.shape[1]:speech_conditioning_inputs.shape[1]+first_inputs.shape[1]], enc[:, -second_inputs.shape[1]:]
first_logits = enc[:, :first_inputs.shape[1]] text_logits = enc[:, :first_inputs.shape[1]]
first_logits = first_head(first_logits) text_logits = self.text_head(text_logits)
first_logits = first_logits.permute(0,2,1) text_logits = text_logits.permute(0,2,1)
if second_inputs is not None:
second_logits = enc[:, -second_inputs.shape[1]:] mel_logits = enc[:, -second_inputs.shape[1]:]
second_logits = second_head(second_logits) mel_logits = self.mel_head(mel_logits)
second_logits = second_logits.permute(0,2,1) mel_logits = mel_logits.permute(0,2,1)
return first_logits, second_logits
else: alignment_logits = enc[:, -second_inputs.shape[1]:]
return first_logits alignment_logits = self.alignment_head(alignment_logits)
alignment_logits = alignment_logits.permute(0,2,1)
return text_logits, mel_logits, alignment_logits
def get_conditioning_latent(self, speech_conditioning_input): def get_conditioning_latent(self, speech_conditioning_input):
@ -346,7 +340,7 @@ class UnifiedVoice(nn.Module):
return conds return conds
def forward(self, speech_conditioning_input, text_inputs, text_lengths, mel_codes, wav_lengths, types=None, text_first=True, return_latent=False): def forward(self, speech_conditioning_input, text_inputs, text_lengths, mel_codes, ctc_codes, wav_lengths, types=None, return_latent=False):
""" """
Forward pass that uses both text and voice in either text conditioning mode or voice conditioning mode Forward pass that uses both text and voice in either text conditioning mode or voice conditioning mode
(actuated by `text_first`). (actuated by `text_first`).
@ -363,12 +357,22 @@ class UnifiedVoice(nn.Module):
if types is not None: if types is not None:
text_inputs = text_inputs * (1+types).unsqueeze(-1) text_inputs = text_inputs * (1+types).unsqueeze(-1)
# TODO: do this in the dataloader.
for b in range(ctc_codes.shape[0]):
last_code = 0
for j in range(ctc_codes.shape[1]):
if ctc_codes[b][j] == 0:
ctc_codes[b][j] = last_code
else:
last_code = ctc_codes[b][j]
alignment_targets = F.interpolate(ctc_codes.unsqueeze(1).float(), size=(mel_codes.shape[-1],), mode='nearest').long().squeeze()
mel_codes = self.set_mel_padding(mel_codes, wav_lengths) mel_codes = self.set_mel_padding(mel_codes, wav_lengths)
text_inputs = F.pad(text_inputs, (0,1), value=self.stop_text_token) text_inputs = F.pad(text_inputs, (0,1), value=self.stop_text_token)
mel_codes = F.pad(mel_codes, (0,1), value=self.stop_mel_token) mel_codes = F.pad(mel_codes, (0,1), value=self.stop_mel_token)
alignment_targets = F.pad(alignment_targets, (0,2), value=0)
conds = self.get_conditioning_latent(speech_conditioning_input) conds = self.get_conditioning_latent(speech_conditioning_input)
text_inputs, text_targets = self.build_aligned_inputs_and_targets(text_inputs, self.start_text_token, self.stop_text_token) text_inputs, text_targets = self.build_aligned_inputs_and_targets(text_inputs, self.start_text_token, self.stop_text_token)
text_emb = self.text_embedding(text_inputs) + self.text_pos_embedding(text_inputs) text_emb = self.text_embedding(text_inputs) + self.text_pos_embedding(text_inputs)
mel_codes, mel_targets = self.build_aligned_inputs_and_targets(mel_codes, self.start_mel_token, self.stop_mel_token) mel_codes, mel_targets = self.build_aligned_inputs_and_targets(mel_codes, self.start_mel_token, self.stop_mel_token)
@ -376,18 +380,14 @@ class UnifiedVoice(nn.Module):
mel_emb = self.mel_embedding(mel_inp) mel_emb = self.mel_embedding(mel_inp)
mel_emb = mel_emb + self.mel_pos_embedding(mel_codes) mel_emb = mel_emb + self.mel_pos_embedding(mel_codes)
if text_first: text_logits, mel_logits, alignment_logits = self.get_logits(conds, text_emb, mel_emb, return_latent=return_latent)
text_logits, mel_logits = self.get_logits(conds, text_emb, self.text_head, mel_emb, self.mel_head, return_latent=return_latent)
if return_latent: if return_latent:
return mel_logits[:, :-2] # Despite the name, these are not logits. Strip off the two tokens added by this forward pass. return mel_logits[:, :-2] # Despite the name, these are not logits. Strip off the two tokens added by this forward pass.
else:
mel_logits, text_logits = self.get_logits(conds, mel_emb, self.mel_head, text_emb, self.text_head, return_latent=return_latent)
if return_latent:
return text_logits[:, :-2] # Despite the name, these are not logits. Strip off the two tokens added by this forward pass.
loss_text = F.cross_entropy(text_logits, text_targets.long()) loss_text = F.cross_entropy(text_logits, text_targets.long())
loss_mel = F.cross_entropy(mel_logits, mel_targets.long()) loss_mel = F.cross_entropy(mel_logits, mel_targets.long())
return loss_text.mean(), loss_mel.mean(), mel_logits loss_alignment = F.cross_entropy(alignment_logits, alignment_targets)
return loss_text.mean(), loss_mel.mean(), loss_alignment, mel_logits
def inference_speech(self, speech_conditioning_input, text_inputs, **hf_generate_kwargs): def inference_speech(self, speech_conditioning_input, text_inputs, **hf_generate_kwargs):
if self.max_mel_tokens == -1: # Assume if this is the case, max_mel_tokens=-1 also if self.max_mel_tokens == -1: # Assume if this is the case, max_mel_tokens=-1 also