Simplify and conform gpt_asr_hf2
This commit is contained in:
parent
a5b4bee719
commit
a9ee5b624f
|
@ -211,10 +211,11 @@ def null_position_embeddings(range, dim):
|
|||
|
||||
class GptAsrHf2(nn.Module):
|
||||
def __init__(self, layers=8, model_dim=512, heads=8, max_symbols_per_phrase=800, max_mel_frames=3000, checkpointing=True,
|
||||
number_text_tokens=512, start_token=511):
|
||||
number_text_tokens=512, start_token=511, stop_token=0):
|
||||
super().__init__()
|
||||
self.number_text_tokens = number_text_tokens
|
||||
self.start_token = start_token
|
||||
self.stop_token = 0
|
||||
|
||||
self.max_mel_frames = max_mel_frames // 4 # Mel frames are reduced by a factor of 4 during encoding.
|
||||
self.max_symbols_per_phrase = max_symbols_per_phrase
|
||||
|
@ -248,12 +249,12 @@ class GptAsrHf2(nn.Module):
|
|||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
|
||||
def build_aligned_inputs_and_targets(self, input, start_token, stop_token):
|
||||
inp = F.pad(input, (1,0), value=start_token)
|
||||
tar = F.pad(input, (0,1), value=stop_token)
|
||||
return inp, tar
|
||||
|
||||
def get_logits(self, mel_inputs, text_targets, pos_emb, get_attns=False):
|
||||
# Pad front and remove last element to set up next token prediction. Pad at front is the "START" token.
|
||||
text_inputs = F.pad(text_targets, (1,0), value=self.start_token)[:, :-1]
|
||||
text_emb = self.gpt.get_input_embeddings()(text_inputs)
|
||||
text_emb = text_emb + pos_emb(torch.arange(text_emb.shape[1], device=text_inputs.device))
|
||||
def get_logits(self, mel_inputs, text_emb, get_attns=False):
|
||||
if mel_inputs is None:
|
||||
emb = text_emb
|
||||
mel_len = 0
|
||||
|
@ -272,21 +273,26 @@ class GptAsrHf2(nn.Module):
|
|||
text_logits = text_logits.permute(0,2,1)
|
||||
return text_logits
|
||||
|
||||
def forward(self, mel_inputs, text_targets, return_attentions=False):
|
||||
text_targets = F.pad(text_targets, (0,1)) # Pad the targets with a <0> so that all have a "stop" token.
|
||||
text_logits = self.get_logits(mel_inputs, text_targets, self.text_pos_embedding, get_attns=return_attentions)
|
||||
def forward(self, mel_inputs, text_inputs, return_attentions=False):
|
||||
text_inputs, text_targets = self.build_aligned_inputs_and_targets(text_inputs, self.start_token, self.stop_token)
|
||||
text_emb = self.gpt.get_input_embeddings()(text_inputs) + \
|
||||
self.text_pos_embedding(torch.arange(text_inputs.shape[1], device=text_inputs.device))
|
||||
text_logits = self.get_logits(mel_inputs, text_emb, get_attns=return_attentions)
|
||||
|
||||
if return_attentions:
|
||||
return text_logits # These weren't really the logits.
|
||||
loss_text = F.cross_entropy(text_logits, text_targets.long())
|
||||
return loss_text.mean(), text_logits
|
||||
|
||||
def text_only(self, text_targets):
|
||||
text_targets = F.pad(text_targets, (0,1)) # Pad the targets with a <0> so that all have a "stop" token.
|
||||
text_logits = self.get_logits(None, text_targets, self.text_solo_pos_embedding)
|
||||
def text_only(self, text_inputs):
|
||||
text_inputs, text_targets = self.build_aligned_inputs_and_targets(text_inputs, self.start_token, self.stop_token)
|
||||
text_emb = self.gpt.get_input_embeddings()(text_inputs) + \
|
||||
self.text_pos_embedding(torch.arange(text_inputs.shape[1], device=text_inputs.device))
|
||||
text_logits = self.get_logits(None, text_emb)
|
||||
loss_text = F.cross_entropy(text_logits, text_targets.long())
|
||||
return loss_text.mean(), text_logits
|
||||
|
||||
def inference(self, mel_inputs, cond_text=None, do_sample=False, temperature=1.0, num_beams=8):
|
||||
def inference(self, mel_inputs, do_sample=False, temperature=1.0, num_beams=8):
|
||||
if not hasattr(self, 'inference_model'):
|
||||
self.inference_model = GPT2InferenceModel(self.gpt_config, self.gpt, self.text_pos_embedding, self.final_norm, self.text_head)
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user