diff --git a/codes/models/gpt_voice/gpt_asr_hf2.py b/codes/models/gpt_voice/gpt_asr_hf2.py index fb33d5f6..dd21ea9c 100644 --- a/codes/models/gpt_voice/gpt_asr_hf2.py +++ b/codes/models/gpt_voice/gpt_asr_hf2.py @@ -250,7 +250,7 @@ class GptAsrHf2(nn.Module): # This model uses its own positional embeddings, which helps discriminate between text and audio MELs. self.text_pos_embedding = nn.Embedding(self.max_symbols_per_phrase + 1, model_dim) self.mel_pos_embedding = nn.Embedding(self.max_mel_frames, model_dim) - self.text_solo_embedding = nn.Parameter(torch.randn(1,1,512) * self.gpt.config.initializer_range, requires_grad=True) + self.text_solo_embedding = nn.Parameter(torch.randn(1,1,model_dim) * self.gpt.config.initializer_range, requires_grad=True) # Head layers self.final_norm = nn.LayerNorm(model_dim) diff --git a/codes/models/gpt_voice/unified_voice.py b/codes/models/gpt_voice/unified_voice.py index 235a1bd8..5fe0ecbf 100644 --- a/codes/models/gpt_voice/unified_voice.py +++ b/codes/models/gpt_voice/unified_voice.py @@ -7,6 +7,7 @@ from transformers import GPT2Model, GPT2Config from models.arch_util import AttentionBlock from models.gpt_voice.gpt_asr_hf import GPT2InferenceModel +from models.gpt_voice.gpt_asr_hf2 import ResBlock from models.tacotron2.text import symbols from trainer.networks import register_model from utils.util import opt_get @@ -34,6 +35,30 @@ class ConditioningEncoder(nn.Module): return h[:, :, 0] +class MelEncoder(nn.Module): + def __init__(self, channels, mel_channels=80, resblocks_per_reduction=2): + super().__init__() + self.channels = channels + self.encoder = nn.Sequential(nn.Conv1d(mel_channels, channels//4, kernel_size=3, padding=1), + nn.Sequential(*[ResBlock(channels//4) for _ in range(resblocks_per_reduction)]), + nn.Conv1d(channels//4, channels//2, kernel_size=3, stride=2, padding=1), + nn.GroupNorm(channels//16, channels//2), + nn.ReLU(), + nn.Sequential(*[ResBlock(channels//2) for _ in range(resblocks_per_reduction)]), + nn.Conv1d(channels//2, channels, kernel_size=3, stride=2, padding=1), + nn.GroupNorm(channels//8, channels), + nn.ReLU(), + nn.Sequential(*[ResBlock(channels) for _ in range(resblocks_per_reduction)]), + ) + self.reduction = 4 + + + def forward(self, x): + for e in self.encoder: + x = e(x) + return x + + def null_position_embeddings(range, dim): return torch.zeros((range.shape[0], range.shape[1], dim), device=range.device) @@ -50,7 +75,7 @@ class UnifiedGptVoice(nn.Module): def __init__(self, layers=8, model_dim=512, heads=8, max_symbols_per_phrase=120, max_mel_tokens=250, max_total_tokens=370, max_conditioning_inputs=3, checkpointing=True, mel_length_compression=1024, max_conditioning_length=60, number_text_tokens=256, start_text_token=255, stop_text_token=0, number_mel_codes=8194, start_mel_token=8192, - stop_mel_token=8193, use_dedicated_position_embeddings_for_paired=True, shuffle_conditioning=True): + stop_mel_token=8193, shuffle_conditioning=True, train_solo_embeddings=False, use_mel_codes_as_input=True): super().__init__() self.number_text_tokens = number_text_tokens @@ -69,14 +94,8 @@ class UnifiedGptVoice(nn.Module): self.mel_length_compression = mel_length_compression self.conditioning_encoder = ConditioningEncoder(80, model_dim, num_attn_heads=heads) self.text_embedding = nn.Embedding(self.number_text_tokens, model_dim) - self.text_pos_solo_embedding = nn.Embedding(self.max_symbols_per_phrase + 1, model_dim) - self.mel_pos_solo_embedding = nn.Embedding(self.max_mel_tokens + 1, model_dim) - if use_dedicated_position_embeddings_for_paired: - self.mel_pos_paired_embedding = nn.Embedding(self.max_mel_tokens + 1, model_dim) - self.text_pos_paired_embedding = nn.Embedding(self.max_symbols_per_phrase + 1, model_dim) - else: - self.mel_pos_paired_embedding = self.mel_pos_solo_embedding - self.text_pos_paired_embedding = self.text_pos_solo_embedding + self.text_pos_embedding = nn.Embedding(self.max_symbols_per_phrase + 1, model_dim) + self.mel_pos_embedding = nn.Embedding(self.max_mel_tokens + 1, model_dim) seq_length = 2+self.max_total_tokens+self.max_conditioning_inputs self.gpt_config = GPT2Config(vocab_size=self.number_mel_codes, n_positions=seq_length, @@ -87,18 +106,26 @@ class UnifiedGptVoice(nn.Module): gradient_checkpointing=checkpointing, use_cache=not checkpointing) self.gpt = GPT2Model(self.gpt_config) + if train_solo_embeddings: + self.mel_solo_embedding = nn.Parameter(torch.randn(1, 1, model_dim) * self.gpt.config.initializer_range, requires_grad=True) + self.text_solo_embedding = nn.Parameter(torch.randn(1, 1, model_dim) * self.gpt.config.initializer_range, requires_grad=True) + else: + self.mel_solo_embedding = 0 + self.text_solo_embedding = 0 # Override the built in positional embeddings del self.gpt.wpe self.gpt.wpe = functools.partial(null_position_embeddings, dim=model_dim) + if not use_mel_codes_as_input: + self.gpt.wte = MelEncoder(model_dim, resblocks_per_reduction=1) + self.final_norm = nn.LayerNorm(model_dim) self.text_head = nn.Linear(model_dim, self.number_text_tokens) self.mel_head = nn.Linear(model_dim, self.number_mel_codes) self.max_conditioning_length = max_conditioning_length # Initialize the embeddings per the GPT-2 scheme - for module in [self.text_embedding, self.text_pos_solo_embedding, self.text_pos_paired_embedding, - self.mel_pos_solo_embedding, self.mel_pos_paired_embedding]: + for module in [self.text_embedding, self.text_pos_embedding, self.mel_pos_embedding]: module.weight.data.normal_(mean=0.0, std=self.gpt.config.initializer_range) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() @@ -177,10 +204,10 @@ class UnifiedGptVoice(nn.Module): speech_conditioning_input = self.conditioning_encoder(speech_conditioning_input).unsqueeze(1) text_inputs, text_targets = self.build_aligned_inputs_and_targets(text_inputs, self.start_text_token, self.stop_text_token) - text_emb = self.text_embedding(text_inputs) + self.text_pos_paired_embedding(torch.arange(text_inputs.shape[1], device=text_inputs.device)) + text_emb = self.text_embedding(text_inputs) + self.text_pos_embedding(torch.arange(text_inputs.shape[1], device=text_inputs.device)) mel_inputs, mel_targets = self.build_aligned_inputs_and_targets(mel_inputs, self.start_mel_token, self.stop_mel_token) mel_emb = self.gpt.get_input_embeddings()(mel_inputs) - mel_emb = mel_emb + self.mel_pos_paired_embedding(torch.arange(mel_emb.shape[1], device=mel_emb.device)) + mel_emb = mel_emb + self.mel_pos_embedding(torch.arange(mel_emb.shape[1], device=mel_emb.device)) if text_first: text_logits, mel_logits = self.get_logits(speech_conditioning_input, text_emb, self.text_head, mel_emb, self.mel_head, get_attns=return_attentions) else: @@ -204,7 +231,7 @@ class UnifiedGptVoice(nn.Module): speech_conditioning_input = self.conditioning_encoder(speech_conditioning_input).unsqueeze(1) text_inputs, text_targets = self.build_aligned_inputs_and_targets(text_inputs, self.start_text_token, self.stop_text_token) - text_emb = self.text_embedding(text_inputs) + self.text_pos_solo_embedding(torch.arange(text_inputs.shape[1], device=text_inputs.device)) + text_emb = self.text_embedding(text_inputs) + self.text_pos_solo_embedding(torch.arange(text_inputs.shape[1], device=text_inputs.device)) + self.text_solo_embedding text_logits = self.get_logits(speech_conditioning_input, text_emb, self.text_head) loss_text = F.cross_entropy(text_logits, text_targets.long()) return loss_text.mean() @@ -222,7 +249,7 @@ class UnifiedGptVoice(nn.Module): mel_inputs, mel_targets = self.build_aligned_inputs_and_targets(mel_inputs, self.start_mel_token, self.stop_mel_token) mel_emb = self.gpt.get_input_embeddings()(mel_inputs) - mel_emb = mel_emb + self.mel_pos_solo_embedding(torch.arange(mel_emb.shape[1], device=mel_emb.device)) + mel_emb = mel_emb + self.mel_pos_solo_embedding(torch.arange(mel_emb.shape[1], device=mel_emb.device)) + self.mel_solo_embedding mel_logits = self.get_logits(speech_conditioning_input, mel_emb, self.mel_head) loss_mel = F.cross_entropy(mel_logits, mel_targets.long()) return loss_mel.mean() @@ -256,7 +283,7 @@ def register_unified_gpt_voice(opt_net, opt): if __name__ == '__main__': - gpt = UnifiedGptVoice(model_dim=256, heads=4, use_dedicated_position_embeddings_for_paired=False) + gpt = UnifiedGptVoice(model_dim=256, heads=4, train_solo_embeddings=True) l = gpt(torch.randn(2, 80, 800), torch.randint(high=len(symbols), size=(2,80)), torch.randint(high=8192, size=(2,250)), diff --git a/codes/train.py b/codes/train.py index 9f7349ce..8fea3b20 100644 --- a/codes/train.py +++ b/codes/train.py @@ -286,7 +286,7 @@ class Trainer: if __name__ == '__main__': parser = argparse.ArgumentParser() - parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_gpt_asr_mass_hf2.yml') + parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_mel_encoder_pred_codes.yml') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') parser.add_argument('--local_rank', type=int, default=0) args = parser.parse_args()