diff --git a/codes/models/audio/tts/unet_diffusion_tts_flat0.py b/codes/models/audio/tts/unet_diffusion_tts_flat0.py index bf526b27..44f18b2c 100644 --- a/codes/models/audio/tts/unet_diffusion_tts_flat0.py +++ b/codes/models/audio/tts/unet_diffusion_tts_flat0.py @@ -118,6 +118,7 @@ class DiffusionTtsFlat(nn.Module): dropout=0, use_fp16=False, num_heads=16, + freeze_everything_except_autoregressive_inputs=False, # Parameters for regularization. layer_drop=.1, unconditioned_percentage=.1, # This implements a mechanism similar to what is used in classifier-free training. @@ -151,7 +152,11 @@ class DiffusionTtsFlat(nn.Module): AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True), ) self.code_norm = normalization(model_channels) - self.latent_converter = nn.Conv1d(in_latent_channels, model_channels, 1) + self.autoregressive_latent_converter = nn.Sequential(nn.Conv1d(in_latent_channels, model_channels, 1), + AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True), + AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True), + AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True), + ) self.contextual_embedder = nn.Sequential(nn.Conv1d(in_channels,model_channels,3,padding=1,stride=2), nn.Conv1d(model_channels, model_channels*2,3,padding=1,stride=2), AttentionBlock(model_channels*2, num_heads, relative_pos_embeddings=True, do_checkpoint=False), @@ -177,11 +182,19 @@ class DiffusionTtsFlat(nn.Module): zero_module(conv_nd(1, model_channels, out_channels, 3, padding=1)), ) + if freeze_everything_except_autoregressive_inputs: + for ap in list(self.autoregressive_latent_converter.parameters()): + ap.ALLOWED_IN_FLAT = True + for p in self.parameters(): + if not hasattr(p, 'ALLOWED_IN_FLAT'): + p.requires_grad = False + p.DO_NOT_TRAIN = True + def get_grad_norm_parameter_groups(self): groups = { 'minicoder': list(self.contextual_embedder.parameters()), 'layers': list(self.layers.parameters()), - 'code_converters': list(self.code_embedding.parameters()) + list(self.code_converter.parameters()) + list(self.latent_converter.parameters()) + list(self.latent_converter.parameters()), + 'code_converters': list(self.code_embedding.parameters()) + list(self.code_converter.parameters()) + list(self.autoregressive_latent_converter.parameters()) + list(self.autoregressive_latent_converter.parameters()), 'timestep_integrator': list(self.conditioning_timestep_integrator.parameters()) + list(self.integrating_conv.parameters()), 'time_embed': list(self.time_embed.parameters()), } @@ -202,7 +215,7 @@ class DiffusionTtsFlat(nn.Module): cond_emb = conds.mean(dim=-1) cond_scale, cond_shift = torch.chunk(cond_emb, 2, dim=1) if is_latent(aligned_conditioning): - code_emb = self.latent_converter(aligned_conditioning) + code_emb = self.autoregressive_latent_converter(aligned_conditioning) else: code_emb = self.code_embedding(aligned_conditioning).permute(0, 2, 1) code_emb = self.code_converter(code_emb) @@ -245,7 +258,7 @@ class DiffusionTtsFlat(nn.Module): if conditioning_free: code_emb = self.unconditioned_embedding.repeat(x.shape[0], 1, x.shape[-1]) unused_params.extend(list(self.code_converter.parameters()) + list(self.code_embedding.parameters())) - unused_params.extend(list(self.latent_converter.parameters())) + unused_params.extend(list(self.autoregressive_latent_converter.parameters())) else: if precomputed_aligned_embeddings is not None: code_emb = precomputed_aligned_embeddings @@ -256,7 +269,7 @@ class DiffusionTtsFlat(nn.Module): if is_latent(aligned_conditioning): unused_params.extend(list(self.code_converter.parameters()) + list(self.code_embedding.parameters())) else: - unused_params.extend(list(self.latent_converter.parameters())) + unused_params.extend(list(self.autoregressive_latent_converter.parameters())) time_emb = self.time_embed(timestep_embedding(timesteps, self.model_channels)) code_emb = self.conditioning_timestep_integrator(code_emb, time_emb) @@ -297,7 +310,7 @@ if __name__ == '__main__': aligned_sequence = torch.randint(0,8192,(2,100)) cond = torch.randn(2, 100, 400) ts = torch.LongTensor([600, 600]) - model = DiffusionTtsFlat(512, layer_drop=.3, unconditioned_percentage=.5) + model = DiffusionTtsFlat(512, layer_drop=.3, unconditioned_percentage=.5, freeze_everything_except_autoregressive_inputs=True) # Test with latent aligned conditioning #o = model(clip, ts, aligned_latent, cond) # Test with sequence aligned conditioning