forked from mrq/DL-Art-School
uv2 add alignment head
This commit is contained in:
parent
7ff1fbe2be
commit
c68669e1e1
|
@ -238,22 +238,7 @@ class MelEncoder(nn.Module):
|
||||||
class UnifiedVoice(nn.Module):
|
class UnifiedVoice(nn.Module):
|
||||||
def __init__(self, layers=8, model_dim=512, heads=8, max_text_tokens=120, max_mel_tokens=250, max_conditioning_inputs=1,
|
def __init__(self, layers=8, model_dim=512, heads=8, max_text_tokens=120, max_mel_tokens=250, max_conditioning_inputs=1,
|
||||||
mel_length_compression=1024, number_text_tokens=256, number_mel_codes=8194, start_mel_token=8192,
|
mel_length_compression=1024, number_text_tokens=256, number_mel_codes=8194, start_mel_token=8192,
|
||||||
stop_mel_token=8193, start_text_token=None, checkpointing=True, types=1):
|
stop_mel_token=8193, start_text_token=255, checkpointing=True, types=1, only_alignment_head=False):
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
layers: Number of layers in transformer stack.
|
|
||||||
model_dim: Operating dimensions of the transformer
|
|
||||||
heads: Number of transformer heads. Must be divisible by model_dim. Recommend model_dim//64
|
|
||||||
max_text_tokens: Maximum number of text tokens that will be encountered by model.
|
|
||||||
max_mel_tokens: Maximum number of MEL tokens that will be encountered by model.
|
|
||||||
max_conditioning_inputs: Maximum number of conditioning inputs provided to the model. If (1), conditioning input can be of format (b,80,s), otherwise (b,n,80,s).
|
|
||||||
mel_length_compression: The factor between <number_input_samples> and <mel_tokens>. Used to compute MEL code padding given wav input length.
|
|
||||||
number_text_tokens:
|
|
||||||
number_mel_codes:
|
|
||||||
start_mel_token:
|
|
||||||
stop_mel_token:
|
|
||||||
checkpointing:
|
|
||||||
"""
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.number_text_tokens = number_text_tokens
|
self.number_text_tokens = number_text_tokens
|
||||||
|
@ -278,6 +263,15 @@ class UnifiedVoice(nn.Module):
|
||||||
self.final_norm = nn.LayerNorm(model_dim)
|
self.final_norm = nn.LayerNorm(model_dim)
|
||||||
self.text_head = nn.Linear(model_dim, self.number_text_tokens*types+1)
|
self.text_head = nn.Linear(model_dim, self.number_text_tokens*types+1)
|
||||||
self.mel_head = nn.Linear(model_dim, self.number_mel_codes)
|
self.mel_head = nn.Linear(model_dim, self.number_mel_codes)
|
||||||
|
self.alignment_head = nn.Linear(model_dim, 256)
|
||||||
|
|
||||||
|
if only_alignment_head:
|
||||||
|
for p in self.parameters():
|
||||||
|
p.DO_NOT_TRAIN = True
|
||||||
|
p.requires_grad = False
|
||||||
|
for p in self.alignment_head.parameters():
|
||||||
|
del p.DO_NOT_TRAIN
|
||||||
|
p.requires_grad = True
|
||||||
|
|
||||||
# Initialize the embeddings per the GPT-2 scheme
|
# Initialize the embeddings per the GPT-2 scheme
|
||||||
embeddings = [self.text_embedding, self.mel_embedding]
|
embeddings = [self.text_embedding, self.mel_embedding]
|
||||||
|
@ -310,11 +304,8 @@ class UnifiedVoice(nn.Module):
|
||||||
mel_input_tokens[b, actual_end:] = self.stop_mel_token
|
mel_input_tokens[b, actual_end:] = self.stop_mel_token
|
||||||
return mel_input_tokens
|
return mel_input_tokens
|
||||||
|
|
||||||
def get_logits(self, speech_conditioning_inputs, first_inputs, first_head, second_inputs=None, second_head=None, return_latent=False):
|
def get_logits(self, speech_conditioning_inputs, first_inputs, second_inputs, return_latent=False):
|
||||||
if second_inputs is not None:
|
|
||||||
emb = torch.cat([speech_conditioning_inputs, first_inputs, second_inputs], dim=1)
|
emb = torch.cat([speech_conditioning_inputs, first_inputs, second_inputs], dim=1)
|
||||||
else:
|
|
||||||
emb = torch.cat([speech_conditioning_inputs, first_inputs], dim=1)
|
|
||||||
|
|
||||||
gpt_out = self.gpt(inputs_embeds=emb, return_dict=True)
|
gpt_out = self.gpt(inputs_embeds=emb, return_dict=True)
|
||||||
|
|
||||||
|
@ -324,16 +315,19 @@ class UnifiedVoice(nn.Module):
|
||||||
if return_latent:
|
if return_latent:
|
||||||
return enc[:, speech_conditioning_inputs.shape[1]:speech_conditioning_inputs.shape[1]+first_inputs.shape[1]], enc[:, -second_inputs.shape[1]:]
|
return enc[:, speech_conditioning_inputs.shape[1]:speech_conditioning_inputs.shape[1]+first_inputs.shape[1]], enc[:, -second_inputs.shape[1]:]
|
||||||
|
|
||||||
first_logits = enc[:, :first_inputs.shape[1]]
|
text_logits = enc[:, :first_inputs.shape[1]]
|
||||||
first_logits = first_head(first_logits)
|
text_logits = self.text_head(text_logits)
|
||||||
first_logits = first_logits.permute(0,2,1)
|
text_logits = text_logits.permute(0,2,1)
|
||||||
if second_inputs is not None:
|
|
||||||
second_logits = enc[:, -second_inputs.shape[1]:]
|
mel_logits = enc[:, -second_inputs.shape[1]:]
|
||||||
second_logits = second_head(second_logits)
|
mel_logits = self.mel_head(mel_logits)
|
||||||
second_logits = second_logits.permute(0,2,1)
|
mel_logits = mel_logits.permute(0,2,1)
|
||||||
return first_logits, second_logits
|
|
||||||
else:
|
alignment_logits = enc[:, -second_inputs.shape[1]:]
|
||||||
return first_logits
|
alignment_logits = self.alignment_head(alignment_logits)
|
||||||
|
alignment_logits = alignment_logits.permute(0,2,1)
|
||||||
|
|
||||||
|
return text_logits, mel_logits, alignment_logits
|
||||||
|
|
||||||
|
|
||||||
def get_conditioning_latent(self, speech_conditioning_input):
|
def get_conditioning_latent(self, speech_conditioning_input):
|
||||||
|
@ -346,7 +340,7 @@ class UnifiedVoice(nn.Module):
|
||||||
return conds
|
return conds
|
||||||
|
|
||||||
|
|
||||||
def forward(self, speech_conditioning_input, text_inputs, text_lengths, mel_codes, wav_lengths, types=None, text_first=True, return_latent=False):
|
def forward(self, speech_conditioning_input, text_inputs, text_lengths, mel_codes, ctc_codes, wav_lengths, types=None, return_latent=False):
|
||||||
"""
|
"""
|
||||||
Forward pass that uses both text and voice in either text conditioning mode or voice conditioning mode
|
Forward pass that uses both text and voice in either text conditioning mode or voice conditioning mode
|
||||||
(actuated by `text_first`).
|
(actuated by `text_first`).
|
||||||
|
@ -363,12 +357,22 @@ class UnifiedVoice(nn.Module):
|
||||||
if types is not None:
|
if types is not None:
|
||||||
text_inputs = text_inputs * (1+types).unsqueeze(-1)
|
text_inputs = text_inputs * (1+types).unsqueeze(-1)
|
||||||
|
|
||||||
|
# TODO: do this in the dataloader.
|
||||||
|
for b in range(ctc_codes.shape[0]):
|
||||||
|
last_code = 0
|
||||||
|
for j in range(ctc_codes.shape[1]):
|
||||||
|
if ctc_codes[b][j] == 0:
|
||||||
|
ctc_codes[b][j] = last_code
|
||||||
|
else:
|
||||||
|
last_code = ctc_codes[b][j]
|
||||||
|
alignment_targets = F.interpolate(ctc_codes.unsqueeze(1).float(), size=(mel_codes.shape[-1],), mode='nearest').long().squeeze()
|
||||||
|
|
||||||
mel_codes = self.set_mel_padding(mel_codes, wav_lengths)
|
mel_codes = self.set_mel_padding(mel_codes, wav_lengths)
|
||||||
text_inputs = F.pad(text_inputs, (0,1), value=self.stop_text_token)
|
text_inputs = F.pad(text_inputs, (0,1), value=self.stop_text_token)
|
||||||
mel_codes = F.pad(mel_codes, (0,1), value=self.stop_mel_token)
|
mel_codes = F.pad(mel_codes, (0,1), value=self.stop_mel_token)
|
||||||
|
alignment_targets = F.pad(alignment_targets, (0,2), value=0)
|
||||||
|
|
||||||
conds = self.get_conditioning_latent(speech_conditioning_input)
|
conds = self.get_conditioning_latent(speech_conditioning_input)
|
||||||
|
|
||||||
text_inputs, text_targets = self.build_aligned_inputs_and_targets(text_inputs, self.start_text_token, self.stop_text_token)
|
text_inputs, text_targets = self.build_aligned_inputs_and_targets(text_inputs, self.start_text_token, self.stop_text_token)
|
||||||
text_emb = self.text_embedding(text_inputs) + self.text_pos_embedding(text_inputs)
|
text_emb = self.text_embedding(text_inputs) + self.text_pos_embedding(text_inputs)
|
||||||
mel_codes, mel_targets = self.build_aligned_inputs_and_targets(mel_codes, self.start_mel_token, self.stop_mel_token)
|
mel_codes, mel_targets = self.build_aligned_inputs_and_targets(mel_codes, self.start_mel_token, self.stop_mel_token)
|
||||||
|
@ -376,18 +380,14 @@ class UnifiedVoice(nn.Module):
|
||||||
mel_emb = self.mel_embedding(mel_inp)
|
mel_emb = self.mel_embedding(mel_inp)
|
||||||
mel_emb = mel_emb + self.mel_pos_embedding(mel_codes)
|
mel_emb = mel_emb + self.mel_pos_embedding(mel_codes)
|
||||||
|
|
||||||
if text_first:
|
text_logits, mel_logits, alignment_logits = self.get_logits(conds, text_emb, mel_emb, return_latent=return_latent)
|
||||||
text_logits, mel_logits = self.get_logits(conds, text_emb, self.text_head, mel_emb, self.mel_head, return_latent=return_latent)
|
|
||||||
if return_latent:
|
if return_latent:
|
||||||
return mel_logits[:, :-2] # Despite the name, these are not logits. Strip off the two tokens added by this forward pass.
|
return mel_logits[:, :-2] # Despite the name, these are not logits. Strip off the two tokens added by this forward pass.
|
||||||
else:
|
|
||||||
mel_logits, text_logits = self.get_logits(conds, mel_emb, self.mel_head, text_emb, self.text_head, return_latent=return_latent)
|
|
||||||
if return_latent:
|
|
||||||
return text_logits[:, :-2] # Despite the name, these are not logits. Strip off the two tokens added by this forward pass.
|
|
||||||
|
|
||||||
loss_text = F.cross_entropy(text_logits, text_targets.long())
|
loss_text = F.cross_entropy(text_logits, text_targets.long())
|
||||||
loss_mel = F.cross_entropy(mel_logits, mel_targets.long())
|
loss_mel = F.cross_entropy(mel_logits, mel_targets.long())
|
||||||
return loss_text.mean(), loss_mel.mean(), mel_logits
|
loss_alignment = F.cross_entropy(alignment_logits, alignment_targets)
|
||||||
|
return loss_text.mean(), loss_mel.mean(), loss_alignment, mel_logits
|
||||||
|
|
||||||
def inference_speech(self, speech_conditioning_input, text_inputs, **hf_generate_kwargs):
|
def inference_speech(self, speech_conditioning_input, text_inputs, **hf_generate_kwargs):
|
||||||
if self.max_mel_tokens == -1: # Assume if this is the case, max_mel_tokens=-1 also
|
if self.max_mel_tokens == -1: # Assume if this is the case, max_mel_tokens=-1 also
|
||||||
|
|
Loading…
Reference in New Issue
Block a user