diff --git a/codes/models/audio/tts/unet_diffusion_tts_flat0.py b/codes/models/audio/tts/unet_diffusion_tts_flat0.py
index c56903c0..14f59f00 100644
--- a/codes/models/audio/tts/unet_diffusion_tts_flat0.py
+++ b/codes/models/audio/tts/unet_diffusion_tts_flat0.py
@@ -166,6 +166,7 @@ class DiffusionTtsFlat(nn.Module):
             DiffusionLayer(model_channels, dropout, num_heads),
         )
         self.integrating_conv = nn.Conv1d(model_channels*2, model_channels, kernel_size=1)
+        self.mel_head = nn.Conv1d(model_channels, in_channels, kernel_size=3, padding=1)
 
         self.layers = nn.ModuleList([DiffusionLayer(model_channels, dropout, num_heads) for _ in range(num_layers)] +
                                     [ResBlock(model_channels, model_channels, dropout, dims=1, use_scale_shift_norm=True) for _ in range(3)])
@@ -228,12 +229,14 @@ class DiffusionTtsFlat(nn.Module):
                                                device=code_emb.device) < self.unconditioned_percentage
             code_emb = torch.where(unconditioned_batches, self.unconditioned_embedding.repeat(x.shape[0], 1, 1),
                                    code_emb)
+        expanded_code_emb = F.interpolate(code_emb, size=x.shape[-1], mode='nearest')
+        mel_pred = self.mel_head(expanded_code_emb)
 
         # Everything after this comment is timestep dependent.
         time_emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
-        code_emb = self.conditioning_timestep_integrator(code_emb, time_emb)
+        code_emb = self.conditioning_timestep_integrator(expanded_code_emb, time_emb)
         x = self.inp_block(x)
-        x = torch.cat([x, F.interpolate(code_emb, size=x.shape[-1], mode='nearest')], dim=1)
+        x = torch.cat([x, code_emb], dim=1)
         x = self.integrating_conv(x)
         for i, lyr in enumerate(self.layers):
             # Do layer drop where applicable. Do not drop first and last layers.
@@ -253,7 +256,7 @@ class DiffusionTtsFlat(nn.Module):
             extraneous_addition = extraneous_addition + p.mean()
         out = out + extraneous_addition * 0
 
-        return out
+        return out, mel_pred
 
 
 @register_model
@@ -269,7 +272,7 @@ if __name__ == '__main__':
     ts = torch.LongTensor([600, 600])
     model = DiffusionTtsFlat(512, layer_drop=.3)
     # Test with latent aligned conditioning
-    o = model(clip, ts, aligned_latent, cond)
+    #o = model(clip, ts, aligned_latent, cond)
     # Test with sequence aligned conditioning
     o = model(clip, ts, aligned_sequence, cond)
 
diff --git a/codes/train.py b/codes/train.py
index ee85b690..80e246dc 100644
--- a/codes/train.py
+++ b/codes/train.py
@@ -318,7 +318,7 @@ class Trainer:
 
 if __name__ == '__main__':
     parser = argparse.ArgumentParser()
-    parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../experiments/train_gpt_tts_unified.yml')
+    parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_clip_text_to_voice.yml')
     parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
     args = parser.parse_args()
     opt = option.parse(args.opt, is_train=True)
diff --git a/codes/trainer/steps.py b/codes/trainer/steps.py
index 95a9fa95..886c4d09 100644
--- a/codes/trainer/steps.py
+++ b/codes/trainer/steps.py
@@ -241,12 +241,13 @@ class ConfigurableStep(Module):
             # Finally, compute the losses.
             total_loss = 0
             for loss_name, loss in self.losses.items():
+                multiplier = 1
                 # Some losses only activate after a set number of steps. For example, proto-discriminator losses can
                 # be very disruptive to a generator.
                 if 'after' in loss.opt.keys() and loss.opt['after'] > self.env['step'] or \
                    'before' in loss.opt.keys() and self.env['step'] > loss.opt['before'] or \
                    'every' in loss.opt.keys() and self.env['step'] % loss.opt['every'] != 0:
-                    continue
+                    multiplier = 0  # Multiply by 0 so gradients still flow and DDP works. Effectively this means the loss is unused.
                 if loss.is_stateful():
                     l, lstate = loss(self.get_network_for_name(self.step_opt['training']), local_state)
                     local_state.update(lstate)
@@ -255,7 +256,7 @@ class ConfigurableStep(Module):
                     l = loss(self.get_network_for_name(self.step_opt['training']), local_state)
                 if not l.isfinite():
                     print(f'!!Detected non-finite loss {loss_name}')
-                total_loss += l * self.weights[loss_name]
+                total_loss += l * self.weights[loss_name] * multiplier
                 # Record metrics.
                 if isinstance(l, torch.Tensor):
                     loss_accumulator.add_loss(loss_name, l)