diff --git a/codes/models/gpt_voice/gpt_asr_hf2.py b/codes/models/gpt_voice/gpt_asr_hf2.py index 0884c80c..461f03b8 100644 --- a/codes/models/gpt_voice/gpt_asr_hf2.py +++ b/codes/models/gpt_voice/gpt_asr_hf2.py @@ -45,6 +45,32 @@ class MelEncoder(nn.Module): nn.ReLU(), nn.Sequential(*[ResBlock(channels) for _ in range(resblocks_per_reduction)]), ) + self.reduction = 4 + + def forward(self, x): + for e in self.encoder: + x = e(x) + return x + + +class LeanMelEncoder(nn.Module): + def __init__(self, channels, mel_channels=80, resblocks_per_reduction=1): + super().__init__() + self.channels = channels + self.encoder = nn.Sequential(nn.Conv1d(mel_channels, channels//2, kernel_size=5, stride=2, padding=1), + nn.GroupNorm(channels//16, channels//2), + nn.ReLU(), + nn.Sequential(*[ResBlock(channels//2) for _ in range(resblocks_per_reduction)]), + nn.Conv1d(channels//2, channels, kernel_size=3, stride=2, padding=1), + nn.GroupNorm(channels//8, channels), + nn.ReLU(), + nn.Sequential(*[ResBlock(channels) for _ in range(resblocks_per_reduction)]), + nn.Conv1d(channels, channels, kernel_size=3, stride=2, padding=1), + nn.GroupNorm(channels//8, channels), + nn.ReLU(), + nn.Sequential(*[ResBlock(channels) for _ in range(resblocks_per_reduction)]), + ) + self.reduction = 8 def forward(self, x): for e in self.encoder: @@ -211,21 +237,18 @@ def null_position_embeddings(range, dim): class GptAsrHf2(nn.Module): def __init__(self, layers=8, model_dim=512, heads=8, max_symbols_per_phrase=800, max_mel_frames=3000, checkpointing=True, - number_text_tokens=512, start_token=511, stop_token=0, mel_encoder_resblocks_per_level=2): + number_text_tokens=512, start_token=511, stop_token=0, lean_encoder=False): super().__init__() self.number_text_tokens = number_text_tokens self.start_token = start_token self.stop_token = stop_token - - self.max_mel_frames = max_mel_frames // 4 # Mel frames are reduced by a factor of 4 during encoding. self.max_symbols_per_phrase = max_symbols_per_phrase - self.model_dim = model_dim - self.max_mel_frames = self.max_mel_frames - self.mel_encoder = MelEncoder(model_dim, resblocks_per_reduction=mel_encoder_resblocks_per_level) - self.text_pos_embedding = nn.Embedding(self.max_symbols_per_phrase + 1, model_dim) - self.text_pos_embedding = nn.Embedding(self.max_symbols_per_phrase + 1, model_dim) - self.mel_pos_embedding = nn.Embedding(self.max_mel_frames, model_dim) + if lean_encoder: + self.mel_encoder = LeanMelEncoder(model_dim) + else: + self.mel_encoder = MelEncoder(model_dim, resblocks_per_reduction=1) + self.max_mel_frames = max_mel_frames // self.mel_encoder.reduction seq_length = 2+self.max_symbols_per_phrase+self.max_mel_frames self.gpt_config = GPT2Config(vocab_size=self.number_text_tokens, n_positions=seq_length, @@ -236,12 +259,15 @@ class GptAsrHf2(nn.Module): gradient_checkpointing=checkpointing, use_cache=not checkpointing) self.gpt = GPT2Model(self.gpt_config) - self.text_solo_embedding = nn.Parameter(torch.randn(1,1,512) * self.gpt.config.initializer_range, requires_grad=True) - # Override the built in positional embeddings del self.gpt.wpe self.gpt.wpe = functools.partial(null_position_embeddings, dim=model_dim) + # This model uses its own positional embeddings, which helps discriminate between text and audio MELs. + self.text_pos_embedding = nn.Embedding(self.max_symbols_per_phrase + 1, model_dim) + self.mel_pos_embedding = nn.Embedding(self.max_mel_frames, model_dim) + self.text_solo_embedding = nn.Parameter(torch.randn(1,1,512) * self.gpt.config.initializer_range, requires_grad=True) + self.final_norm = nn.LayerNorm(model_dim) self.text_head = nn.Linear(model_dim, self.number_text_tokens) @@ -336,7 +362,7 @@ def distill(): if __name__ == '__main__': #distill() - gpt = GptAsrHf2(max_symbols_per_phrase=250, max_mel_frames=1400, layers=16, model_dim=512, heads=8) + gpt = GptAsrHf2(max_symbols_per_phrase=250, max_mel_frames=1400, layers=16, model_dim=512, heads=8, lean_encoder=True) l = gpt(torch.randn(2,80,640), torch.randint(high=len(symbols), size=(2,80))) gpt.text_only(torch.randint(high=len(symbols), size=(2,120)))