Work checkpoint for gpt asr hf2
This commit is contained in:
parent
cd89e6b42e
commit
5b5cbc057c
|
@ -1,3 +1,5 @@
|
|||
import functools
|
||||
import random
|
||||
from time import time
|
||||
|
||||
import torch
|
||||
|
@ -203,13 +205,17 @@ class GPT2InferenceModel(GPT2PreTrainedModel):
|
|||
)
|
||||
|
||||
|
||||
class GptAsrHf2(nn.Module):
|
||||
NUMBER_SYMBOLS = len(symbols)
|
||||
START_TOKEN = NUMBER_SYMBOLS
|
||||
NUMBER_TEXT_TOKENS = NUMBER_SYMBOLS+2
|
||||
def null_position_embeddings(range, dim):
|
||||
return torch.zeros((range.shape[0], range.shape[1], dim), device=range.device)
|
||||
|
||||
def __init__(self, layers=8, model_dim=512, heads=8, max_symbols_per_phrase=800, max_mel_frames=3000, checkpointing=True):
|
||||
|
||||
class GptAsrHf2(nn.Module):
|
||||
def __init__(self, layers=8, model_dim=512, heads=8, max_symbols_per_phrase=800, max_mel_frames=3000, checkpointing=True,
|
||||
number_text_tokens=512, start_token=511):
|
||||
super().__init__()
|
||||
self.number_text_tokens = number_text_tokens
|
||||
self.start_token = start_token
|
||||
|
||||
self.max_mel_frames = max_mel_frames // 4 # Mel frames are reduced by a factor of 4 during encoding.
|
||||
self.max_symbols_per_phrase = max_symbols_per_phrase
|
||||
|
||||
|
@ -219,33 +225,48 @@ class GptAsrHf2(nn.Module):
|
|||
self.text_pos_embedding = nn.Embedding(self.max_symbols_per_phrase + 1, model_dim)
|
||||
self.mel_pos_embedding = nn.Embedding(self.max_mel_frames, model_dim)
|
||||
seq_length = 2+self.max_symbols_per_phrase+self.max_mel_frames
|
||||
self.gpt_config = GPT2Config(vocab_size=self.NUMBER_TEXT_TOKENS,
|
||||
n_positions=seq_length,
|
||||
n_ctx=seq_length,
|
||||
n_embd=model_dim,
|
||||
n_layer=layers,
|
||||
n_head=heads,
|
||||
gradient_checkpointing=checkpointing,
|
||||
use_cache=not checkpointing)
|
||||
self.gpt_config = GPT2Config(vocab_size=self.number_text_tokens,
|
||||
n_positions=seq_length,
|
||||
n_ctx=seq_length,
|
||||
n_embd=model_dim,
|
||||
n_layer=layers,
|
||||
n_head=heads,
|
||||
gradient_checkpointing=checkpointing,
|
||||
use_cache=not checkpointing)
|
||||
self.gpt = GPT2Model(self.gpt_config)
|
||||
# Override the built in positional embeddings
|
||||
del self.gpt.wpe
|
||||
self.gpt.wpe = functools.partial(null_position_embeddings, dim=model_dim)
|
||||
|
||||
self.final_norm = nn.LayerNorm(model_dim)
|
||||
self.text_head = nn.Linear(model_dim, self.NUMBER_TEXT_TOKENS)
|
||||
self.text_head = nn.Linear(model_dim, self.number_text_tokens)
|
||||
|
||||
# Initialize the embeddings per the GPT-2 scheme
|
||||
for module in [self.text_pos_embedding, self.mel_pos_embedding]:
|
||||
module.weight.data.normal_(mean=0.0, std=self.gpt.config.initializer_range)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
|
||||
|
||||
def get_logits(self, mel_inputs, text_targets, get_attns=False):
|
||||
# Pad front and remove last element to set up next token prediction. Pad at front is the "START" token.
|
||||
text_inputs = F.pad(text_targets, (1,0), value=self.START_TOKEN)[:, :-1]
|
||||
text_inputs = F.pad(text_targets, (1,0), value=self.start_token)[:, :-1]
|
||||
text_emb = self.gpt.get_input_embeddings()(text_inputs)
|
||||
text_emb = text_emb + self.text_pos_embedding(torch.arange(text_emb.shape[1], device=text_inputs.device))
|
||||
mel_emb = self.mel_encoder(mel_inputs)
|
||||
mel_emb = mel_emb.permute(0,2,1).contiguous()
|
||||
mel_emb = mel_emb + self.mel_pos_embedding(torch.arange(mel_emb.shape[1], device=mel_emb.device))
|
||||
emb = torch.cat([mel_emb, text_emb], dim=1)
|
||||
if mel_inputs is None:
|
||||
emb = text_emb
|
||||
mel_len = 0
|
||||
else:
|
||||
mel_emb = self.mel_encoder(mel_inputs)
|
||||
mel_emb = mel_emb.permute(0,2,1).contiguous()
|
||||
mel_emb = mel_emb + self.mel_pos_embedding(torch.arange(mel_emb.shape[1], device=mel_emb.device))
|
||||
emb = torch.cat([mel_emb, text_emb], dim=1)
|
||||
mel_len = mel_emb.shape[1]
|
||||
gpt_out = self.gpt(inputs_embeds=emb, return_dict=True, output_attentions=get_attns)
|
||||
if get_attns:
|
||||
return gpt_out.attentions
|
||||
enc = gpt_out.last_hidden_state
|
||||
text_logits = self.final_norm(enc[:, mel_emb.shape[1]:])
|
||||
text_logits = self.final_norm(enc[:, mel_len:])
|
||||
text_logits = self.text_head(text_logits)
|
||||
text_logits = text_logits.permute(0,2,1)
|
||||
return text_logits
|
||||
|
@ -258,6 +279,12 @@ class GptAsrHf2(nn.Module):
|
|||
loss_text = F.cross_entropy(text_logits, text_targets.long())
|
||||
return loss_text.mean(), text_logits
|
||||
|
||||
def text_only(self, text_targets):
|
||||
text_targets = F.pad(text_targets, (0,1)) # Pad the targets with a <0> so that all have a "stop" token.
|
||||
text_logits = self.get_logits(None, text_targets)
|
||||
loss_text = F.cross_entropy(text_logits, text_targets.long())
|
||||
return loss_text.mean(), text_logits
|
||||
|
||||
def inference(self, mel_inputs, cond_text=None, do_sample=False, temperature=1.0, num_beams=8):
|
||||
if not hasattr(self, 'inference_model'):
|
||||
self.inference_model = GPT2InferenceModel(self.gpt_config, self.gpt, self.text_pos_embedding, self.final_norm, self.text_head)
|
||||
|
@ -271,14 +298,14 @@ class GptAsrHf2(nn.Module):
|
|||
# "fake_inputs" are stand-ins for the MEL frames, which will be injected with the prep_inputs function above.
|
||||
if cond_text is None:
|
||||
fake_inputs = torch.full((mel_emb.shape[0],mel_emb.shape[1]+1,), fill_value=1, dtype=torch.long, device=mel_inputs.device)
|
||||
fake_inputs[:,-1] = self.START_TOKEN
|
||||
fake_inputs[:,-1] = self.start_token
|
||||
else:
|
||||
cond_used = 10
|
||||
fake_inputs = torch.full((mel_emb.shape[0],mel_emb.shape[1]+1+cond_used,), fill_value=1, dtype=torch.long, device=mel_inputs.device)
|
||||
fake_inputs[:,-1-cond_used] = self.START_TOKEN
|
||||
fake_inputs[:,-1-cond_used] = self.start_token
|
||||
fake_inputs[:, -cond_used:] = cond_text[:, :cond_used]
|
||||
gen = self.inference_model.generate(fake_inputs, do_sample=do_sample, bos_token_id=self.START_TOKEN, pad_token_id=0, eos_token_id=0,
|
||||
max_length=self.max_symbols_per_phrase+mel_emb.shape[1], temperature=temperature, num_beams=num_beams, use_cache=True)
|
||||
gen = self.inference_model.generate(fake_inputs, do_sample=do_sample, bos_token_id=self.start_token, pad_token_id=0, eos_token_id=0,
|
||||
max_length=self.max_symbols_per_phrase+mel_emb.shape[1], temperature=temperature, num_beams=num_beams, use_cache=True)
|
||||
return gen[:, mel_emb.shape[1]:]
|
||||
|
||||
|
||||
|
@ -307,6 +334,7 @@ if __name__ == '__main__':
|
|||
|
||||
gpt = GptAsrHf2(max_symbols_per_phrase=250, max_mel_frames=1400, layers=16, model_dim=512, heads=8)
|
||||
l = gpt(torch.randn(2,80,800), torch.randint(high=len(symbols), size=(2,100)))
|
||||
gpt.text_only(torch.randint(high=len(symbols), size=(2,100)))
|
||||
|
||||
#start = time()
|
||||
#gpt.inference(torch.randn(1,80,350), num_beams=1)
|
||||
|
|
Loading…
Reference in New Issue
Block a user