diff --git a/codes/models/audio/tts/unet_diffusion_tts_flat0.py b/codes/models/audio/tts/unet_diffusion_tts_flat0.py
index bf526b27..44f18b2c 100644
--- a/codes/models/audio/tts/unet_diffusion_tts_flat0.py
+++ b/codes/models/audio/tts/unet_diffusion_tts_flat0.py
@@ -118,6 +118,7 @@ class DiffusionTtsFlat(nn.Module):
             dropout=0,
             use_fp16=False,
             num_heads=16,
+            freeze_everything_except_autoregressive_inputs=False,
             # Parameters for regularization.
             layer_drop=.1,
             unconditioned_percentage=.1,  # This implements a mechanism similar to what is used in classifier-free training.
@@ -151,7 +152,11 @@ class DiffusionTtsFlat(nn.Module):
             AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True),
         )
         self.code_norm = normalization(model_channels)
-        self.latent_converter = nn.Conv1d(in_latent_channels, model_channels, 1)
+        self.autoregressive_latent_converter = nn.Sequential(nn.Conv1d(in_latent_channels, model_channels, 1),
+                                                             AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True),
+                                                             AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True),
+                                                             AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True),
+                                                             )
         self.contextual_embedder = nn.Sequential(nn.Conv1d(in_channels,model_channels,3,padding=1,stride=2),
                                                  nn.Conv1d(model_channels, model_channels*2,3,padding=1,stride=2),
                                                  AttentionBlock(model_channels*2, num_heads, relative_pos_embeddings=True, do_checkpoint=False),
@@ -177,11 +182,19 @@ class DiffusionTtsFlat(nn.Module):
             zero_module(conv_nd(1, model_channels, out_channels, 3, padding=1)),
         )
 
+        if freeze_everything_except_autoregressive_inputs:
+            for ap in list(self.autoregressive_latent_converter.parameters()):
+                ap.ALLOWED_IN_FLAT = True
+            for p in self.parameters():
+                if not hasattr(p, 'ALLOWED_IN_FLAT'):
+                    p.requires_grad = False
+                    p.DO_NOT_TRAIN = True
+
     def get_grad_norm_parameter_groups(self):
         groups = {
             'minicoder': list(self.contextual_embedder.parameters()),
             'layers': list(self.layers.parameters()),
-            'code_converters': list(self.code_embedding.parameters()) + list(self.code_converter.parameters()) + list(self.latent_converter.parameters()) + list(self.latent_converter.parameters()),
+            'code_converters': list(self.code_embedding.parameters()) + list(self.code_converter.parameters()) + list(self.autoregressive_latent_converter.parameters()) + list(self.autoregressive_latent_converter.parameters()),
             'timestep_integrator': list(self.conditioning_timestep_integrator.parameters()) + list(self.integrating_conv.parameters()),
             'time_embed': list(self.time_embed.parameters()),
         }
@@ -202,7 +215,7 @@ class DiffusionTtsFlat(nn.Module):
         cond_emb = conds.mean(dim=-1)
         cond_scale, cond_shift = torch.chunk(cond_emb, 2, dim=1)
         if is_latent(aligned_conditioning):
-            code_emb = self.latent_converter(aligned_conditioning)
+            code_emb = self.autoregressive_latent_converter(aligned_conditioning)
         else:
             code_emb = self.code_embedding(aligned_conditioning).permute(0, 2, 1)
             code_emb = self.code_converter(code_emb)
@@ -245,7 +258,7 @@ class DiffusionTtsFlat(nn.Module):
         if conditioning_free:
             code_emb = self.unconditioned_embedding.repeat(x.shape[0], 1, x.shape[-1])
             unused_params.extend(list(self.code_converter.parameters()) + list(self.code_embedding.parameters()))
-            unused_params.extend(list(self.latent_converter.parameters()))
+            unused_params.extend(list(self.autoregressive_latent_converter.parameters()))
         else:
             if precomputed_aligned_embeddings is not None:
                 code_emb = precomputed_aligned_embeddings
@@ -256,7 +269,7 @@ class DiffusionTtsFlat(nn.Module):
             if is_latent(aligned_conditioning):
                 unused_params.extend(list(self.code_converter.parameters()) + list(self.code_embedding.parameters()))
             else:
-                unused_params.extend(list(self.latent_converter.parameters()))
+                unused_params.extend(list(self.autoregressive_latent_converter.parameters()))
 
         time_emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
         code_emb = self.conditioning_timestep_integrator(code_emb, time_emb)
@@ -297,7 +310,7 @@ if __name__ == '__main__':
     aligned_sequence = torch.randint(0,8192,(2,100))
     cond = torch.randn(2, 100, 400)
     ts = torch.LongTensor([600, 600])
-    model = DiffusionTtsFlat(512, layer_drop=.3, unconditioned_percentage=.5)
+    model = DiffusionTtsFlat(512, layer_drop=.3, unconditioned_percentage=.5, freeze_everything_except_autoregressive_inputs=True)
     # Test with latent aligned conditioning
     #o = model(clip, ts, aligned_latent, cond)
     # Test with sequence aligned conditioning