diff --git a/codes/models/audio/tts/unified_voice3.py b/codes/models/audio/tts/unified_voice3.py index 42ba1251..dd26c789 100644 --- a/codes/models/audio/tts/unified_voice3.py +++ b/codes/models/audio/tts/unified_voice3.py @@ -238,7 +238,8 @@ class MelEncoder(nn.Module): class UnifiedVoice(nn.Module): def __init__(self, layers=8, model_dim=512, heads=8, max_text_tokens=120, max_mel_tokens=250, max_conditioning_inputs=1, mel_length_compression=1024, number_text_tokens=256, number_mel_codes=8194, start_mel_token=8192, - stop_mel_token=8193, start_text_token=None, number_aligned_text_codes=256, checkpointing=True, types=1): + stop_mel_token=8193, start_text_token=None, number_aligned_text_codes=256, checkpointing=True, types=1, + freeze_for_aligned_codes=False,): """ Args: layers: Number of layers in transformer stack. @@ -278,13 +279,21 @@ class UnifiedVoice(nn.Module): self.final_norm = nn.LayerNorm(model_dim) self.text_head = nn.Linear(model_dim, self.number_text_tokens*types+1) self.mel_head = nn.Linear(model_dim, self.number_mel_codes) - self.aligned_head = nn.Linear(model_dim, self.number_aligned_text_codes) + self.aligned_head = nn.Linear(model_dim, number_aligned_text_codes) # Initialize the embeddings per the GPT-2 scheme embeddings = [self.text_embedding, self.mel_embedding] for module in embeddings: module.weight.data.normal_(mean=0.0, std=.02) + if freeze_for_aligned_codes: + for p in self.parameters(): + p.DO_NOT_TRAIN = True + p.requires_grad = False + for p in self.aligned_head.parameters(): + del p.DO_NOT_TRAIN + p.requires_grad = True + def get_grad_norm_parameter_groups(self): return { 'conditioning_encoder': list(self.conditioning_encoder.parameters()), @@ -363,18 +372,18 @@ class UnifiedVoice(nn.Module): if types is not None: text_inputs = text_inputs * (1+types).unsqueeze(-1) - mel_codes = self.set_mel_padding(mel_codes, wav_lengths) - text_inputs = F.pad(text_inputs, (0,1), value=self.stop_text_token) - mel_codes = F.pad(mel_codes, (0,1), value=self.stop_mel_token) conds = self.get_conditioning_latent(speech_conditioning_input) - ac_expansion_factor = mel_codes.shape[-1] // aligned_codes.shape[-1] + ac_expansion_factor = mel_codes.shape[-1] / aligned_codes.shape[-1] aligned_codes = aligned_codes.repeat(1, ac_expansion_factor) _, aligned_targets = self.build_aligned_inputs_and_targets(aligned_codes, 0, 0) + text_inputs = F.pad(text_inputs, (0,1), value=self.stop_text_token) text_inputs, text_targets = self.build_aligned_inputs_and_targets(text_inputs, self.start_text_token, self.stop_text_token) text_emb = self.text_embedding(text_inputs) + self.text_pos_embedding(text_inputs) + + mel_codes = self.set_mel_padding(mel_codes, wav_lengths) mel_codes, mel_targets = self.build_aligned_inputs_and_targets(mel_codes, self.start_mel_token, self.stop_mel_token) mel_inp = mel_codes mel_emb = self.mel_embedding(mel_inp) @@ -431,15 +440,16 @@ class UnifiedVoice(nn.Module): @register_model -def register_unified_voice2(opt_net, opt): +def register_unified_voice3(opt_net, opt): return UnifiedVoice(**opt_get(opt_net, ['kwargs'], {})) if __name__ == '__main__': gpt = UnifiedVoice(model_dim=256, heads=4, max_conditioning_inputs=4, types=2) + mel = torch.randint(high=8192, size=(2,250)) + ac = torch.randint(high=256, size=(2,250*1024//443)) l = gpt(torch.randn(2, 3, 80, 800), torch.randint(high=256, size=(2,120)), torch.tensor([32, 120]), - torch.randint(high=8192, size=(2,250)), - torch.tensor([250*256,195*256]), + mel, torch.tensor([250*256,195*256]), ac, types=torch.tensor([0, 1]))