From a9ee5b624fa68aa66a37948be2f02acb9c0773d9 Mon Sep 17 00:00:00 2001 From: James Betker Date: Tue, 28 Dec 2021 11:54:33 -0700 Subject: [PATCH] Simplify and conform gpt_asr_hf2 --- codes/models/gpt_voice/gpt_asr_hf2.py | 32 ++++++++++++++++----------- 1 file changed, 19 insertions(+), 13 deletions(-) diff --git a/codes/models/gpt_voice/gpt_asr_hf2.py b/codes/models/gpt_voice/gpt_asr_hf2.py index 246cde02..a4991b44 100644 --- a/codes/models/gpt_voice/gpt_asr_hf2.py +++ b/codes/models/gpt_voice/gpt_asr_hf2.py @@ -211,10 +211,11 @@ def null_position_embeddings(range, dim): class GptAsrHf2(nn.Module): def __init__(self, layers=8, model_dim=512, heads=8, max_symbols_per_phrase=800, max_mel_frames=3000, checkpointing=True, - number_text_tokens=512, start_token=511): + number_text_tokens=512, start_token=511, stop_token=0): super().__init__() self.number_text_tokens = number_text_tokens self.start_token = start_token + self.stop_token = 0 self.max_mel_frames = max_mel_frames // 4 # Mel frames are reduced by a factor of 4 during encoding. self.max_symbols_per_phrase = max_symbols_per_phrase @@ -248,12 +249,12 @@ class GptAsrHf2(nn.Module): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() + def build_aligned_inputs_and_targets(self, input, start_token, stop_token): + inp = F.pad(input, (1,0), value=start_token) + tar = F.pad(input, (0,1), value=stop_token) + return inp, tar - def get_logits(self, mel_inputs, text_targets, pos_emb, get_attns=False): - # Pad front and remove last element to set up next token prediction. Pad at front is the "START" token. - text_inputs = F.pad(text_targets, (1,0), value=self.start_token)[:, :-1] - text_emb = self.gpt.get_input_embeddings()(text_inputs) - text_emb = text_emb + pos_emb(torch.arange(text_emb.shape[1], device=text_inputs.device)) + def get_logits(self, mel_inputs, text_emb, get_attns=False): if mel_inputs is None: emb = text_emb mel_len = 0 @@ -272,21 +273,26 @@ class GptAsrHf2(nn.Module): text_logits = text_logits.permute(0,2,1) return text_logits - def forward(self, mel_inputs, text_targets, return_attentions=False): - text_targets = F.pad(text_targets, (0,1)) # Pad the targets with a <0> so that all have a "stop" token. - text_logits = self.get_logits(mel_inputs, text_targets, self.text_pos_embedding, get_attns=return_attentions) + def forward(self, mel_inputs, text_inputs, return_attentions=False): + text_inputs, text_targets = self.build_aligned_inputs_and_targets(text_inputs, self.start_token, self.stop_token) + text_emb = self.gpt.get_input_embeddings()(text_inputs) + \ + self.text_pos_embedding(torch.arange(text_inputs.shape[1], device=text_inputs.device)) + text_logits = self.get_logits(mel_inputs, text_emb, get_attns=return_attentions) + if return_attentions: return text_logits # These weren't really the logits. loss_text = F.cross_entropy(text_logits, text_targets.long()) return loss_text.mean(), text_logits - def text_only(self, text_targets): - text_targets = F.pad(text_targets, (0,1)) # Pad the targets with a <0> so that all have a "stop" token. - text_logits = self.get_logits(None, text_targets, self.text_solo_pos_embedding) + def text_only(self, text_inputs): + text_inputs, text_targets = self.build_aligned_inputs_and_targets(text_inputs, self.start_token, self.stop_token) + text_emb = self.gpt.get_input_embeddings()(text_inputs) + \ + self.text_pos_embedding(torch.arange(text_inputs.shape[1], device=text_inputs.device)) + text_logits = self.get_logits(None, text_emb) loss_text = F.cross_entropy(text_logits, text_targets.long()) return loss_text.mean(), text_logits - def inference(self, mel_inputs, cond_text=None, do_sample=False, temperature=1.0, num_beams=8): + def inference(self, mel_inputs, do_sample=False, temperature=1.0, num_beams=8): if not hasattr(self, 'inference_model'): self.inference_model = GPT2InferenceModel(self.gpt_config, self.gpt, self.text_pos_embedding, self.final_norm, self.text_head)