From 5b5cbc057cf759981ee14d334b1c1c233800a6ae Mon Sep 17 00:00:00 2001 From: James Betker Date: Sun, 26 Dec 2021 10:29:12 -0700 Subject: [PATCH] Work checkpoint for gpt asr hf2 --- codes/models/gpt_voice/gpt_asr_hf2.py | 76 ++++++++++++++++++--------- 1 file changed, 52 insertions(+), 24 deletions(-) diff --git a/codes/models/gpt_voice/gpt_asr_hf2.py b/codes/models/gpt_voice/gpt_asr_hf2.py index 8ffdf5de..69df9b21 100644 --- a/codes/models/gpt_voice/gpt_asr_hf2.py +++ b/codes/models/gpt_voice/gpt_asr_hf2.py @@ -1,3 +1,5 @@ +import functools +import random from time import time import torch @@ -203,13 +205,17 @@ class GPT2InferenceModel(GPT2PreTrainedModel): ) -class GptAsrHf2(nn.Module): - NUMBER_SYMBOLS = len(symbols) - START_TOKEN = NUMBER_SYMBOLS - NUMBER_TEXT_TOKENS = NUMBER_SYMBOLS+2 +def null_position_embeddings(range, dim): + return torch.zeros((range.shape[0], range.shape[1], dim), device=range.device) - def __init__(self, layers=8, model_dim=512, heads=8, max_symbols_per_phrase=800, max_mel_frames=3000, checkpointing=True): + +class GptAsrHf2(nn.Module): + def __init__(self, layers=8, model_dim=512, heads=8, max_symbols_per_phrase=800, max_mel_frames=3000, checkpointing=True, + number_text_tokens=512, start_token=511): super().__init__() + self.number_text_tokens = number_text_tokens + self.start_token = start_token + self.max_mel_frames = max_mel_frames // 4 # Mel frames are reduced by a factor of 4 during encoding. self.max_symbols_per_phrase = max_symbols_per_phrase @@ -219,33 +225,48 @@ class GptAsrHf2(nn.Module): self.text_pos_embedding = nn.Embedding(self.max_symbols_per_phrase + 1, model_dim) self.mel_pos_embedding = nn.Embedding(self.max_mel_frames, model_dim) seq_length = 2+self.max_symbols_per_phrase+self.max_mel_frames - self.gpt_config = GPT2Config(vocab_size=self.NUMBER_TEXT_TOKENS, - n_positions=seq_length, - n_ctx=seq_length, - n_embd=model_dim, - n_layer=layers, - n_head=heads, - gradient_checkpointing=checkpointing, - use_cache=not checkpointing) + self.gpt_config = GPT2Config(vocab_size=self.number_text_tokens, + n_positions=seq_length, + n_ctx=seq_length, + n_embd=model_dim, + n_layer=layers, + n_head=heads, + gradient_checkpointing=checkpointing, + use_cache=not checkpointing) self.gpt = GPT2Model(self.gpt_config) + # Override the built in positional embeddings + del self.gpt.wpe + self.gpt.wpe = functools.partial(null_position_embeddings, dim=model_dim) + self.final_norm = nn.LayerNorm(model_dim) - self.text_head = nn.Linear(model_dim, self.NUMBER_TEXT_TOKENS) + self.text_head = nn.Linear(model_dim, self.number_text_tokens) + + # Initialize the embeddings per the GPT-2 scheme + for module in [self.text_pos_embedding, self.mel_pos_embedding]: + module.weight.data.normal_(mean=0.0, std=self.gpt.config.initializer_range) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() def get_logits(self, mel_inputs, text_targets, get_attns=False): # Pad front and remove last element to set up next token prediction. Pad at front is the "START" token. - text_inputs = F.pad(text_targets, (1,0), value=self.START_TOKEN)[:, :-1] + text_inputs = F.pad(text_targets, (1,0), value=self.start_token)[:, :-1] text_emb = self.gpt.get_input_embeddings()(text_inputs) text_emb = text_emb + self.text_pos_embedding(torch.arange(text_emb.shape[1], device=text_inputs.device)) - mel_emb = self.mel_encoder(mel_inputs) - mel_emb = mel_emb.permute(0,2,1).contiguous() - mel_emb = mel_emb + self.mel_pos_embedding(torch.arange(mel_emb.shape[1], device=mel_emb.device)) - emb = torch.cat([mel_emb, text_emb], dim=1) + if mel_inputs is None: + emb = text_emb + mel_len = 0 + else: + mel_emb = self.mel_encoder(mel_inputs) + mel_emb = mel_emb.permute(0,2,1).contiguous() + mel_emb = mel_emb + self.mel_pos_embedding(torch.arange(mel_emb.shape[1], device=mel_emb.device)) + emb = torch.cat([mel_emb, text_emb], dim=1) + mel_len = mel_emb.shape[1] gpt_out = self.gpt(inputs_embeds=emb, return_dict=True, output_attentions=get_attns) if get_attns: return gpt_out.attentions enc = gpt_out.last_hidden_state - text_logits = self.final_norm(enc[:, mel_emb.shape[1]:]) + text_logits = self.final_norm(enc[:, mel_len:]) text_logits = self.text_head(text_logits) text_logits = text_logits.permute(0,2,1) return text_logits @@ -258,6 +279,12 @@ class GptAsrHf2(nn.Module): loss_text = F.cross_entropy(text_logits, text_targets.long()) return loss_text.mean(), text_logits + def text_only(self, text_targets): + text_targets = F.pad(text_targets, (0,1)) # Pad the targets with a <0> so that all have a "stop" token. + text_logits = self.get_logits(None, text_targets) + loss_text = F.cross_entropy(text_logits, text_targets.long()) + return loss_text.mean(), text_logits + def inference(self, mel_inputs, cond_text=None, do_sample=False, temperature=1.0, num_beams=8): if not hasattr(self, 'inference_model'): self.inference_model = GPT2InferenceModel(self.gpt_config, self.gpt, self.text_pos_embedding, self.final_norm, self.text_head) @@ -271,14 +298,14 @@ class GptAsrHf2(nn.Module): # "fake_inputs" are stand-ins for the MEL frames, which will be injected with the prep_inputs function above. if cond_text is None: fake_inputs = torch.full((mel_emb.shape[0],mel_emb.shape[1]+1,), fill_value=1, dtype=torch.long, device=mel_inputs.device) - fake_inputs[:,-1] = self.START_TOKEN + fake_inputs[:,-1] = self.start_token else: cond_used = 10 fake_inputs = torch.full((mel_emb.shape[0],mel_emb.shape[1]+1+cond_used,), fill_value=1, dtype=torch.long, device=mel_inputs.device) - fake_inputs[:,-1-cond_used] = self.START_TOKEN + fake_inputs[:,-1-cond_used] = self.start_token fake_inputs[:, -cond_used:] = cond_text[:, :cond_used] - gen = self.inference_model.generate(fake_inputs, do_sample=do_sample, bos_token_id=self.START_TOKEN, pad_token_id=0, eos_token_id=0, - max_length=self.max_symbols_per_phrase+mel_emb.shape[1], temperature=temperature, num_beams=num_beams, use_cache=True) + gen = self.inference_model.generate(fake_inputs, do_sample=do_sample, bos_token_id=self.start_token, pad_token_id=0, eos_token_id=0, + max_length=self.max_symbols_per_phrase+mel_emb.shape[1], temperature=temperature, num_beams=num_beams, use_cache=True) return gen[:, mel_emb.shape[1]:] @@ -307,6 +334,7 @@ if __name__ == '__main__': gpt = GptAsrHf2(max_symbols_per_phrase=250, max_mel_frames=1400, layers=16, model_dim=512, heads=8) l = gpt(torch.randn(2,80,800), torch.randint(high=len(symbols), size=(2,100))) + gpt.text_only(torch.randint(high=len(symbols), size=(2,100))) #start = time() #gpt.inference(torch.randn(1,80,350), num_beams=1)