diff --git a/codes/models/audio/music/transformer_diffusion9.py b/codes/models/audio/music/transformer_diffusion9.py
new file mode 100644
index 00000000..b2068d28
--- /dev/null
+++ b/codes/models/audio/music/transformer_diffusion9.py
@@ -0,0 +1,361 @@
+import itertools
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from models.arch_util import ResBlock
+from models.audio.music.music_quantizer2 import MusicQuantizer2
+from models.diffusion.nn import timestep_embedding, normalization, zero_module, conv_nd, linear
+from models.diffusion.unet_diffusion import TimestepBlock
+from models.lucidrains.x_transformers import Encoder, Attention, FeedForward, RMSScaleShiftNorm, RotaryEmbedding
+from trainer.networks import register_model
+from utils.util import checkpoint, print_network
+
+
+def is_latent(t):
+    return t.dtype == torch.float
+
+def is_sequence(t):
+    return t.dtype == torch.long
+
+
+class MultiGroupEmbedding(nn.Module):
+    def __init__(self, tokens, groups, dim):
+        super().__init__()
+        self.m = nn.ModuleList([nn.Embedding(tokens, dim // groups) for _ in range(groups)])
+
+    def forward(self, x):
+        h = [embedding(x[:, :, i]) for i, embedding in enumerate(self.m)]
+        return torch.cat(h, dim=-1)
+
+
+class TimestepRotaryEmbedSequential(nn.Sequential, TimestepBlock):
+    def forward(self, x, emb, rotary_emb):
+        for layer in self:
+            if isinstance(layer, TimestepBlock):
+                x = layer(x, emb, rotary_emb)
+            else:
+                x = layer(x, rotary_emb)
+        return x
+
+
+class DietAttentionBlock(TimestepBlock):
+    def __init__(self, in_dim, dim, heads, dropout):
+        super().__init__()
+        self.proj = nn.Linear(in_dim, dim)
+        self.proj.bias.data.zero_()
+        self.rms_scale_norm = RMSScaleShiftNorm(dim, bias=False)
+        self.attn = Attention(dim, heads=heads, dim_head=dim//heads, causal=False, dropout=dropout)
+        self.ff = FeedForward(dim, in_dim, mult=1, dropout=dropout, zero_init_output=True)
+
+    def forward(self, x, timestep_emb, rotary_emb):
+        h = self.proj(x)
+        h = self.rms_scale_norm(h, norm_scale_shift_inp=timestep_emb)
+        h, _, _, _ = checkpoint(self.attn, h, None, None, None, None, None, rotary_emb)
+        h = checkpoint(self.ff, h)
+        return h + x
+
+
+class TransformerDiffusion(nn.Module):
+    """
+    A diffusion model composed entirely of stacks of transformer layers. Why would you do it any other way?
+    """
+    def __init__(
+            self,
+            prenet_channels=256,
+            prenet_layers=3,
+            model_channels=512,
+            block_channels=256,
+            num_layers=8,
+            in_channels=256,
+            rotary_emb_dim=32,
+            input_vec_dim=512,
+            out_channels=512,  # mean and variance
+            num_heads=16,
+            dropout=0,
+            use_fp16=False,
+            ar_prior=False,
+            # Parameters for regularization.
+            unconditioned_percentage=.1,  # This implements a mechanism similar to what is used in classifier-free training.
+    ):
+        super().__init__()
+
+        self.in_channels = in_channels
+        self.model_channels = model_channels
+        self.prenet_channels = prenet_channels
+        self.out_channels = out_channels
+        self.dropout = dropout
+        self.unconditioned_percentage = unconditioned_percentage
+        self.enable_fp16 = use_fp16
+
+        self.inp_block = conv_nd(1, in_channels, prenet_channels, 3, 1, 1)
+
+        self.time_embed = nn.Sequential(
+            linear(prenet_channels, prenet_channels),
+            nn.SiLU(),
+            linear(prenet_channels, block_channels),
+        )
+
+        self.ar_prior = ar_prior
+        prenet_heads = prenet_channels//64
+        if ar_prior:
+            self.ar_input = nn.Linear(input_vec_dim, prenet_channels)
+            self.ar_prior_intg = Encoder(
+                    dim=prenet_channels,
+                    depth=prenet_layers,
+                    heads=prenet_heads,
+                    ff_dropout=dropout,
+                    attn_dropout=dropout,
+                    use_rmsnorm=True,
+                    ff_glu=True,
+                    rotary_pos_emb=True,
+                    zero_init_branch_output=True,
+                    ff_mult=1,
+                )
+        else:
+            self.input_converter = nn.Linear(input_vec_dim, prenet_channels)
+            self.code_converter = Encoder(
+                        dim=prenet_channels,
+                        depth=prenet_layers,
+                        heads=prenet_heads,
+                        ff_dropout=dropout,
+                        attn_dropout=dropout,
+                        use_rmsnorm=True,
+                        ff_glu=True,
+                        rotary_pos_emb=True,
+                        zero_init_branch_output=True,
+                        ff_mult=1,
+                    )
+
+        self.unconditioned_embedding = nn.Parameter(torch.randn(1,1,prenet_channels))
+        self.rotary_embeddings = RotaryEmbedding(rotary_emb_dim)
+        self.intg = nn.Linear(prenet_channels*2, model_channels)
+        self.layers = TimestepRotaryEmbedSequential(*[DietAttentionBlock(model_channels, block_channels, num_heads, dropout) for _ in range(num_layers)])
+
+        self.out = nn.Sequential(
+            normalization(model_channels),
+            nn.SiLU(),
+            zero_module(conv_nd(1, model_channels, out_channels, 3, padding=1)),
+        )
+
+        self.debug_codes = {}
+
+    def get_grad_norm_parameter_groups(self):
+        groups = {
+            'layers': list(self.layers.parameters()) + list(self.inp_block.parameters()),
+            'code_converters': list(self.input_converter.parameters()) + list(self.code_converter.parameters()),
+            'time_embed': list(self.time_embed.parameters()),
+        }
+        return groups
+
+    def timestep_independent(self, prior, expected_seq_len):
+        code_emb = self.ar_input(prior) if self.ar_prior else self.input_converter(prior)
+        code_emb = self.ar_prior_intg(code_emb) if self.ar_prior else self.code_converter(code_emb)
+
+        # Mask out the conditioning branch for whole batch elements, implementing something similar to classifier-free guidance.
+        if self.training and self.unconditioned_percentage > 0:
+            unconditioned_batches = torch.rand((code_emb.shape[0], 1, 1),
+                                               device=code_emb.device) < self.unconditioned_percentage
+            code_emb = torch.where(unconditioned_batches, self.unconditioned_embedding.repeat(prior.shape[0], 1, 1),
+                                   code_emb)
+
+        expanded_code_emb = F.interpolate(code_emb.permute(0,2,1), size=expected_seq_len, mode='nearest').permute(0,2,1)
+        return expanded_code_emb
+
+    def forward(self, x, timesteps, codes=None, conditioning_input=None, precomputed_code_embeddings=None, conditioning_free=False):
+        if precomputed_code_embeddings is not None:
+            assert codes is None and conditioning_input is None, "Do not provide precomputed embeddings and the other parameters. It is unclear what you want me to do here."
+
+        unused_params = []
+        if conditioning_free:
+            code_emb = self.unconditioned_embedding.repeat(x.shape[0], x.shape[-1], 1)
+        else:
+            if precomputed_code_embeddings is not None:
+                code_emb = precomputed_code_embeddings
+            else:
+                code_emb = self.timestep_independent(codes, x.shape[-1])
+            unused_params.append(self.unconditioned_embedding)
+
+        with torch.autocast(x.device.type, enabled=self.enable_fp16):
+            blk_emb = self.time_embed(timestep_embedding(timesteps, self.prenet_channels))
+            x = self.inp_block(x).permute(0,2,1)
+
+            rotary_pos_emb = self.rotary_embeddings(x.shape[1], x.device)
+            x = self.intg(torch.cat([x, code_emb], dim=-1))
+            for layer in self.layers:
+                x = checkpoint(layer, x, blk_emb, rotary_pos_emb)
+
+        x = x.float().permute(0,2,1)
+        out = self.out(x)
+
+        # Involve probabilistic or possibly unused parameters in loss so we don't get DDP errors.
+        extraneous_addition = 0
+        for p in unused_params:
+            extraneous_addition = extraneous_addition + p.mean()
+        out = out + extraneous_addition * 0
+
+        return out
+
+
+class TransformerDiffusionWithQuantizer(nn.Module):
+    def __init__(self, quantizer_dims=[1024], freeze_quantizer_until=20000, **kwargs):
+        super().__init__()
+
+        self.internal_step = 0
+        self.freeze_quantizer_until = freeze_quantizer_until
+        self.diff = TransformerDiffusion(**kwargs)
+        self.quantizer = MusicQuantizer2(inp_channels=kwargs['in_channels'], inner_dim=quantizer_dims,
+                                         codevector_dim=quantizer_dims[0], codebook_size=256,
+                                         codebook_groups=2, max_gumbel_temperature=4, min_gumbel_temperature=.5)
+        self.quantizer.quantizer.temperature = self.quantizer.min_gumbel_temperature
+        del self.quantizer.up
+
+    def update_for_step(self, step, *args):
+        self.internal_step = step
+        qstep = max(0, self.internal_step - self.freeze_quantizer_until)
+        self.quantizer.quantizer.temperature = max(
+            self.quantizer.max_gumbel_temperature * self.quantizer.gumbel_temperature_decay ** qstep,
+                    self.quantizer.min_gumbel_temperature,
+                )
+
+    def forward(self, x, timesteps, truth_mel, conditioning_input=None, disable_diversity=False, conditioning_free=False):
+        quant_grad_enabled = self.internal_step > self.freeze_quantizer_until
+        with torch.set_grad_enabled(quant_grad_enabled):
+            proj, diversity_loss = self.quantizer(truth_mel, return_decoder_latent=True)
+            proj = proj.permute(0,2,1)
+
+        # Make sure this does not cause issues in DDP by explicitly using the parameters for nothing.
+        if not quant_grad_enabled:
+            unused = 0
+            for p in self.quantizer.parameters():
+                unused = unused + p.mean() * 0
+            proj = proj + unused
+            diversity_loss = diversity_loss * 0
+
+        diff = self.diff(x, timesteps, codes=proj, conditioning_input=conditioning_input, conditioning_free=conditioning_free)
+        if disable_diversity:
+            return diff
+        return diff, diversity_loss
+
+    def get_debug_values(self, step, __):
+        if self.quantizer.total_codes > 0:
+            return {'histogram_quant_codes': self.quantizer.codes[:self.quantizer.total_codes],
+                    'gumbel_temperature': self.quantizer.quantizer.temperature}
+        else:
+            return {}
+
+    def get_grad_norm_parameter_groups(self):
+        groups = {
+            'attention_layers': list(itertools.chain.from_iterable([lyr.attn.parameters() for lyr in self.diff.layers])),
+            'ff_layers': list(itertools.chain.from_iterable([lyr.ff.parameters() for lyr in self.diff.layers])),
+            'quantizer_encoder': list(self.quantizer.encoder.parameters()),
+            'quant_codebook': [self.quantizer.quantizer.codevectors],
+            'rotary_embeddings': list(self.diff.rotary_embeddings.parameters()),
+            'out': list(self.diff.out.parameters()),
+            'x_proj': list(self.diff.inp_block.parameters()),
+            'layers': list(self.diff.layers.parameters()),
+            'code_converters': list(self.diff.input_converter.parameters()) + list(self.diff.code_converter.parameters()),
+            'time_embed': list(self.diff.time_embed.parameters()),
+        }
+        return groups
+
+
+class TransformerDiffusionWithARPrior(nn.Module):
+    def __init__(self, freeze_diff=False, **kwargs):
+        super().__init__()
+
+        self.internal_step = 0
+        from models.audio.music.gpt_music import GptMusicLower
+        self.ar = GptMusicLower(dim=512, layers=12)
+        for p in self.ar.parameters():
+            p.DO_NOT_TRAIN = True
+            p.requires_grad = False
+
+        self.diff = TransformerDiffusion(ar_prior=True, **kwargs)
+        if freeze_diff:
+            for p in self.diff.parameters():
+                p.DO_NOT_TRAIN = True
+                p.requires_grad = False
+            for p in list(self.diff.ar_prior_intg.parameters()) + list(self.diff.ar_input.parameters()):
+                del p.DO_NOT_TRAIN
+                p.requires_grad = True
+
+    def get_grad_norm_parameter_groups(self):
+        groups = {
+            'attention_layers': list(itertools.chain.from_iterable([lyr.attn.parameters() for lyr in self.diff.layers])),
+            'ff_layers': list(itertools.chain.from_iterable([lyr.ff.parameters() for lyr in self.diff.layers])),
+            'rotary_embeddings': list(self.diff.rotary_embeddings.parameters()),
+            'out': list(self.diff.out.parameters()),
+            'x_proj': list(self.diff.inp_block.parameters()),
+            'layers': list(self.diff.layers.parameters()),
+            'ar_prior_intg': list(self.diff.ar_prior_intg.parameters()),
+            'time_embed': list(self.diff.time_embed.parameters()),
+        }
+        return groups
+
+    def forward(self, x, timesteps, truth_mel, disable_diversity=False, conditioning_input=None, conditioning_free=False):
+        with torch.no_grad():
+            prior = self.ar(truth_mel, conditioning_input, return_latent=True)
+
+        diff = self.diff(x, timesteps, prior, conditioning_free=conditioning_free)
+        return diff
+
+
+@register_model
+def register_transformer_diffusion9(opt_net, opt):
+    return TransformerDiffusion(**opt_net['kwargs'])
+
+
+@register_model
+def register_transformer_diffusion9_with_quantizer(opt_net, opt):
+    return TransformerDiffusionWithQuantizer(**opt_net['kwargs'])
+
+
+@register_model
+def register_transformer_diffusion9_with_ar_prior(opt_net, opt):
+    return TransformerDiffusionWithARPrior(**opt_net['kwargs'])
+
+
+def test_quant_model():
+    clip = torch.randn(2, 256, 400)
+    cond = torch.randn(2, 256, 400)
+    ts = torch.LongTensor([600, 600])
+    model = TransformerDiffusionWithQuantizer(in_channels=256, model_channels=3072, block_channels=1536,
+                                              prenet_channels=1024, num_heads=12,
+                                              input_vec_dim=1024, num_layers=24, prenet_layers=6)
+    model.get_grad_norm_parameter_groups()
+
+    quant_weights = torch.load('D:\\dlas\\experiments\\train_music_quant_r4\\models\\5000_generator.pth')
+    model.quantizer.load_state_dict(quant_weights, strict=False)
+
+    torch.save(model.state_dict(), 'sample.pth')
+    print_network(model)
+    o = model(clip, ts, clip, cond)
+
+
+def test_ar_model():
+    clip = torch.randn(2, 256, 400)
+    cond = torch.randn(2, 256, 400)
+    ts = torch.LongTensor([600, 600])
+    model = TransformerDiffusionWithARPrior(model_channels=3072, block_channels=1536, prenet_channels=1536,
+                                            input_vec_dim=512, num_layers=24, prenet_layers=6, freeze_diff=True,
+                                            unconditioned_percentage=.4)
+    model.get_grad_norm_parameter_groups()
+
+    ar_weights = torch.load('D:\\dlas\\experiments\\train_music_gpt\\models\\44500_generator_ema.pth')
+    model.ar.load_state_dict(ar_weights, strict=True)
+    diff_weights = torch.load('X:\\dlas\\experiments\\train_music_diffusion_tfd8\\models\\47500_generator_ema.pth')
+    pruned_diff_weights = {}
+    for k,v in diff_weights.items():
+        if k.startswith('diff.'):
+            pruned_diff_weights[k.replace('diff.', '')] = v
+    model.diff.load_state_dict(pruned_diff_weights, strict=False)
+    torch.save(model.state_dict(), 'sample.pth')
+
+    model(clip, ts, cond, conditioning_input=cond)
+
+
+
+if __name__ == '__main__':
+    test_quant_model()
diff --git a/codes/models/audio/tts/unet_diffusion_tts10.py b/codes/models/audio/tts/unet_diffusion_tts10.py
deleted file mode 100644
index 1bbc4f4b..00000000
--- a/codes/models/audio/tts/unet_diffusion_tts10.py
+++ /dev/null
@@ -1,330 +0,0 @@
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-from torch import autocast
-
-from models.diffusion.nn import timestep_embedding, normalization, zero_module, conv_nd, linear
-from models.diffusion.unet_diffusion import AttentionBlock, TimestepEmbedSequential, \
-    Downsample, Upsample, TimestepBlock
-from models.lucidrains.x_transformers import Encoder
-from scripts.audio.gen.use_diffuse_tts import ceil_multiple
-from trainer.networks import register_model
-from utils.util import checkpoint
-
-
-class ResBlock(TimestepBlock):
-    def __init__(
-        self,
-        channels,
-        emb_channels,
-        dropout,
-        out_channels=None,
-        dims=2,
-        kernel_size=3,
-    ):
-        super().__init__()
-        self.channels = channels
-        self.emb_channels = emb_channels
-        self.dropout = dropout
-        self.out_channels = out_channels or channels
-        padding = 1 if kernel_size == 3 else 2
-
-        self.in_layers = nn.Sequential(
-            normalization(channels),
-            nn.SiLU(),
-            conv_nd(dims, channels, self.out_channels, 1, padding=0),
-        )
-
-        self.emb_layers = nn.Sequential(
-            nn.SiLU(),
-            linear(
-                emb_channels,
-                self.out_channels,
-            ),
-        )
-        self.out_layers = nn.Sequential(
-            normalization(self.out_channels),
-            nn.SiLU(),
-            nn.Dropout(p=dropout),
-            zero_module(
-                conv_nd(dims, self.out_channels, self.out_channels, kernel_size, padding=padding)
-            ),
-        )
-
-        if self.out_channels == channels:
-            self.skip_connection = nn.Identity()
-        else:
-            self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
-
-    def forward(self, x, emb):
-        """
-        Apply the block to a Tensor, conditioned on a timestep embedding.
-
-        :param x: an [N x C x ...] Tensor of features.
-        :param emb: an [N x emb_channels] Tensor of timestep embeddings.
-        :return: an [N x C x ...] Tensor of outputs.
-        """
-        return checkpoint(
-            self._forward, x, emb
-        )
-
-    def _forward(self, x, emb):
-        h = self.in_layers(x)
-        emb_out = self.emb_layers(emb).type(h.dtype)
-        while len(emb_out.shape) < len(h.shape):
-            emb_out = emb_out[..., None]
-        h = h + emb_out
-        h = self.out_layers(h)
-        return self.skip_connection(x) + h
-
-
-class DiffusionTts(nn.Module):
-    def __init__(
-            self,
-            model_channels,
-            in_channels=100,
-            num_tokens=256,
-            out_channels=200,  # mean and variance
-            dropout=0,
-            # m                 1,  2,   4,   8
-            block_channels=  (512,640, 768,1024),
-            num_res_blocks=  (3,    3,   3,   3),
-            token_conditioning_resolutions=(2,4,8),
-            attention_resolutions=(2,4,8),
-            conv_resample=True,
-            dims=1,
-            use_fp16=False,
-            kernel_size=3,
-            scale_factor=2,
-            num_heads=None,
-            time_embed_dim_multiplier=4,
-            nil_guidance_fwd_proportion=.15,
-    ):
-        super().__init__()
-
-        self.in_channels = in_channels
-        self.model_channels = model_channels
-        self.out_channels = out_channels
-        self.attention_resolutions = attention_resolutions
-        self.dropout = dropout
-        self.conv_resample = conv_resample
-        self.dtype = torch.float16 if use_fp16 else torch.float32
-        self.dims = dims
-        self.nil_guidance_fwd_proportion = nil_guidance_fwd_proportion
-        self.mask_token_id = num_tokens
-        num_heads = model_channels // 64 if num_heads is None else num_heads
-
-        padding = 1 if kernel_size == 3 else 2
-
-        time_embed_dim = model_channels * time_embed_dim_multiplier
-        self.time_embed = nn.Sequential(
-            linear(model_channels, time_embed_dim),
-            nn.SiLU(),
-            linear(time_embed_dim, time_embed_dim),
-        )
-
-        self.code_embedding = nn.Embedding(num_tokens+1, model_channels)
-        self.conditioning_embedder = nn.Sequential(nn.Conv1d(in_channels, model_channels // 2, 3, padding=1, stride=2),
-                                                   nn.Conv1d(model_channels//2, model_channels,3,padding=1,stride=2))
-        self.conditioning_encoder = Encoder(
-                    dim=model_channels,
-                    depth=4,
-                    heads=num_heads,
-                    ff_dropout=dropout,
-                    attn_dropout=dropout,
-                    use_rmsnorm=True,
-                    ff_glu=True,
-                    rotary_pos_emb=True,
-                )
-
-        self.codes_encoder = Encoder(
-                    dim=model_channels,
-                    depth=8,
-                    heads=num_heads,
-                    ff_dropout=dropout,
-                    attn_dropout=dropout,
-                    use_rms_scaleshift_norm=True,
-                    ff_glu=True,
-                    rotary_pos_emb=True,
-                    zero_init_branch_output=True,
-                )
-
-        self.input_blocks = nn.ModuleList(
-            [
-                TimestepEmbedSequential(
-                    conv_nd(dims, in_channels, model_channels, kernel_size, padding=padding)
-                )
-            ]
-        )
-        token_conditioning_blocks = []
-        self._feature_size = model_channels
-        input_block_chans = [model_channels]
-        ch = model_channels
-        ds = 1
-
-        for level, (blk_chan, num_blocks) in enumerate(zip(block_channels, num_res_blocks)):
-            if ds in token_conditioning_resolutions:
-                token_conditioning_block = nn.Conv1d(model_channels, ch, 1)
-                token_conditioning_block.weight.data *= .02
-                self.input_blocks.append(token_conditioning_block)
-                token_conditioning_blocks.append(token_conditioning_block)
-
-            for _ in range(num_blocks):
-                layers = [
-                    ResBlock(
-                        ch,
-                        time_embed_dim,
-                        dropout,
-                        out_channels=blk_chan,
-                        dims=dims,
-                        kernel_size=kernel_size,
-                    )
-                ]
-                ch = blk_chan
-                if ds in attention_resolutions:
-                    layers.append(
-                        AttentionBlock(
-                            ch,
-                            num_heads=num_heads,
-                        )
-                    )
-                self.input_blocks.append(TimestepEmbedSequential(*layers))
-                self._feature_size += ch
-                input_block_chans.append(ch)
-            if level != len(block_channels) - 1:
-                out_ch = ch
-                self.input_blocks.append(
-                    TimestepEmbedSequential(
-                        Downsample(
-                            ch, conv_resample, dims=dims, out_channels=out_ch, factor=scale_factor, ksize=1, pad=0
-                        )
-                    )
-                )
-                ch = out_ch
-                input_block_chans.append(ch)
-                ds *= 2
-                self._feature_size += ch
-
-        self.middle_block = TimestepEmbedSequential(
-            ResBlock(
-                ch,
-                time_embed_dim,
-                dropout,
-                dims=dims,
-            ),
-            AttentionBlock(
-                ch,
-                num_heads=num_heads,
-            ),
-            ResBlock(
-                ch,
-                time_embed_dim,
-                dropout,
-                dims=dims,
-            ),
-        )
-        self._feature_size += ch
-
-        self.output_blocks = nn.ModuleList([])
-        for level, (blk_chan, num_blocks) in list(enumerate(zip(block_channels, num_res_blocks)))[::-1]:
-            for i in range(num_blocks + 1):
-                ich = input_block_chans.pop()
-                layers = [
-                    ResBlock(
-                        ch + ich,
-                        time_embed_dim,
-                        dropout,
-                        out_channels=blk_chan,
-                        dims=dims,
-                        kernel_size=kernel_size,
-                    )
-                ]
-                ch = blk_chan
-                if ds in attention_resolutions:
-                    layers.append(
-                        AttentionBlock(
-                            ch,
-                        )
-                    )
-                if level and i == num_blocks:
-                    out_ch = ch
-                    layers.append(
-                        Upsample(ch, conv_resample, dims=dims, out_channels=out_ch, factor=scale_factor)
-                    )
-                    ds //= 2
-                self.output_blocks.append(TimestepEmbedSequential(*layers))
-                self._feature_size += ch
-
-        self.out = nn.Sequential(
-            normalization(ch),
-            nn.SiLU(),
-            zero_module(conv_nd(dims, model_channels, out_channels, kernel_size, padding=padding)),
-        )
-
-    def forward(self, x, timesteps, codes, conditioning_input=None):
-        """
-        Apply the model to an input batch.
-
-        :param x: an [N x C x ...] Tensor of inputs.
-        :param timesteps: a 1-D batch of timesteps.
-        :param codes: an aligned text input.
-        :return: an [N x C x ...] Tensor of outputs.
-        """
-        with autocast(x.device.type):
-            orig_x_shape = x.shape[-1]
-            cm = ceil_multiple(x.shape[-1], 16)
-            if cm != 0:
-                pc = (cm-x.shape[-1])/x.shape[-1]
-                x = F.pad(x, (0,cm-x.shape[-1]))
-                codes = F.pad(codes, (0, int(pc * codes.shape[-1])))
-
-            hs = []
-            time_emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
-
-            # Mask out guidance tokens for un-guided diffusion.
-            if self.training and self.nil_guidance_fwd_proportion > 0:
-                token_mask = torch.rand(codes.shape, device=codes.device) < self.nil_guidance_fwd_proportion
-                codes = torch.where(token_mask, self.mask_token_id, codes)
-            code_emb = self.code_embedding(codes).permute(0, 2, 1)
-            cond_emb = self.conditioning_embedder(conditioning_input).permute(0,2,1)
-            cond_emb = self.conditioning_encoder(cond_emb)[:, 0]
-            code_emb = self.codes_encoder(code_emb.permute(0,2,1), norm_scale_shift_inp=cond_emb).permute(0,2,1)
-
-            first = True
-            time_emb = time_emb.float()
-            h = x
-            for k, module in enumerate(self.input_blocks):
-                if isinstance(module, nn.Conv1d):
-                    h_tok = F.interpolate(module(code_emb), size=(h.shape[-1]), mode='nearest')
-                    h = h + h_tok
-                else:
-                    with autocast(x.device.type, enabled=not first):
-                        # First block has autocast disabled to allow a high precision signal to be properly vectorized.
-                        h = module(h, time_emb)
-                    hs.append(h)
-                first = False
-            h = self.middle_block(h, time_emb)
-            for module in self.output_blocks:
-                h = torch.cat([h, hs.pop()], dim=1)
-                h = module(h, time_emb)
-
-        # Last block also has autocast disabled for high-precision outputs.
-        h = h.float()
-        out = self.out(h)
-        return out[:, :, :orig_x_shape]
-
-
-@register_model
-def register_diffusion_tts10(opt_net, opt):
-    return DiffusionTts(**opt_net['kwargs'])
-
-
-if __name__ == '__main__':
-    clip = torch.randn(2, 100, 500).cuda()
-    tok = torch.randint(0,256, (2,230)).cuda()
-    cond = torch.randn(2, 100, 300).cuda()
-    ts = torch.LongTensor([600, 600]).cuda()
-    model = DiffusionTts(512).cuda()
-    print(sum(p.numel() for p in model.parameters()) / 1000000)
-    model(clip, ts, tok, cond)
-
diff --git a/codes/models/lucidrains/x_transformers.py b/codes/models/lucidrains/x_transformers.py
index b48eb51e..158d358c 100644
--- a/codes/models/lucidrains/x_transformers.py
+++ b/codes/models/lucidrains/x_transformers.py
@@ -352,12 +352,12 @@ class RMSNorm(nn.Module):
 
 
 class RMSScaleShiftNorm(nn.Module):
-    def __init__(self, dim, eps=1e-8):
+    def __init__(self, dim, eps=1e-8, bias=True):
         super().__init__()
         self.scale = dim ** -0.5
         self.eps = eps
         self.g = nn.Parameter(torch.ones(dim))
-        self.scale_shift_process = nn.Linear(dim, dim * 2)
+        self.scale_shift_process = nn.Linear(dim, dim * 2, bias=bias)
 
     def forward(self, x, norm_scale_shift_inp):
         norm = torch.norm(x, dim=-1, keepdim=True) * self.scale