Make gpt-asr more configurable

This commit is contained in:
James Betker 2021-08-19 16:33:41 -06:00
parent 570ed327ed
commit b521d94b01
2 changed files with 14 additions and 14 deletions

View File

@ -49,22 +49,22 @@ class MelEncoder(nn.Module):
class GptAsr(nn.Module):
MAX_SYMBOLS_PER_PHRASE = 200
MAX_MEL_FRAMES = 1000 // 4
NUMBER_SYMBOLS = len(symbols)
NUMBER_TEXT_TOKENS = NUMBER_SYMBOLS+1
def __init__(self, layers=8, model_dim=512, heads=8):
def __init__(self, layers=8, model_dim=512, heads=8, max_symbols_per_phrase=200, max_mel_frames=1000):
super().__init__()
self.max_mel_frames = max_mel_frames // 4 # Mel frames are reduced by a factor of 4 during encoding.
self.max_symbols_per_phrase = max_symbols_per_phrase
self.model_dim = model_dim
self.max_mel_frames = self.MAX_MEL_FRAMES
self.max_mel_frames = self.max_mel_frames
self.text_embedding = nn.Embedding(self.NUMBER_TEXT_TOKENS, model_dim)
self.mel_encoder = MelEncoder(model_dim)
self.text_pos_embedding = nn.Embedding(self.MAX_SYMBOLS_PER_PHRASE+1, model_dim)
self.mel_pos_embedding = nn.Embedding(self.MAX_MEL_FRAMES, model_dim)
self.gpt = Transformer(dim=model_dim, depth=layers, seq_len=2+self.MAX_SYMBOLS_PER_PHRASE+self.MAX_MEL_FRAMES, heads=heads,
attn_dropout=.1, ff_dropout=.1, non_causal_sequence_partition=self.MAX_MEL_FRAMES)
self.text_pos_embedding = nn.Embedding(self.max_symbols_per_phrase + 1, model_dim)
self.mel_pos_embedding = nn.Embedding(self.max_mel_frames, model_dim)
self.gpt = Transformer(dim=model_dim, depth=layers, seq_len=2 + self.max_symbols_per_phrase + self.max_mel_frames, heads=heads,
attn_dropout=.1, ff_dropout=.1, non_causal_sequence_partition=self.max_mel_frames)
self.final_norm = nn.LayerNorm(model_dim)
self.text_head = nn.Linear(model_dim, self.NUMBER_TEXT_TOKENS)
@ -72,11 +72,11 @@ class GptAsr(nn.Module):
def forward(self, mel_inputs, text_targets):
# Pad front and back. Pad at front is the "START" token.
text_targets = F.pad(text_targets, (1,0), value=self.NUMBER_SYMBOLS)
text_targets = F.pad(text_targets, (0, self.MAX_SYMBOLS_PER_PHRASE-text_targets.shape[1]))
text_targets = F.pad(text_targets, (0, self.max_symbols_per_phrase - text_targets.shape[1]))
text_emb = self.text_embedding(text_targets)
text_emb = text_emb + self.text_pos_embedding(torch.arange(text_emb.shape[1], device=text_targets.device))
mel_emb = self.mel_encoder(mel_inputs)
mel_emb = F.pad(mel_emb, (0, self.MAX_MEL_FRAMES-mel_emb.shape[-1]))
mel_emb = F.pad(mel_emb, (0, self.max_mel_frames - mel_emb.shape[-1]))
mel_emb = mel_emb.permute(0,2,1).contiguous()
mel_emb = mel_emb + self.mel_pos_embedding(torch.arange(mel_emb.shape[1], device=mel_emb.device))
emb = torch.cat([mel_emb, text_emb], dim=1)
@ -84,7 +84,7 @@ class GptAsr(nn.Module):
enc = self.gpt(emb)
# Compute loss
text_logits = self.final_norm(enc[:, self.MAX_MEL_FRAMES:])
text_logits = self.final_norm(enc[:, self.max_mel_frames:])
text_logits = self.text_head(text_logits)
text_logits = text_logits.permute(0,2,1)
loss_text = F.cross_entropy(text_logits[:,:,:-1], text_targets[:,1:].long())
@ -114,13 +114,13 @@ class GptAsr(nn.Module):
b, _, s = mel_inputs.shape
assert b == 1 # Beam search only works on batches of one.
mel_emb = self.mel_encoder(mel_inputs)
mel_emb = F.pad(mel_emb, (0, self.MAX_MEL_FRAMES-mel_emb.shape[-1]))
mel_emb = F.pad(mel_emb, (0, self.max_mel_frames - mel_emb.shape[-1]))
mel_emb = mel_emb.permute(0,2,1).contiguous()
mel_emb = mel_emb + self.mel_pos_embedding(torch.arange(mel_emb.shape[1], device=mel_emb.device))
text_seq = torch.full((b,1), fill_value=self.NUMBER_SYMBOLS, device=mel_emb.device)
probabilities = torch.ones((b,), device=mel_emb.device)
while text_seq.shape[-1] < self.MAX_SYMBOLS_PER_PHRASE:
while text_seq.shape[-1] < self.max_symbols_per_phrase:
text_emb = self.text_embedding(text_seq)
text_emb = text_emb + self.text_pos_embedding(torch.arange(text_emb.shape[1], device=mel_emb.device))
if text_emb.shape[0] != mel_emb.shape[0]:

View File

@ -87,7 +87,7 @@ if __name__ == "__main__":
sentence_number = 0
last_detection_start = 0
start = 0
clip_size = model.MAX_MEL_FRAMES
clip_size = model.max_mel_frames
while start+clip_size < mels.shape[-1]:
clip = mels[:, :, start:start+clip_size]
preds = torch.nn.functional.sigmoid(model(clip)).squeeze(-1).squeeze(0) # Squeeze off the batch and sigmoid dimensions, leaving only the sequence dimension.