Make gpt-asr more configurable
This commit is contained in:
parent
570ed327ed
commit
b521d94b01
|
@ -49,22 +49,22 @@ class MelEncoder(nn.Module):
|
|||
|
||||
|
||||
class GptAsr(nn.Module):
|
||||
MAX_SYMBOLS_PER_PHRASE = 200
|
||||
MAX_MEL_FRAMES = 1000 // 4
|
||||
NUMBER_SYMBOLS = len(symbols)
|
||||
NUMBER_TEXT_TOKENS = NUMBER_SYMBOLS+1
|
||||
|
||||
def __init__(self, layers=8, model_dim=512, heads=8):
|
||||
def __init__(self, layers=8, model_dim=512, heads=8, max_symbols_per_phrase=200, max_mel_frames=1000):
|
||||
super().__init__()
|
||||
self.max_mel_frames = max_mel_frames // 4 # Mel frames are reduced by a factor of 4 during encoding.
|
||||
self.max_symbols_per_phrase = max_symbols_per_phrase
|
||||
|
||||
self.model_dim = model_dim
|
||||
self.max_mel_frames = self.MAX_MEL_FRAMES
|
||||
self.max_mel_frames = self.max_mel_frames
|
||||
self.text_embedding = nn.Embedding(self.NUMBER_TEXT_TOKENS, model_dim)
|
||||
self.mel_encoder = MelEncoder(model_dim)
|
||||
self.text_pos_embedding = nn.Embedding(self.MAX_SYMBOLS_PER_PHRASE+1, model_dim)
|
||||
self.mel_pos_embedding = nn.Embedding(self.MAX_MEL_FRAMES, model_dim)
|
||||
self.gpt = Transformer(dim=model_dim, depth=layers, seq_len=2+self.MAX_SYMBOLS_PER_PHRASE+self.MAX_MEL_FRAMES, heads=heads,
|
||||
attn_dropout=.1, ff_dropout=.1, non_causal_sequence_partition=self.MAX_MEL_FRAMES)
|
||||
self.text_pos_embedding = nn.Embedding(self.max_symbols_per_phrase + 1, model_dim)
|
||||
self.mel_pos_embedding = nn.Embedding(self.max_mel_frames, model_dim)
|
||||
self.gpt = Transformer(dim=model_dim, depth=layers, seq_len=2 + self.max_symbols_per_phrase + self.max_mel_frames, heads=heads,
|
||||
attn_dropout=.1, ff_dropout=.1, non_causal_sequence_partition=self.max_mel_frames)
|
||||
|
||||
self.final_norm = nn.LayerNorm(model_dim)
|
||||
self.text_head = nn.Linear(model_dim, self.NUMBER_TEXT_TOKENS)
|
||||
|
@ -72,11 +72,11 @@ class GptAsr(nn.Module):
|
|||
def forward(self, mel_inputs, text_targets):
|
||||
# Pad front and back. Pad at front is the "START" token.
|
||||
text_targets = F.pad(text_targets, (1,0), value=self.NUMBER_SYMBOLS)
|
||||
text_targets = F.pad(text_targets, (0, self.MAX_SYMBOLS_PER_PHRASE-text_targets.shape[1]))
|
||||
text_targets = F.pad(text_targets, (0, self.max_symbols_per_phrase - text_targets.shape[1]))
|
||||
text_emb = self.text_embedding(text_targets)
|
||||
text_emb = text_emb + self.text_pos_embedding(torch.arange(text_emb.shape[1], device=text_targets.device))
|
||||
mel_emb = self.mel_encoder(mel_inputs)
|
||||
mel_emb = F.pad(mel_emb, (0, self.MAX_MEL_FRAMES-mel_emb.shape[-1]))
|
||||
mel_emb = F.pad(mel_emb, (0, self.max_mel_frames - mel_emb.shape[-1]))
|
||||
mel_emb = mel_emb.permute(0,2,1).contiguous()
|
||||
mel_emb = mel_emb + self.mel_pos_embedding(torch.arange(mel_emb.shape[1], device=mel_emb.device))
|
||||
emb = torch.cat([mel_emb, text_emb], dim=1)
|
||||
|
@ -84,7 +84,7 @@ class GptAsr(nn.Module):
|
|||
enc = self.gpt(emb)
|
||||
|
||||
# Compute loss
|
||||
text_logits = self.final_norm(enc[:, self.MAX_MEL_FRAMES:])
|
||||
text_logits = self.final_norm(enc[:, self.max_mel_frames:])
|
||||
text_logits = self.text_head(text_logits)
|
||||
text_logits = text_logits.permute(0,2,1)
|
||||
loss_text = F.cross_entropy(text_logits[:,:,:-1], text_targets[:,1:].long())
|
||||
|
@ -114,13 +114,13 @@ class GptAsr(nn.Module):
|
|||
b, _, s = mel_inputs.shape
|
||||
assert b == 1 # Beam search only works on batches of one.
|
||||
mel_emb = self.mel_encoder(mel_inputs)
|
||||
mel_emb = F.pad(mel_emb, (0, self.MAX_MEL_FRAMES-mel_emb.shape[-1]))
|
||||
mel_emb = F.pad(mel_emb, (0, self.max_mel_frames - mel_emb.shape[-1]))
|
||||
mel_emb = mel_emb.permute(0,2,1).contiguous()
|
||||
mel_emb = mel_emb + self.mel_pos_embedding(torch.arange(mel_emb.shape[1], device=mel_emb.device))
|
||||
|
||||
text_seq = torch.full((b,1), fill_value=self.NUMBER_SYMBOLS, device=mel_emb.device)
|
||||
probabilities = torch.ones((b,), device=mel_emb.device)
|
||||
while text_seq.shape[-1] < self.MAX_SYMBOLS_PER_PHRASE:
|
||||
while text_seq.shape[-1] < self.max_symbols_per_phrase:
|
||||
text_emb = self.text_embedding(text_seq)
|
||||
text_emb = text_emb + self.text_pos_embedding(torch.arange(text_emb.shape[1], device=mel_emb.device))
|
||||
if text_emb.shape[0] != mel_emb.shape[0]:
|
||||
|
|
|
@ -87,7 +87,7 @@ if __name__ == "__main__":
|
|||
sentence_number = 0
|
||||
last_detection_start = 0
|
||||
start = 0
|
||||
clip_size = model.MAX_MEL_FRAMES
|
||||
clip_size = model.max_mel_frames
|
||||
while start+clip_size < mels.shape[-1]:
|
||||
clip = mels[:, :, start:start+clip_size]
|
||||
preds = torch.nn.functional.sigmoid(model(clip)).squeeze(-1).squeeze(0) # Squeeze off the batch and sigmoid dimensions, leaving only the sequence dimension.
|
||||
|
|
Loading…
Reference in New Issue
Block a user