From 6996dfd9d5b590c6e5ebd62a80b16797a829d1ea Mon Sep 17 00:00:00 2001 From: James Betker Date: Sun, 26 Dec 2021 15:17:24 -0700 Subject: [PATCH] asr_hf2: add independent position embedders --- codes/models/gpt_voice/gpt_asr_hf2.py | 9 +++++---- codes/train.py | 2 +- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/codes/models/gpt_voice/gpt_asr_hf2.py b/codes/models/gpt_voice/gpt_asr_hf2.py index 69df9b21..6b4a8e2b 100644 --- a/codes/models/gpt_voice/gpt_asr_hf2.py +++ b/codes/models/gpt_voice/gpt_asr_hf2.py @@ -223,6 +223,7 @@ class GptAsrHf2(nn.Module): self.max_mel_frames = self.max_mel_frames self.mel_encoder = MelEncoder(model_dim) self.text_pos_embedding = nn.Embedding(self.max_symbols_per_phrase + 1, model_dim) + self.text_solo_pos_embedding = nn.Embedding(self.max_symbols_per_phrase + 1, model_dim) self.mel_pos_embedding = nn.Embedding(self.max_mel_frames, model_dim) seq_length = 2+self.max_symbols_per_phrase+self.max_mel_frames self.gpt_config = GPT2Config(vocab_size=self.number_text_tokens, @@ -248,11 +249,11 @@ class GptAsrHf2(nn.Module): module.weight.data[module.padding_idx].zero_() - def get_logits(self, mel_inputs, text_targets, get_attns=False): + def get_logits(self, mel_inputs, text_targets, pos_emb, get_attns=False): # Pad front and remove last element to set up next token prediction. Pad at front is the "START" token. text_inputs = F.pad(text_targets, (1,0), value=self.start_token)[:, :-1] text_emb = self.gpt.get_input_embeddings()(text_inputs) - text_emb = text_emb + self.text_pos_embedding(torch.arange(text_emb.shape[1], device=text_inputs.device)) + text_emb = text_emb + pos_emb(torch.arange(text_emb.shape[1], device=text_inputs.device)) if mel_inputs is None: emb = text_emb mel_len = 0 @@ -273,7 +274,7 @@ class GptAsrHf2(nn.Module): def forward(self, mel_inputs, text_targets, return_attentions=False): text_targets = F.pad(text_targets, (0,1)) # Pad the targets with a <0> so that all have a "stop" token. - text_logits = self.get_logits(mel_inputs, text_targets, get_attns=return_attentions) + text_logits = self.get_logits(mel_inputs, text_targets, self.text_pos_embedding, get_attns=return_attentions) if return_attentions: return text_logits # These weren't really the logits. loss_text = F.cross_entropy(text_logits, text_targets.long()) @@ -281,7 +282,7 @@ class GptAsrHf2(nn.Module): def text_only(self, text_targets): text_targets = F.pad(text_targets, (0,1)) # Pad the targets with a <0> so that all have a "stop" token. - text_logits = self.get_logits(None, text_targets) + text_logits = self.get_logits(None, text_targets, self.text_solo_pos_embedding) loss_text = F.cross_entropy(text_logits, text_targets.long()) return loss_text.mean(), text_logits diff --git a/codes/train.py b/codes/train.py index ed4384cc..9475aa66 100644 --- a/codes/train.py +++ b/codes/train.py @@ -286,7 +286,7 @@ class Trainer: if __name__ == '__main__': parser = argparse.ArgumentParser() - parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_gpt_unified_voice.yml') + parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_gpt_asr_mass_hf2.yml') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') parser.add_argument('--local_rank', type=int, default=0) args = parser.parse_args()