forked from mrq/DL-Art-School
Make gpt-asr more configurable
This commit is contained in:
parent
570ed327ed
commit
b521d94b01
|
@ -49,22 +49,22 @@ class MelEncoder(nn.Module):
|
||||||
|
|
||||||
|
|
||||||
class GptAsr(nn.Module):
|
class GptAsr(nn.Module):
|
||||||
MAX_SYMBOLS_PER_PHRASE = 200
|
|
||||||
MAX_MEL_FRAMES = 1000 // 4
|
|
||||||
NUMBER_SYMBOLS = len(symbols)
|
NUMBER_SYMBOLS = len(symbols)
|
||||||
NUMBER_TEXT_TOKENS = NUMBER_SYMBOLS+1
|
NUMBER_TEXT_TOKENS = NUMBER_SYMBOLS+1
|
||||||
|
|
||||||
def __init__(self, layers=8, model_dim=512, heads=8):
|
def __init__(self, layers=8, model_dim=512, heads=8, max_symbols_per_phrase=200, max_mel_frames=1000):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
self.max_mel_frames = max_mel_frames // 4 # Mel frames are reduced by a factor of 4 during encoding.
|
||||||
|
self.max_symbols_per_phrase = max_symbols_per_phrase
|
||||||
|
|
||||||
self.model_dim = model_dim
|
self.model_dim = model_dim
|
||||||
self.max_mel_frames = self.MAX_MEL_FRAMES
|
self.max_mel_frames = self.max_mel_frames
|
||||||
self.text_embedding = nn.Embedding(self.NUMBER_TEXT_TOKENS, model_dim)
|
self.text_embedding = nn.Embedding(self.NUMBER_TEXT_TOKENS, model_dim)
|
||||||
self.mel_encoder = MelEncoder(model_dim)
|
self.mel_encoder = MelEncoder(model_dim)
|
||||||
self.text_pos_embedding = nn.Embedding(self.MAX_SYMBOLS_PER_PHRASE+1, model_dim)
|
self.text_pos_embedding = nn.Embedding(self.max_symbols_per_phrase + 1, model_dim)
|
||||||
self.mel_pos_embedding = nn.Embedding(self.MAX_MEL_FRAMES, model_dim)
|
self.mel_pos_embedding = nn.Embedding(self.max_mel_frames, model_dim)
|
||||||
self.gpt = Transformer(dim=model_dim, depth=layers, seq_len=2+self.MAX_SYMBOLS_PER_PHRASE+self.MAX_MEL_FRAMES, heads=heads,
|
self.gpt = Transformer(dim=model_dim, depth=layers, seq_len=2 + self.max_symbols_per_phrase + self.max_mel_frames, heads=heads,
|
||||||
attn_dropout=.1, ff_dropout=.1, non_causal_sequence_partition=self.MAX_MEL_FRAMES)
|
attn_dropout=.1, ff_dropout=.1, non_causal_sequence_partition=self.max_mel_frames)
|
||||||
|
|
||||||
self.final_norm = nn.LayerNorm(model_dim)
|
self.final_norm = nn.LayerNorm(model_dim)
|
||||||
self.text_head = nn.Linear(model_dim, self.NUMBER_TEXT_TOKENS)
|
self.text_head = nn.Linear(model_dim, self.NUMBER_TEXT_TOKENS)
|
||||||
|
@ -72,11 +72,11 @@ class GptAsr(nn.Module):
|
||||||
def forward(self, mel_inputs, text_targets):
|
def forward(self, mel_inputs, text_targets):
|
||||||
# Pad front and back. Pad at front is the "START" token.
|
# Pad front and back. Pad at front is the "START" token.
|
||||||
text_targets = F.pad(text_targets, (1,0), value=self.NUMBER_SYMBOLS)
|
text_targets = F.pad(text_targets, (1,0), value=self.NUMBER_SYMBOLS)
|
||||||
text_targets = F.pad(text_targets, (0, self.MAX_SYMBOLS_PER_PHRASE-text_targets.shape[1]))
|
text_targets = F.pad(text_targets, (0, self.max_symbols_per_phrase - text_targets.shape[1]))
|
||||||
text_emb = self.text_embedding(text_targets)
|
text_emb = self.text_embedding(text_targets)
|
||||||
text_emb = text_emb + self.text_pos_embedding(torch.arange(text_emb.shape[1], device=text_targets.device))
|
text_emb = text_emb + self.text_pos_embedding(torch.arange(text_emb.shape[1], device=text_targets.device))
|
||||||
mel_emb = self.mel_encoder(mel_inputs)
|
mel_emb = self.mel_encoder(mel_inputs)
|
||||||
mel_emb = F.pad(mel_emb, (0, self.MAX_MEL_FRAMES-mel_emb.shape[-1]))
|
mel_emb = F.pad(mel_emb, (0, self.max_mel_frames - mel_emb.shape[-1]))
|
||||||
mel_emb = mel_emb.permute(0,2,1).contiguous()
|
mel_emb = mel_emb.permute(0,2,1).contiguous()
|
||||||
mel_emb = mel_emb + self.mel_pos_embedding(torch.arange(mel_emb.shape[1], device=mel_emb.device))
|
mel_emb = mel_emb + self.mel_pos_embedding(torch.arange(mel_emb.shape[1], device=mel_emb.device))
|
||||||
emb = torch.cat([mel_emb, text_emb], dim=1)
|
emb = torch.cat([mel_emb, text_emb], dim=1)
|
||||||
|
@ -84,7 +84,7 @@ class GptAsr(nn.Module):
|
||||||
enc = self.gpt(emb)
|
enc = self.gpt(emb)
|
||||||
|
|
||||||
# Compute loss
|
# Compute loss
|
||||||
text_logits = self.final_norm(enc[:, self.MAX_MEL_FRAMES:])
|
text_logits = self.final_norm(enc[:, self.max_mel_frames:])
|
||||||
text_logits = self.text_head(text_logits)
|
text_logits = self.text_head(text_logits)
|
||||||
text_logits = text_logits.permute(0,2,1)
|
text_logits = text_logits.permute(0,2,1)
|
||||||
loss_text = F.cross_entropy(text_logits[:,:,:-1], text_targets[:,1:].long())
|
loss_text = F.cross_entropy(text_logits[:,:,:-1], text_targets[:,1:].long())
|
||||||
|
@ -114,13 +114,13 @@ class GptAsr(nn.Module):
|
||||||
b, _, s = mel_inputs.shape
|
b, _, s = mel_inputs.shape
|
||||||
assert b == 1 # Beam search only works on batches of one.
|
assert b == 1 # Beam search only works on batches of one.
|
||||||
mel_emb = self.mel_encoder(mel_inputs)
|
mel_emb = self.mel_encoder(mel_inputs)
|
||||||
mel_emb = F.pad(mel_emb, (0, self.MAX_MEL_FRAMES-mel_emb.shape[-1]))
|
mel_emb = F.pad(mel_emb, (0, self.max_mel_frames - mel_emb.shape[-1]))
|
||||||
mel_emb = mel_emb.permute(0,2,1).contiguous()
|
mel_emb = mel_emb.permute(0,2,1).contiguous()
|
||||||
mel_emb = mel_emb + self.mel_pos_embedding(torch.arange(mel_emb.shape[1], device=mel_emb.device))
|
mel_emb = mel_emb + self.mel_pos_embedding(torch.arange(mel_emb.shape[1], device=mel_emb.device))
|
||||||
|
|
||||||
text_seq = torch.full((b,1), fill_value=self.NUMBER_SYMBOLS, device=mel_emb.device)
|
text_seq = torch.full((b,1), fill_value=self.NUMBER_SYMBOLS, device=mel_emb.device)
|
||||||
probabilities = torch.ones((b,), device=mel_emb.device)
|
probabilities = torch.ones((b,), device=mel_emb.device)
|
||||||
while text_seq.shape[-1] < self.MAX_SYMBOLS_PER_PHRASE:
|
while text_seq.shape[-1] < self.max_symbols_per_phrase:
|
||||||
text_emb = self.text_embedding(text_seq)
|
text_emb = self.text_embedding(text_seq)
|
||||||
text_emb = text_emb + self.text_pos_embedding(torch.arange(text_emb.shape[1], device=mel_emb.device))
|
text_emb = text_emb + self.text_pos_embedding(torch.arange(text_emb.shape[1], device=mel_emb.device))
|
||||||
if text_emb.shape[0] != mel_emb.shape[0]:
|
if text_emb.shape[0] != mel_emb.shape[0]:
|
||||||
|
|
|
@ -87,7 +87,7 @@ if __name__ == "__main__":
|
||||||
sentence_number = 0
|
sentence_number = 0
|
||||||
last_detection_start = 0
|
last_detection_start = 0
|
||||||
start = 0
|
start = 0
|
||||||
clip_size = model.MAX_MEL_FRAMES
|
clip_size = model.max_mel_frames
|
||||||
while start+clip_size < mels.shape[-1]:
|
while start+clip_size < mels.shape[-1]:
|
||||||
clip = mels[:, :, start:start+clip_size]
|
clip = mels[:, :, start:start+clip_size]
|
||||||
preds = torch.nn.functional.sigmoid(model(clip)).squeeze(-1).squeeze(0) # Squeeze off the batch and sigmoid dimensions, leaving only the sequence dimension.
|
preds = torch.nn.functional.sigmoid(model(clip)).squeeze(-1).squeeze(0) # Squeeze off the batch and sigmoid dimensions, leaving only the sequence dimension.
|
||||||
|
|
Loading…
Reference in New Issue
Block a user