diff --git a/codes/models/audio/tts/unified_voice2.py b/codes/models/audio/tts/unified_voice2.py index f0d0f6b7..435526c6 100644 --- a/codes/models/audio/tts/unified_voice2.py +++ b/codes/models/audio/tts/unified_voice2.py @@ -241,7 +241,7 @@ class UnifiedVoice(nn.Module): mel_length_compression=1024, number_text_tokens=256, start_text_token=255, stop_text_token=0, number_mel_codes=8194, start_mel_token=8192, stop_mel_token=8193, train_solo_embeddings=False, use_mel_codes_as_input=True, - checkpointing=True, average_conditioning_embeddings=False): + checkpointing=True, average_conditioning_embeddings=False, freeze_everything_but_position_embeddings=False): """ Args: layers: Number of layers in transformer stack. @@ -272,10 +272,10 @@ class UnifiedVoice(nn.Module): self.stop_mel_token = stop_mel_token self.layers = layers self.heads = heads + self.max_conditioning_inputs = max_conditioning_inputs self.max_mel_tokens = -1 if max_mel_tokens == -1 else max_mel_tokens+2+self.max_conditioning_inputs self.max_text_tokens = -1 if max_text_tokens == -1 else max_text_tokens+2 self.model_dim = model_dim - self.max_conditioning_inputs = max_conditioning_inputs self.mel_length_compression = mel_length_compression self.conditioning_encoder = ConditioningEncoder(80, model_dim, num_attn_heads=heads) self.average_conditioning_embeddings = average_conditioning_embeddings @@ -304,6 +304,15 @@ class UnifiedVoice(nn.Module): for module in embeddings: module.weight.data.normal_(mean=0.0, std=.02) + if freeze_everything_but_position_embeddings: + for p in self.parameters(): + p.requires_grad = False + p.DO_NOT_TRAIN = True + for m in [self.mel_pos_embedding, self.text_pos_embedding]: + for p in m.parameters(): + del p.DO_NOT_TRAIN + p.requires_grad = True + def get_grad_norm_parameter_groups(self): return { 'conditioning_encoder': list(self.conditioning_encoder.parameters()), @@ -566,7 +575,7 @@ def register_unified_voice2(opt_net, opt): if __name__ == '__main__': - gpt = UnifiedVoice(model_dim=256, heads=4, train_solo_embeddings=True, use_mel_codes_as_input=True, max_conditioning_inputs=4) + gpt = UnifiedVoice(model_dim=256, heads=4, train_solo_embeddings=True, use_mel_codes_as_input=True, max_conditioning_inputs=4, freeze_everything_but_position_embeddings=True) l = gpt(torch.randn(2, 3, 80, 800), torch.randint(high=256, size=(2,120)), torch.tensor([32, 120]),