From b521d94b0158c859cec07948fcf0a2eccf8bc9bb Mon Sep 17 00:00:00 2001 From: James Betker Date: Thu, 19 Aug 2021 16:33:41 -0600 Subject: [PATCH] Make gpt-asr more configurable --- codes/models/gpt_voice/gpt_asr.py | 26 ++++++++++----------- codes/scripts/audio/test_audio_segmentor.py | 2 +- 2 files changed, 14 insertions(+), 14 deletions(-) diff --git a/codes/models/gpt_voice/gpt_asr.py b/codes/models/gpt_voice/gpt_asr.py index 6dc1311f..b0acd4f8 100644 --- a/codes/models/gpt_voice/gpt_asr.py +++ b/codes/models/gpt_voice/gpt_asr.py @@ -49,22 +49,22 @@ class MelEncoder(nn.Module): class GptAsr(nn.Module): - MAX_SYMBOLS_PER_PHRASE = 200 - MAX_MEL_FRAMES = 1000 // 4 NUMBER_SYMBOLS = len(symbols) NUMBER_TEXT_TOKENS = NUMBER_SYMBOLS+1 - def __init__(self, layers=8, model_dim=512, heads=8): + def __init__(self, layers=8, model_dim=512, heads=8, max_symbols_per_phrase=200, max_mel_frames=1000): super().__init__() + self.max_mel_frames = max_mel_frames // 4 # Mel frames are reduced by a factor of 4 during encoding. + self.max_symbols_per_phrase = max_symbols_per_phrase self.model_dim = model_dim - self.max_mel_frames = self.MAX_MEL_FRAMES + self.max_mel_frames = self.max_mel_frames self.text_embedding = nn.Embedding(self.NUMBER_TEXT_TOKENS, model_dim) self.mel_encoder = MelEncoder(model_dim) - self.text_pos_embedding = nn.Embedding(self.MAX_SYMBOLS_PER_PHRASE+1, model_dim) - self.mel_pos_embedding = nn.Embedding(self.MAX_MEL_FRAMES, model_dim) - self.gpt = Transformer(dim=model_dim, depth=layers, seq_len=2+self.MAX_SYMBOLS_PER_PHRASE+self.MAX_MEL_FRAMES, heads=heads, - attn_dropout=.1, ff_dropout=.1, non_causal_sequence_partition=self.MAX_MEL_FRAMES) + self.text_pos_embedding = nn.Embedding(self.max_symbols_per_phrase + 1, model_dim) + self.mel_pos_embedding = nn.Embedding(self.max_mel_frames, model_dim) + self.gpt = Transformer(dim=model_dim, depth=layers, seq_len=2 + self.max_symbols_per_phrase + self.max_mel_frames, heads=heads, + attn_dropout=.1, ff_dropout=.1, non_causal_sequence_partition=self.max_mel_frames) self.final_norm = nn.LayerNorm(model_dim) self.text_head = nn.Linear(model_dim, self.NUMBER_TEXT_TOKENS) @@ -72,11 +72,11 @@ class GptAsr(nn.Module): def forward(self, mel_inputs, text_targets): # Pad front and back. Pad at front is the "START" token. text_targets = F.pad(text_targets, (1,0), value=self.NUMBER_SYMBOLS) - text_targets = F.pad(text_targets, (0, self.MAX_SYMBOLS_PER_PHRASE-text_targets.shape[1])) + text_targets = F.pad(text_targets, (0, self.max_symbols_per_phrase - text_targets.shape[1])) text_emb = self.text_embedding(text_targets) text_emb = text_emb + self.text_pos_embedding(torch.arange(text_emb.shape[1], device=text_targets.device)) mel_emb = self.mel_encoder(mel_inputs) - mel_emb = F.pad(mel_emb, (0, self.MAX_MEL_FRAMES-mel_emb.shape[-1])) + mel_emb = F.pad(mel_emb, (0, self.max_mel_frames - mel_emb.shape[-1])) mel_emb = mel_emb.permute(0,2,1).contiguous() mel_emb = mel_emb + self.mel_pos_embedding(torch.arange(mel_emb.shape[1], device=mel_emb.device)) emb = torch.cat([mel_emb, text_emb], dim=1) @@ -84,7 +84,7 @@ class GptAsr(nn.Module): enc = self.gpt(emb) # Compute loss - text_logits = self.final_norm(enc[:, self.MAX_MEL_FRAMES:]) + text_logits = self.final_norm(enc[:, self.max_mel_frames:]) text_logits = self.text_head(text_logits) text_logits = text_logits.permute(0,2,1) loss_text = F.cross_entropy(text_logits[:,:,:-1], text_targets[:,1:].long()) @@ -114,13 +114,13 @@ class GptAsr(nn.Module): b, _, s = mel_inputs.shape assert b == 1 # Beam search only works on batches of one. mel_emb = self.mel_encoder(mel_inputs) - mel_emb = F.pad(mel_emb, (0, self.MAX_MEL_FRAMES-mel_emb.shape[-1])) + mel_emb = F.pad(mel_emb, (0, self.max_mel_frames - mel_emb.shape[-1])) mel_emb = mel_emb.permute(0,2,1).contiguous() mel_emb = mel_emb + self.mel_pos_embedding(torch.arange(mel_emb.shape[1], device=mel_emb.device)) text_seq = torch.full((b,1), fill_value=self.NUMBER_SYMBOLS, device=mel_emb.device) probabilities = torch.ones((b,), device=mel_emb.device) - while text_seq.shape[-1] < self.MAX_SYMBOLS_PER_PHRASE: + while text_seq.shape[-1] < self.max_symbols_per_phrase: text_emb = self.text_embedding(text_seq) text_emb = text_emb + self.text_pos_embedding(torch.arange(text_emb.shape[1], device=mel_emb.device)) if text_emb.shape[0] != mel_emb.shape[0]: diff --git a/codes/scripts/audio/test_audio_segmentor.py b/codes/scripts/audio/test_audio_segmentor.py index bfbdcee6..b18e1816 100644 --- a/codes/scripts/audio/test_audio_segmentor.py +++ b/codes/scripts/audio/test_audio_segmentor.py @@ -87,7 +87,7 @@ if __name__ == "__main__": sentence_number = 0 last_detection_start = 0 start = 0 - clip_size = model.MAX_MEL_FRAMES + clip_size = model.max_mel_frames while start+clip_size < mels.shape[-1]: clip = mels[:, :, start:start+clip_size] preds = torch.nn.functional.sigmoid(model(clip)).squeeze(-1).squeeze(0) # Squeeze off the batch and sigmoid dimensions, leaving only the sequence dimension.