diff --git a/codes/models/audio/music/tfdpc_v4.py b/codes/models/audio/music/tfdpc_v4.py
new file mode 100644
index 00000000..01dd376f
--- /dev/null
+++ b/codes/models/audio/music/tfdpc_v4.py
@@ -0,0 +1,352 @@
+import itertools
+import os
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torchaudio
+import torchvision
+
+from models.diffusion.nn import timestep_embedding, normalization, zero_module, conv_nd, linear
+from models.diffusion.unet_diffusion import TimestepBlock
+from models.lucidrains.x_transformers import Encoder, Attention, RMSScaleShiftNorm, RotaryEmbedding, \
+    FeedForward
+from trainer.networks import register_model
+from utils.util import checkpoint, print_network, load_audio
+
+
+class TimestepRotaryEmbedSequential(nn.Sequential, TimestepBlock):
+    def forward(self, x, emb, rotary_emb):
+        for layer in self:
+            if isinstance(layer, TimestepBlock):
+                x = layer(x, emb, rotary_emb)
+            else:
+                x = layer(x, rotary_emb)
+        return x
+
+
+class SubBlock(nn.Module):
+    def __init__(self, inp_dim, contraction_dim, heads, dropout, use_conv):
+        super().__init__()
+        self.attn = Attention(inp_dim, out_dim=contraction_dim, heads=heads, dim_head=contraction_dim//heads, causal=False, dropout=dropout)
+        self.attnorm = nn.LayerNorm(contraction_dim)
+        self.use_conv = use_conv
+        if use_conv:
+            self.ff = nn.Conv1d(inp_dim+contraction_dim, contraction_dim, kernel_size=3, padding=1)
+        else:
+            self.ff = FeedForward(inp_dim+contraction_dim, dim_out=contraction_dim, mult=2, dropout=dropout)
+        self.ffnorm = nn.LayerNorm(contraction_dim)
+
+    def forward(self, x, rotary_emb):
+        ah, _, _, _ = checkpoint(self.attn, x, None, None, None, None, None, rotary_emb)
+        ah = F.gelu(self.attnorm(ah))
+        h = torch.cat([ah, x], dim=-1)
+        hf = checkpoint(self.ff, h.permute(0,2,1) if self.use_conv else h)
+        hf = F.gelu(self.ffnorm(hf.permute(0,2,1) if self.use_conv else hf))
+        h = torch.cat([h, hf], dim=-1)
+        return h
+
+
+class ConcatAttentionBlock(TimestepBlock):
+    def __init__(self, trunk_dim, contraction_dim, time_embed_dim, cond_dim_in, cond_dim_hidden, heads, dropout, cond_projection=True, use_conv=False):
+        super().__init__()
+        self.prenorm = RMSScaleShiftNorm(trunk_dim, embed_dim=time_embed_dim, bias=False)
+        if cond_projection:
+            self.tdim = trunk_dim+cond_dim_hidden
+            self.cond_project = nn.Linear(cond_dim_in, cond_dim_hidden)
+        else:
+            self.tdim = trunk_dim
+        self.block1 = SubBlock(self.tdim, contraction_dim, heads, dropout, use_conv)
+        self.block2 = SubBlock(self.tdim+contraction_dim*2, contraction_dim, heads, dropout, use_conv)
+        self.out = nn.Linear(contraction_dim*4, trunk_dim, bias=False)
+        self.out.weight.data.zero_()
+
+    def forward(self, x, cond, timestep_emb, rotary_emb):
+        h = self.prenorm(x, norm_scale_shift_inp=timestep_emb)
+        if hasattr(self, 'cond_project'):
+            cond = self.cond_project(cond)
+            h = torch.cat([h, cond], dim=-1)
+        h = self.block1(h, rotary_emb)
+        h = self.block2(h, rotary_emb)
+        h = self.out(h[:,:,self.tdim:])
+        return h + x
+
+
+class ConditioningEncoder(nn.Module):
+    def __init__(self,
+                 cond_dim,
+                 embedding_dim,
+                 time_embed_dim,
+                 attn_blocks=6,
+                 num_attn_heads=8,
+                 dropout=.1,
+                 do_checkpointing=False):
+        super().__init__()
+        attn = []
+        self.init = nn.Conv1d(cond_dim, embedding_dim, kernel_size=1)
+        self.time_proj = nn.Linear(time_embed_dim, embedding_dim)
+        self.attn = Encoder(
+                dim=embedding_dim,
+                depth=attn_blocks,
+                heads=num_attn_heads,
+                ff_dropout=dropout,
+                attn_dropout=dropout,
+                use_rmsnorm=True,
+                ff_glu=True,
+                rotary_pos_emb=True,
+                zero_init_branch_output=True,
+                ff_mult=2,
+            )
+        self.dim = embedding_dim
+        self.do_checkpointing = do_checkpointing
+
+    def forward(self, x, time_emb):
+        h = self.init(x).permute(0,2,1)
+        time_enc = self.time_proj(time_emb)
+        h = torch.cat([time_enc.unsqueeze(1), h], dim=1)
+        h = self.attn(h).permute(0,2,1)
+        return h.mean(dim=2).unsqueeze(1)
+
+
+class TransformerDiffusionWithPointConditioning(nn.Module):
+    """
+    A diffusion model composed entirely of stacks of transformer layers. Why would you do it any other way?
+    """
+    def __init__(
+            self,
+            in_channels=256,
+            out_channels=512,  # mean and variance
+            model_channels=1024,
+            contraction_dim=256,
+            time_embed_dim=256,
+            num_layers=8,
+            rotary_emb_dim=32,
+            input_cond_dim=1024,
+            num_heads=8,
+            dropout=0,
+            use_fp16=False,
+            # Parameters for regularization.
+            unconditioned_percentage=.1,  # This implements a mechanism similar to what is used in classifier-free training.
+    ):
+        super().__init__()
+
+        self.in_channels = in_channels
+        self.model_channels = model_channels
+        self.time_embed_dim = time_embed_dim
+        self.out_channels = out_channels
+        self.dropout = dropout
+        self.unconditioned_percentage = unconditioned_percentage
+        self.enable_fp16 = use_fp16
+
+        self.inp_block = conv_nd(1, in_channels, model_channels, 3, 1, 1)
+        self.conditioning_encoder = ConditioningEncoder(256, model_channels, time_embed_dim)
+
+        self.time_embed = nn.Sequential(
+            linear(time_embed_dim, time_embed_dim),
+            nn.SiLU(),
+            linear(time_embed_dim, time_embed_dim),
+        )
+
+        self.unconditioned_embedding = nn.Parameter(torch.randn(1,1,model_channels))
+        self.rotary_embeddings = RotaryEmbedding(rotary_emb_dim)
+        self.layers = TimestepRotaryEmbedSequential(*[ConcatAttentionBlock(model_channels,
+                                                                           contraction_dim,
+                                                                           time_embed_dim,
+                                                                           cond_dim_in=input_cond_dim,
+                                                                           cond_dim_hidden=input_cond_dim//2,
+                                                                           heads=num_heads,
+                                                                           dropout=dropout,
+                                                                           cond_projection=(k % 3 == 0),
+                                                                           use_conv=(k % 3 != 0),
+                                                                           ) for k in range(num_layers)])
+
+        self.out = nn.Sequential(
+            normalization(model_channels),
+            nn.SiLU(),
+            zero_module(conv_nd(1, model_channels, out_channels, 3, padding=1)),
+        )
+
+        self.debug_codes = {}
+
+    def get_grad_norm_parameter_groups(self):
+        attn1 = list(itertools.chain.from_iterable([lyr.block1.attn.parameters() for lyr in self.layers]))
+        attn2 = list(itertools.chain.from_iterable([lyr.block2.attn.parameters() for lyr in self.layers]))
+        ff1 = list(itertools.chain.from_iterable([lyr.block1.ff.parameters() for lyr in self.layers]))
+        ff2 = list(itertools.chain.from_iterable([lyr.block2.ff.parameters() for lyr in self.layers]))
+        blkout_layers = list(itertools.chain.from_iterable([lyr.out.parameters() for lyr in self.layers]))
+        groups = {
+            'prenorms': list(itertools.chain.from_iterable([lyr.prenorm.parameters() for lyr in self.layers])),
+            'blk1_attention_layers': attn1,
+            'blk2_attention_layers': attn2,
+            'attention_layers': attn1 + attn2,
+            'blk1_ff_layers': ff1,
+            'blk2_ff_layers': ff2,
+            'ff_layers': ff1 + ff2,
+            'block_out_layers': blkout_layers,
+            'rotary_embeddings': list(self.rotary_embeddings.parameters()),
+            'out': list(self.out.parameters()),
+            'x_proj': list(self.inp_block.parameters()),
+            'layers': list(self.layers.parameters()),
+            'time_embed': list(self.time_embed.parameters()),
+            'conditioning_encoder': list(self.conditioning_encoder.parameters()),
+        }
+        return groups
+
+    def forward(self, x, timesteps, conditioning_input, conditioning_free=False):
+        unused_params = []
+
+        time_emb = self.time_embed(timestep_embedding(timesteps, self.time_embed_dim))
+        cond_enc = self.conditioning_encoder(conditioning_input, time_emb)
+
+        if conditioning_free:
+            cond = self.unconditioned_embedding
+        else:
+            cond = cond_enc
+            # Mask out the conditioning branch for whole batch elements, implementing something similar to classifier-free guidance.
+            if self.training and self.unconditioned_percentage > 0:
+                unconditioned_batches = torch.rand((cond.shape[0], 1, 1),
+                                                   device=cond.device) < self.unconditioned_percentage
+                cond = torch.where(unconditioned_batches, self.unconditioned_embedding.repeat(cond.shape[0], 1, 1), cond)
+            unused_params.append(self.unconditioned_embedding)
+        cond = cond.repeat(1,x.shape[-1],1)
+
+        with torch.autocast(x.device.type, enabled=self.enable_fp16):
+            x = self.inp_block(x).permute(0,2,1)
+
+            rotary_pos_emb = self.rotary_embeddings(x.shape[1]+1, x.device)
+            for layer in self.layers:
+                x = checkpoint(layer, x, cond, time_emb, rotary_pos_emb)
+
+        x = x.float().permute(0,2,1)
+        out = self.out(x)
+
+        # Involve probabilistic or possibly unused parameters in loss so we don't get DDP errors.
+        extraneous_addition = 0
+        for p in unused_params:
+            extraneous_addition = extraneous_addition + p.mean()
+        out = out + extraneous_addition * 0
+
+        return out
+
+    def before_step(self, step):
+        scaled_grad_parameters = list(itertools.chain.from_iterable([lyr.out.parameters() for lyr in self.diff.layers])) + \
+                                 list(itertools.chain.from_iterable([lyr.prenorm.parameters() for lyr in self.diff.layers]))
+        # Scale back the gradients of the blkout and prenorm layers by a constant factor. These get two orders of magnitudes
+        # higher gradients. Ideally we would use parameter groups, but ZeroRedundancyOptimizer makes this trickier than
+        # directly fiddling with the gradients.
+        for p in scaled_grad_parameters:
+            if hasattr(p, 'grad') and p.grad is not None:
+                p.grad *= .2
+
+
+@register_model
+def register_tfdpc4(opt_net, opt):
+    return TransformerDiffusionWithPointConditioning(**opt_net['kwargs'])
+
+
+def test_cheater_model():
+    clip = torch.randn(2, 256, 400)
+    cl = torch.randn(2, 256, 400)
+    ts = torch.LongTensor([600, 600])
+
+    # For music:
+    model = TransformerDiffusionWithPointConditioning(in_channels=256, out_channels=512, model_channels=1024,
+                                                        contraction_dim=384, num_heads=6, num_layers=18, dropout=0,
+                                                        unconditioned_percentage=.4)
+    print_network(model)
+    o = model(clip, ts, cl)
+    pg = model.get_grad_norm_parameter_groups()
+    def prmsz(lp):
+        sz = 0
+        for p in lp:
+            q = 1
+            for s in p.shape:
+                q *= s
+            sz += q
+        return sz
+    for k, v in pg.items():
+        print(f'{k}: {prmsz(v)/1000000}')
+
+
+def inference_tfdpc4_with_cheater():
+    with torch.no_grad():
+        os.makedirs('results/tfdpc_v3', exist_ok=True)
+
+        #length = 40 * 22050 // 256 // 16
+        samples = {'electronica1': load_audio('Y:\\split\\yt-music-eval\\00001.wav', 22050),
+                   'electronica2': load_audio('Y:\\split\\yt-music-eval\\00272.wav', 22050),
+                   'e_guitar': load_audio('Y:\\split\\yt-music-eval\\00227.wav', 22050),
+                   'creep': load_audio('Y:\\separated\\bt-music-3\\[2007] MTV Unplugged (Live) (Japan Edition)\\05 - Creep [Cover On Radiohead]\\00001\\no_vocals.wav', 22050),
+                   'rock1': load_audio('Y:\\separated\\bt-music-3\\2016 - Heal My Soul\\01 - Daze Of The Night\\00000\\no_vocals.wav', 22050),
+                   'kiss': load_audio('Y:\\separated\\bt-music-3\\KISS (2001) Box Set CD1\\02 Deuce (Demo Version)\\00000\\no_vocals.wav', 22050),
+                   'purp': load_audio('Y:\\separated\\bt-music-3\\Shades of Deep Purple\\11 Help (Alternate Take)\\00001\\no_vocals.wav', 22050),
+                   'western_stars': load_audio('Y:\\separated\\bt-music-3\\Western Stars\\01 Hitch Hikin\'\\00000\\no_vocals.wav', 22050),
+                   'silk': load_audio('Y:\\separated\\silk\\MonstercatSilkShowcase\\890\\00007\\no_vocals.wav', 22050),
+                   'long_electronica': load_audio('C:\\Users\\James\\Music\\longer_sample.wav', 22050),}
+        for k, sample in samples.items():
+            sample = sample.cuda()
+            length = sample.shape[0]//256//16
+
+            model = TransformerDiffusionWithPointConditioning(in_channels=256, out_channels=512, model_channels=1024,
+                                                              contraction_dim=512, num_heads=8, num_layers=12, dropout=0,
+                                                              use_fp16=False, unconditioned_percentage=0).eval().cuda()
+            model.load_state_dict(torch.load('x:/dlas/experiments/train_music_cheater_gen_v3/models/59000_generator_ema.pth'))
+
+            from trainer.injectors.audio_injectors import TorchMelSpectrogramInjector
+            spec_fn = TorchMelSpectrogramInjector({'n_mel_channels': 256, 'mel_fmax': 11000, 'filter_length': 16000, 'true_normalization': True,
+                                                        'normalize': True, 'in': 'in', 'out': 'out'}, {}).cuda()
+            ref_mel = spec_fn({'in': sample.unsqueeze(0)})['out']
+            from trainer.injectors.audio_injectors import MusicCheaterLatentInjector
+            cheater_encoder = MusicCheaterLatentInjector({'in': 'in', 'out': 'out'}, {}).cuda()
+            ref_cheater = cheater_encoder({'in': ref_mel})['out']
+
+            from models.diffusion.respace import SpacedDiffusion
+            from models.diffusion.respace import space_timesteps
+            from models.diffusion.gaussian_diffusion import get_named_beta_schedule
+            diffuser = SpacedDiffusion(use_timesteps=space_timesteps(4000, [128]), model_mean_type='epsilon',
+                                   model_var_type='learned_range', loss_type='mse', betas=get_named_beta_schedule('linear', 4000),
+                                   conditioning_free=True, conditioning_free_k=1)
+
+            # Conventional decoding method:
+            gen_cheater = diffuser.ddim_sample_loop(model, (1,256,length), progress=True, model_kwargs={'true_cheater': ref_cheater})
+
+            # Guidance decoding method:
+            #mask = torch.ones_like(ref_cheater)
+            #mask[:,:,15:-15] = 0
+            #gen_cheater = diffuser.p_sample_loop_with_guidance(model, ref_cheater, mask, model_kwargs={'true_cheater': ref_cheater})
+
+            # Just decode the ref.
+            #gen_cheater = ref_cheater
+
+            from models.audio.music.transformer_diffusion12 import TransformerDiffusionWithCheaterLatent
+            diffuser = SpacedDiffusion(use_timesteps=space_timesteps(4000, [32]), model_mean_type='epsilon',
+                                   model_var_type='learned_range', loss_type='mse', betas=get_named_beta_schedule('linear', 4000),
+                                   conditioning_free=True, conditioning_free_k=1)
+            wrap = TransformerDiffusionWithCheaterLatent(in_channels=256, out_channels=512, model_channels=1024,
+                                                         contraction_dim=512, prenet_channels=1024, input_vec_dim=256,
+                                                         prenet_layers=6, num_heads=8, num_layers=16, new_code_expansion=True,
+                                                         dropout=0, unconditioned_percentage=0).eval().cuda()
+            wrap.load_state_dict(torch.load('x:/dlas/experiments/train_music_diffusion_tfd_cheater_from_scratch/models/56500_generator_ema.pth'))
+            cheater_to_mel = wrap.diff
+            gen_mel = diffuser.ddim_sample_loop(cheater_to_mel, (1,256,gen_cheater.shape[-1]*16), progress=True,
+                                             model_kwargs={'codes': gen_cheater.permute(0,2,1)})
+            torchvision.utils.save_image((gen_mel + 1)/2, f'results/tfdpc_v3/{k}.png')
+
+            from utils.music_utils import get_mel2wav_v3_model
+            m2w = get_mel2wav_v3_model().cuda()
+            spectral_diffuser = SpacedDiffusion(use_timesteps=space_timesteps(4000, [32]), model_mean_type='epsilon',
+                                   model_var_type='learned_range', loss_type='mse', betas=get_named_beta_schedule('linear', 4000),
+                                   conditioning_free=True, conditioning_free_k=1)
+            from trainer.injectors.audio_injectors import denormalize_mel
+            gen_mel_denorm = denormalize_mel(gen_mel)
+            output_shape = (1,16,gen_mel_denorm.shape[-1]*256//16)
+            gen_wav = spectral_diffuser.ddim_sample_loop(m2w, output_shape, model_kwargs={'codes': gen_mel_denorm})
+            from trainer.injectors.audio_injectors import pixel_shuffle_1d
+            gen_wav = pixel_shuffle_1d(gen_wav, 16)
+
+            torchaudio.save(f'results/tfdpc_v3/{k}.wav', gen_wav.squeeze(1).cpu(), 22050)
+            torchaudio.save(f'results/tfdpc_v3/{k}_ref.wav', sample.unsqueeze(0).cpu(), 22050)
+
+if __name__ == '__main__':
+    test_cheater_model()
+    #inference_tfdpc4_with_cheater()
diff --git a/codes/models/audio/music/unet_diffusion_waveform_gen3.py b/codes/models/audio/music/unet_diffusion_waveform_gen3.py
index 7a9101ef..e84ab75a 100644
--- a/codes/models/audio/music/unet_diffusion_waveform_gen3.py
+++ b/codes/models/audio/music/unet_diffusion_waveform_gen3.py
@@ -358,10 +358,13 @@ def register_unet_diffusion_waveform_gen3(opt_net, opt):
 
 
 if __name__ == '__main__':
-    clip = torch.randn(2, 64, 880)
+    clip = torch.randn(2, 4, 880)
     aligned_sequence = torch.randn(2,256,220)
     ts = torch.LongTensor([600, 600])
-    model = DiffusionWaveformGen()
+    model = DiffusionWaveformGen(in_channels=4, out_channels=8, model_channels=64, in_mel_channels=256,
+                                 channel_mult=[1,2,4,6,8,16], num_res_blocks=[2,2,2,1,1,0], mid_resnet_depth=24,
+                                 conditioning_dim_factor=8,
+                                 token_conditioning_resolutions=[4,16], dropout=.1, time_embed_dim_multiplier=4)
     # Test with sequence aligned conditioning
     o = model(clip, ts, aligned_sequence)
     print_network(model)
diff --git a/codes/utils/music_utils.py b/codes/utils/music_utils.py
index 79c2f8c9..c77dc4ca 100644
--- a/codes/utils/music_utils.py
+++ b/codes/utils/music_utils.py
@@ -10,6 +10,15 @@ def get_mel2wav_model():
     model.eval()
     return model
 
+def get_mel2wav_v3_model():
+    from models.audio.music.unet_diffusion_waveform_gen3 import DiffusionWaveformGen
+    model = DiffusionWaveformGen(model_channels=256, in_channels=16, in_mel_channels=256, out_channels=32, channel_mult=[1,1.5,2,4],
+                                 num_res_blocks=[2,1,1,0], mid_resnet_depth=24, token_conditioning_resolutions=[1,4],
+                                 dropout=0, time_embed_dim_multiplier=1, unconditioned_percentage=0)
+    model.load_state_dict(torch.load("../experiments/music_mel2wav_v3.pth", map_location=torch.device('cpu')))
+    model.eval()
+    return model
+
 def get_music_codegen():
     from models.audio.mel2vec import ContrastiveTrainingWrapper
     model = ContrastiveTrainingWrapper(mel_input_channels=256, inner_dim=1024, layers=24, dropout=0,